458 lines
15 KiB
Python
458 lines
15 KiB
Python
import multiprocessing
|
||
import os
|
||
import socket
|
||
import sys
|
||
from collections.abc import Callable
|
||
from functools import wraps
|
||
from typing import Any
|
||
|
||
from .._futures import _future_watcher_wrapper, _new_cbscheduler
|
||
from .._granian import (
|
||
ASGIWorker,
|
||
IPCReceiverHandle,
|
||
IPCSenderHandle,
|
||
ProcInfoCollector,
|
||
RSGIWorker,
|
||
SocketHolder,
|
||
WorkerSignal,
|
||
WSGIWorker,
|
||
)
|
||
from .._internal import load_env
|
||
from .._types import SSLCtx
|
||
from ..asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap
|
||
from ..rsgi import _callback_wrapper as _rsgi_call_wrap, _callbacks_from_target as _rsgi_cbs_from_target
|
||
from ..wsgi import _callback_wrapper as _wsgi_call_wrap
|
||
from .common import (
|
||
WORKERS_METHODS,
|
||
AbstractServer,
|
||
AbstractWorker,
|
||
HTTP1Settings,
|
||
HTTP2Settings,
|
||
HTTPModes,
|
||
Interfaces,
|
||
RuntimeModes,
|
||
TaskImpl,
|
||
configure_logging,
|
||
logger,
|
||
setproctitle,
|
||
)
|
||
|
||
|
||
multiprocessing.allow_connection_pickling()
|
||
|
||
|
||
class WorkerProcess(AbstractWorker):
|
||
_idl = 'PID'
|
||
|
||
def __init__(self, parent, idx, target, args):
|
||
# NOTE: Python 3.14 defaults mp spawn method to 'forkserver' on Linux,
|
||
# which doesn't really play well with shared sockets.
|
||
self._spawn_method = multiprocessing.get_start_method()
|
||
if self._spawn_method not in {'fork', 'spawn'}:
|
||
self._spawn_method = 'spawn'
|
||
super().__init__(parent, idx, target, args)
|
||
|
||
@staticmethod
|
||
def wrap_target(target):
|
||
@wraps(target)
|
||
def wrapped(
|
||
worker_id,
|
||
process_name,
|
||
callback_loader,
|
||
sock,
|
||
ipc,
|
||
loop_impl,
|
||
log_enabled,
|
||
log_level,
|
||
log_config,
|
||
env_files,
|
||
*args,
|
||
**kwargs,
|
||
):
|
||
from granian._loops import loops
|
||
|
||
if process_name:
|
||
setproctitle.setproctitle(f'{process_name} worker-{worker_id}')
|
||
|
||
configure_logging(log_level, log_config, log_enabled)
|
||
load_env(env_files)
|
||
|
||
_ipc_handle = None
|
||
sock, _sso = sock
|
||
if sys.platform == 'win32':
|
||
sock = SocketHolder(_sso.fileno())
|
||
elif ipc:
|
||
_ipc_fd = os.dup(ipc.fileno())
|
||
os.set_blocking(_ipc_fd, False)
|
||
_ipc_handle = IPCSenderHandle(_ipc_fd)
|
||
|
||
loop = loops.get(loop_impl)
|
||
callback = callback_loader()
|
||
return target(worker_id, callback, sock, _ipc_handle, loop, *args, **kwargs)
|
||
|
||
return wrapped
|
||
|
||
def _spawn(self, target, args):
|
||
self.inner = multiprocessing.get_context(self._spawn_method).Process(
|
||
name='granian-worker', target=target, args=args
|
||
)
|
||
|
||
def _id(self):
|
||
return self.inner.pid
|
||
|
||
def terminate(self):
|
||
self.interrupt_by_parent = True
|
||
self.inner.terminate()
|
||
|
||
def kill(self):
|
||
self.interrupt_by_parent = True
|
||
self.inner.kill()
|
||
|
||
|
||
class MPServer(AbstractServer[WorkerProcess]):
|
||
@staticmethod
|
||
@WorkerProcess.wrap_target
|
||
def _spawn_asgi_worker(
|
||
worker_id: int,
|
||
callback: Any,
|
||
sock: Any,
|
||
ipc: Any,
|
||
loop: Any,
|
||
runtime_mode: RuntimeModes,
|
||
runtime_threads: int,
|
||
runtime_blocking_threads: int | None,
|
||
blocking_threads: int,
|
||
blocking_threads_idle_timeout: int,
|
||
backpressure: int,
|
||
task_impl: TaskImpl,
|
||
http_mode: HTTPModes,
|
||
http1_settings: HTTP1Settings | None,
|
||
http2_settings: HTTP2Settings | None,
|
||
websockets: bool,
|
||
static_path: tuple[str, str, str | None, str | None] | None,
|
||
log_access_fmt: str | None,
|
||
ssl_ctx: SSLCtx,
|
||
scope_opts: dict[str, Any],
|
||
metrics: Any,
|
||
):
|
||
from granian._signals import set_loop_signals
|
||
|
||
wcallback = _future_watcher_wrapper(_asgi_call_wrap(callback, scope_opts, {}, log_access_fmt))
|
||
shutdown_event = set_loop_signals(loop)
|
||
|
||
worker = ASGIWorker(
|
||
worker_id,
|
||
sock,
|
||
ipc,
|
||
runtime_threads,
|
||
runtime_blocking_threads,
|
||
blocking_threads,
|
||
blocking_threads_idle_timeout,
|
||
backpressure,
|
||
http_mode,
|
||
http1_settings,
|
||
http2_settings,
|
||
websockets,
|
||
static_path,
|
||
*ssl_ctx,
|
||
metrics,
|
||
)
|
||
serve = getattr(worker, WORKERS_METHODS[runtime_mode][sock.is_uds()])
|
||
scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio)
|
||
serve(scheduler, loop, shutdown_event)
|
||
|
||
@staticmethod
|
||
@WorkerProcess.wrap_target
|
||
def _spawn_asgi_lifespan_worker(
|
||
worker_id: int,
|
||
callback: Any,
|
||
sock: Any,
|
||
ipc: Any,
|
||
loop: Any,
|
||
runtime_mode: RuntimeModes,
|
||
runtime_threads: int,
|
||
runtime_blocking_threads: int | None,
|
||
blocking_threads: int,
|
||
blocking_threads_idle_timeout: int,
|
||
backpressure: int,
|
||
task_impl: TaskImpl,
|
||
http_mode: HTTPModes,
|
||
http1_settings: HTTP1Settings | None,
|
||
http2_settings: HTTP2Settings | None,
|
||
websockets: bool,
|
||
static_path: tuple[str, str, str | None] | None,
|
||
log_access_fmt: str | None,
|
||
ssl_ctx: SSLCtx,
|
||
scope_opts: dict[str, Any],
|
||
metrics: Any,
|
||
):
|
||
from granian._signals import set_loop_signals
|
||
|
||
lifespan_handler = LifespanProtocol(callback)
|
||
wcallback = _future_watcher_wrapper(
|
||
_asgi_call_wrap(callback, scope_opts, lifespan_handler.state, log_access_fmt)
|
||
)
|
||
shutdown_event = set_loop_signals(loop)
|
||
|
||
loop.run_until_complete(lifespan_handler.startup())
|
||
if lifespan_handler.interrupt:
|
||
logger.error('ASGI lifespan startup failed', exc_info=lifespan_handler.exc)
|
||
sys.exit(1)
|
||
|
||
worker = ASGIWorker(
|
||
worker_id,
|
||
sock,
|
||
ipc,
|
||
runtime_threads,
|
||
runtime_blocking_threads,
|
||
blocking_threads,
|
||
blocking_threads_idle_timeout,
|
||
backpressure,
|
||
http_mode,
|
||
http1_settings,
|
||
http2_settings,
|
||
websockets,
|
||
static_path,
|
||
*ssl_ctx,
|
||
metrics,
|
||
)
|
||
serve = getattr(worker, WORKERS_METHODS[runtime_mode][sock.is_uds()])
|
||
scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio)
|
||
serve(scheduler, loop, shutdown_event)
|
||
loop.run_until_complete(lifespan_handler.shutdown())
|
||
|
||
@staticmethod
|
||
@WorkerProcess.wrap_target
|
||
def _spawn_rsgi_worker(
|
||
worker_id: int,
|
||
callback: Any,
|
||
sock: Any,
|
||
ipc: Any,
|
||
loop: Any,
|
||
runtime_mode: RuntimeModes,
|
||
runtime_threads: int,
|
||
runtime_blocking_threads: int | None,
|
||
blocking_threads: int,
|
||
blocking_threads_idle_timeout: int,
|
||
backpressure: int,
|
||
task_impl: TaskImpl,
|
||
http_mode: HTTPModes,
|
||
http1_settings: HTTP1Settings | None,
|
||
http2_settings: HTTP2Settings | None,
|
||
websockets: bool,
|
||
static_path: tuple[str, str, str | None] | None,
|
||
log_access_fmt: str | None,
|
||
ssl_ctx: SSLCtx,
|
||
scope_opts: dict[str, Any],
|
||
metrics: Any,
|
||
):
|
||
from granian._signals import set_loop_signals
|
||
|
||
callback, callback_init, callback_del = _rsgi_cbs_from_target(callback)
|
||
wcallback = _future_watcher_wrapper(_rsgi_call_wrap(callback, log_access_fmt))
|
||
shutdown_event = set_loop_signals(loop)
|
||
callback_init(loop)
|
||
|
||
worker = RSGIWorker(
|
||
worker_id,
|
||
sock,
|
||
ipc,
|
||
runtime_threads,
|
||
runtime_blocking_threads,
|
||
blocking_threads,
|
||
blocking_threads_idle_timeout,
|
||
backpressure,
|
||
http_mode,
|
||
http1_settings,
|
||
http2_settings,
|
||
websockets,
|
||
static_path,
|
||
*ssl_ctx,
|
||
metrics,
|
||
)
|
||
serve = getattr(worker, WORKERS_METHODS[runtime_mode][sock.is_uds()])
|
||
scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio)
|
||
serve(scheduler, loop, shutdown_event)
|
||
callback_del(loop)
|
||
|
||
@staticmethod
|
||
@WorkerProcess.wrap_target
|
||
def _spawn_wsgi_worker(
|
||
worker_id: int,
|
||
callback: Any,
|
||
sock: Any,
|
||
ipc: Any,
|
||
loop: Any,
|
||
runtime_mode: RuntimeModes,
|
||
runtime_threads: int,
|
||
runtime_blocking_threads: int | None,
|
||
blocking_threads: int,
|
||
blocking_threads_idle_timeout: int,
|
||
backpressure: int,
|
||
task_impl: TaskImpl,
|
||
http_mode: HTTPModes,
|
||
http1_settings: HTTP1Settings | None,
|
||
http2_settings: HTTP2Settings | None,
|
||
websockets: bool,
|
||
static_path: tuple[str, str, str | None] | None,
|
||
log_access_fmt: str | None,
|
||
ssl_ctx: SSLCtx,
|
||
scope_opts: dict[str, Any],
|
||
metrics: Any,
|
||
):
|
||
from granian._signals import set_sync_signals
|
||
|
||
wcallback = _wsgi_call_wrap(callback, scope_opts, log_access_fmt)
|
||
shutdown_event = set_sync_signals()
|
||
|
||
worker = WSGIWorker(
|
||
worker_id,
|
||
sock,
|
||
ipc,
|
||
runtime_threads,
|
||
runtime_blocking_threads,
|
||
blocking_threads,
|
||
blocking_threads_idle_timeout,
|
||
backpressure,
|
||
http_mode,
|
||
http1_settings,
|
||
http2_settings,
|
||
static_path,
|
||
*ssl_ctx,
|
||
metrics,
|
||
)
|
||
serve = getattr(worker, WORKERS_METHODS[runtime_mode][sock.is_uds()])
|
||
scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio)
|
||
serve(scheduler, loop, shutdown_event)
|
||
|
||
def _init_shared_socket(self):
|
||
super()._init_shared_socket()
|
||
sock = socket.socket(fileno=self._sfd)
|
||
sock.set_inheritable(True)
|
||
self._sso = sock
|
||
|
||
def _write_pidfile(self):
|
||
super()._write_pidfile()
|
||
self._rss_collector = ProcInfoCollector()
|
||
|
||
def _unlink_pidfile(self):
|
||
self._sso.detach()
|
||
super()._unlink_pidfile()
|
||
|
||
def _start_ipc(self):
|
||
self._ipc = {}
|
||
self._ipc_sig = WorkerSignal()
|
||
|
||
# NOTE: for "reasons", on Windows the call to `os.set_blocking` fails with an OS err 9.
|
||
# I'm not sure what the problem is, thus on Windows we just disable
|
||
# IPC – and, consequentially, metrics – entirely; or, at least,
|
||
# until until someone finds out a way to make that call work.
|
||
# And, to quote John Malkovich: "fuck Microsoft!"
|
||
# (https://www.youtube.com/watch?v=2zpCOYkdvTQ)
|
||
if sys.platform == 'win32':
|
||
for idx in range(self.workers):
|
||
self._ipc[idx] = (None, None)
|
||
return
|
||
|
||
for idx in range(self.workers):
|
||
rx, tx = multiprocessing.Pipe(False)
|
||
rxd = os.dup(rx.fileno())
|
||
# WARN: on Windows, `os.set_blocking` is available only on Py >= 3.12.
|
||
# Doesn't really matter, given the call fails when available U.U
|
||
os.set_blocking(rxd, False)
|
||
self._ipc[idx] = (IPCReceiverHandle(idx, rxd), tx, rx)
|
||
|
||
# NOTE: given we use IPC only for metrics right now, let's run the receivers
|
||
# only if metrics collection is actually enabled.
|
||
if self.metrics_enabled:
|
||
for pipe in self._ipc.values():
|
||
pipe[0].run(self._ipc_sig, self._metrics)
|
||
|
||
def _stop_ipc(self):
|
||
self._ipc_sig.set()
|
||
|
||
def _handle_rss_signal(self, spawn_target, target_loader):
|
||
wpids = {wrk._id(): wrk for wrk in self.wrks}
|
||
try:
|
||
rss_data = self._rss_collector.memory(list(wpids.keys()))
|
||
except Exception:
|
||
logger.warning('Unable to collect resource usage for workers')
|
||
return
|
||
logger.debug(f'Collected resource usages for workers: {rss_data}')
|
||
cycle_samples = {}
|
||
to_restart = []
|
||
for wpid, wmem in rss_data.items():
|
||
if wmem >= self.workers_rss:
|
||
samples = self._rss_wrk_samples.get(wpid, 0) + 1
|
||
if samples >= self.rss_samples:
|
||
wrk = wpids[wpid]
|
||
logger.info(f'worker-{wrk.idx + 1} RSS over threshold, gracefully respawning..')
|
||
to_restart.append(wrk.idx)
|
||
else:
|
||
cycle_samples[wpid] = samples
|
||
else:
|
||
cycle_samples[wpid] = 0
|
||
self._rss_wrk_samples.clear()
|
||
self._rss_wrk_samples.update(cycle_samples)
|
||
if to_restart:
|
||
self._respawn_workers(to_restart, spawn_target, target_loader, delay=self.respawn_interval)
|
||
self._metrics.incr_respawn_rss(len(to_restart))
|
||
|
||
def _spawn_worker(self, idx, target, callback_loader) -> WorkerProcess:
|
||
return WorkerProcess(
|
||
parent=self,
|
||
idx=idx,
|
||
target=target,
|
||
args=(
|
||
idx + 1,
|
||
self.process_name,
|
||
callback_loader,
|
||
(self._shd, self._sso),
|
||
# NOTE: given we use IPC only for metrics right now, let's share the pipe
|
||
# only if metrics collection is actually enabled.
|
||
self._ipc[idx][1] if self.metrics_enabled else None,
|
||
self.loop,
|
||
self.log_enabled,
|
||
self.log_level,
|
||
self.log_config,
|
||
self.env_files,
|
||
self.runtime_mode,
|
||
self.runtime_threads,
|
||
self.runtime_blocking_threads,
|
||
self.blocking_threads,
|
||
self.blocking_threads_idle_timeout,
|
||
self.backpressure,
|
||
self.task_impl,
|
||
self.http,
|
||
self.http1_settings,
|
||
self.http2_settings,
|
||
self.websockets,
|
||
self.static_path,
|
||
self.log_access_format if self.log_access else None,
|
||
self.ssl_ctx,
|
||
{'url_path_prefix': self.url_path_prefix},
|
||
(self.metrics_scrape_interval if self.metrics_enabled else None, None),
|
||
),
|
||
)
|
||
|
||
def serve(
|
||
self,
|
||
spawn_target: Callable[..., None] | None = None,
|
||
target_loader: Callable[..., Callable[..., Any]] | None = None,
|
||
wrap_loader: bool = True,
|
||
):
|
||
if self.interface == Interfaces.WSGI:
|
||
if self.blocking_threads > (multiprocessing.cpu_count() * 2 + 1):
|
||
logger.warning(
|
||
f'Configuration allows spawning up to {self.blocking_threads} Python threads, '
|
||
'which seems quite high compared to the number of CPU cores available. '
|
||
'Consider reviewing your configuration and using `backpressure` to limit '
|
||
'the concurrency on the Python interpreter. '
|
||
'If this configuration is intentional, you can safely ignore this message.'
|
||
)
|
||
|
||
if self.metrics_enabled and sys.platform == 'win32':
|
||
self.metrics_enabled = False
|
||
logger.warn('Metrics are not available in Windows, ignoring.')
|
||
|
||
super().serve(spawn_target, target_loader, wrap_loader)
|