import multiprocessing import os import socket import sys from collections.abc import Callable from functools import wraps from typing import Any from .._futures import _future_watcher_wrapper, _new_cbscheduler from .._granian import ( ASGIWorker, IPCReceiverHandle, IPCSenderHandle, ProcInfoCollector, RSGIWorker, SocketHolder, WorkerSignal, WSGIWorker, ) from .._internal import load_env from .._types import SSLCtx from ..asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap from ..rsgi import _callback_wrapper as _rsgi_call_wrap, _callbacks_from_target as _rsgi_cbs_from_target from ..wsgi import _callback_wrapper as _wsgi_call_wrap from .common import ( WORKERS_METHODS, AbstractServer, AbstractWorker, HTTP1Settings, HTTP2Settings, HTTPModes, Interfaces, RuntimeModes, TaskImpl, configure_logging, logger, setproctitle, ) multiprocessing.allow_connection_pickling() class WorkerProcess(AbstractWorker): _idl = 'PID' def __init__(self, parent, idx, target, args): # NOTE: Python 3.14 defaults mp spawn method to 'forkserver' on Linux, # which doesn't really play well with shared sockets. self._spawn_method = multiprocessing.get_start_method() if self._spawn_method not in {'fork', 'spawn'}: self._spawn_method = 'spawn' super().__init__(parent, idx, target, args) @staticmethod def wrap_target(target): @wraps(target) def wrapped( worker_id, process_name, callback_loader, sock, ipc, loop_impl, log_enabled, log_level, log_config, env_files, *args, **kwargs, ): from granian._loops import loops if process_name: setproctitle.setproctitle(f'{process_name} worker-{worker_id}') configure_logging(log_level, log_config, log_enabled) load_env(env_files) _ipc_handle = None sock, _sso = sock if sys.platform == 'win32': sock = SocketHolder(_sso.fileno()) elif ipc: _ipc_fd = os.dup(ipc.fileno()) os.set_blocking(_ipc_fd, False) _ipc_handle = IPCSenderHandle(_ipc_fd) loop = loops.get(loop_impl) callback = callback_loader() return target(worker_id, callback, sock, _ipc_handle, loop, *args, **kwargs) return wrapped def _spawn(self, target, args): self.inner = multiprocessing.get_context(self._spawn_method).Process( name='granian-worker', target=target, args=args ) def _id(self): return self.inner.pid def terminate(self): self.interrupt_by_parent = True self.inner.terminate() def kill(self): self.interrupt_by_parent = True self.inner.kill() class MPServer(AbstractServer[WorkerProcess]): @staticmethod @WorkerProcess.wrap_target def _spawn_asgi_worker( worker_id: int, callback: Any, sock: Any, ipc: Any, loop: Any, runtime_mode: RuntimeModes, runtime_threads: int, runtime_blocking_threads: int | None, blocking_threads: int, blocking_threads_idle_timeout: int, backpressure: int, task_impl: TaskImpl, http_mode: HTTPModes, http1_settings: HTTP1Settings | None, http2_settings: HTTP2Settings | None, websockets: bool, static_path: tuple[str, str, str | None, str | None] | None, log_access_fmt: str | None, ssl_ctx: SSLCtx, scope_opts: dict[str, Any], metrics: Any, ): from granian._signals import set_loop_signals wcallback = _future_watcher_wrapper(_asgi_call_wrap(callback, scope_opts, {}, log_access_fmt)) shutdown_event = set_loop_signals(loop) worker = ASGIWorker( worker_id, sock, ipc, runtime_threads, runtime_blocking_threads, blocking_threads, blocking_threads_idle_timeout, backpressure, http_mode, http1_settings, http2_settings, websockets, static_path, *ssl_ctx, metrics, ) serve = getattr(worker, WORKERS_METHODS[runtime_mode][sock.is_uds()]) scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio) serve(scheduler, loop, shutdown_event) @staticmethod @WorkerProcess.wrap_target def _spawn_asgi_lifespan_worker( worker_id: int, callback: Any, sock: Any, ipc: Any, loop: Any, runtime_mode: RuntimeModes, runtime_threads: int, runtime_blocking_threads: int | None, blocking_threads: int, blocking_threads_idle_timeout: int, backpressure: int, task_impl: TaskImpl, http_mode: HTTPModes, http1_settings: HTTP1Settings | None, http2_settings: HTTP2Settings | None, websockets: bool, static_path: tuple[str, str, str | None] | None, log_access_fmt: str | None, ssl_ctx: SSLCtx, scope_opts: dict[str, Any], metrics: Any, ): from granian._signals import set_loop_signals lifespan_handler = LifespanProtocol(callback) wcallback = _future_watcher_wrapper( _asgi_call_wrap(callback, scope_opts, lifespan_handler.state, log_access_fmt) ) shutdown_event = set_loop_signals(loop) loop.run_until_complete(lifespan_handler.startup()) if lifespan_handler.interrupt: logger.error('ASGI lifespan startup failed', exc_info=lifespan_handler.exc) sys.exit(1) worker = ASGIWorker( worker_id, sock, ipc, runtime_threads, runtime_blocking_threads, blocking_threads, blocking_threads_idle_timeout, backpressure, http_mode, http1_settings, http2_settings, websockets, static_path, *ssl_ctx, metrics, ) serve = getattr(worker, WORKERS_METHODS[runtime_mode][sock.is_uds()]) scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio) serve(scheduler, loop, shutdown_event) loop.run_until_complete(lifespan_handler.shutdown()) @staticmethod @WorkerProcess.wrap_target def _spawn_rsgi_worker( worker_id: int, callback: Any, sock: Any, ipc: Any, loop: Any, runtime_mode: RuntimeModes, runtime_threads: int, runtime_blocking_threads: int | None, blocking_threads: int, blocking_threads_idle_timeout: int, backpressure: int, task_impl: TaskImpl, http_mode: HTTPModes, http1_settings: HTTP1Settings | None, http2_settings: HTTP2Settings | None, websockets: bool, static_path: tuple[str, str, str | None] | None, log_access_fmt: str | None, ssl_ctx: SSLCtx, scope_opts: dict[str, Any], metrics: Any, ): from granian._signals import set_loop_signals callback, callback_init, callback_del = _rsgi_cbs_from_target(callback) wcallback = _future_watcher_wrapper(_rsgi_call_wrap(callback, log_access_fmt)) shutdown_event = set_loop_signals(loop) callback_init(loop) worker = RSGIWorker( worker_id, sock, ipc, runtime_threads, runtime_blocking_threads, blocking_threads, blocking_threads_idle_timeout, backpressure, http_mode, http1_settings, http2_settings, websockets, static_path, *ssl_ctx, metrics, ) serve = getattr(worker, WORKERS_METHODS[runtime_mode][sock.is_uds()]) scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio) serve(scheduler, loop, shutdown_event) callback_del(loop) @staticmethod @WorkerProcess.wrap_target def _spawn_wsgi_worker( worker_id: int, callback: Any, sock: Any, ipc: Any, loop: Any, runtime_mode: RuntimeModes, runtime_threads: int, runtime_blocking_threads: int | None, blocking_threads: int, blocking_threads_idle_timeout: int, backpressure: int, task_impl: TaskImpl, http_mode: HTTPModes, http1_settings: HTTP1Settings | None, http2_settings: HTTP2Settings | None, websockets: bool, static_path: tuple[str, str, str | None] | None, log_access_fmt: str | None, ssl_ctx: SSLCtx, scope_opts: dict[str, Any], metrics: Any, ): from granian._signals import set_sync_signals wcallback = _wsgi_call_wrap(callback, scope_opts, log_access_fmt) shutdown_event = set_sync_signals() worker = WSGIWorker( worker_id, sock, ipc, runtime_threads, runtime_blocking_threads, blocking_threads, blocking_threads_idle_timeout, backpressure, http_mode, http1_settings, http2_settings, static_path, *ssl_ctx, metrics, ) serve = getattr(worker, WORKERS_METHODS[runtime_mode][sock.is_uds()]) scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio) serve(scheduler, loop, shutdown_event) def _init_shared_socket(self): super()._init_shared_socket() sock = socket.socket(fileno=self._sfd) sock.set_inheritable(True) self._sso = sock def _write_pidfile(self): super()._write_pidfile() self._rss_collector = ProcInfoCollector() def _unlink_pidfile(self): self._sso.detach() super()._unlink_pidfile() def _start_ipc(self): self._ipc = {} self._ipc_sig = WorkerSignal() # NOTE: for "reasons", on Windows the call to `os.set_blocking` fails with an OS err 9. # I'm not sure what the problem is, thus on Windows we just disable # IPC – and, consequentially, metrics – entirely; or, at least, # until until someone finds out a way to make that call work. # And, to quote John Malkovich: "fuck Microsoft!" # (https://www.youtube.com/watch?v=2zpCOYkdvTQ) if sys.platform == 'win32': for idx in range(self.workers): self._ipc[idx] = (None, None) return for idx in range(self.workers): rx, tx = multiprocessing.Pipe(False) rxd = os.dup(rx.fileno()) # WARN: on Windows, `os.set_blocking` is available only on Py >= 3.12. # Doesn't really matter, given the call fails when available U.U os.set_blocking(rxd, False) self._ipc[idx] = (IPCReceiverHandle(idx, rxd), tx, rx) # NOTE: given we use IPC only for metrics right now, let's run the receivers # only if metrics collection is actually enabled. if self.metrics_enabled: for pipe in self._ipc.values(): pipe[0].run(self._ipc_sig, self._metrics) def _stop_ipc(self): self._ipc_sig.set() def _handle_rss_signal(self, spawn_target, target_loader): wpids = {wrk._id(): wrk for wrk in self.wrks} try: rss_data = self._rss_collector.memory(list(wpids.keys())) except Exception: logger.warning('Unable to collect resource usage for workers') return logger.debug(f'Collected resource usages for workers: {rss_data}') cycle_samples = {} to_restart = [] for wpid, wmem in rss_data.items(): if wmem >= self.workers_rss: samples = self._rss_wrk_samples.get(wpid, 0) + 1 if samples >= self.rss_samples: wrk = wpids[wpid] logger.info(f'worker-{wrk.idx + 1} RSS over threshold, gracefully respawning..') to_restart.append(wrk.idx) else: cycle_samples[wpid] = samples else: cycle_samples[wpid] = 0 self._rss_wrk_samples.clear() self._rss_wrk_samples.update(cycle_samples) if to_restart: self._respawn_workers(to_restart, spawn_target, target_loader, delay=self.respawn_interval) self._metrics.incr_respawn_rss(len(to_restart)) def _spawn_worker(self, idx, target, callback_loader) -> WorkerProcess: return WorkerProcess( parent=self, idx=idx, target=target, args=( idx + 1, self.process_name, callback_loader, (self._shd, self._sso), # NOTE: given we use IPC only for metrics right now, let's share the pipe # only if metrics collection is actually enabled. self._ipc[idx][1] if self.metrics_enabled else None, self.loop, self.log_enabled, self.log_level, self.log_config, self.env_files, self.runtime_mode, self.runtime_threads, self.runtime_blocking_threads, self.blocking_threads, self.blocking_threads_idle_timeout, self.backpressure, self.task_impl, self.http, self.http1_settings, self.http2_settings, self.websockets, self.static_path, self.log_access_format if self.log_access else None, self.ssl_ctx, {'url_path_prefix': self.url_path_prefix}, (self.metrics_scrape_interval if self.metrics_enabled else None, None), ), ) def serve( self, spawn_target: Callable[..., None] | None = None, target_loader: Callable[..., Callable[..., Any]] | None = None, wrap_loader: bool = True, ): if self.interface == Interfaces.WSGI: if self.blocking_threads > (multiprocessing.cpu_count() * 2 + 1): logger.warning( f'Configuration allows spawning up to {self.blocking_threads} Python threads, ' 'which seems quite high compared to the number of CPU cores available. ' 'Consider reviewing your configuration and using `backpressure` to limit ' 'the concurrency on the Python interpreter. ' 'If this configuration is intentional, you can safely ignore this message.' ) if self.metrics_enabled and sys.platform == 'win32': self.metrics_enabled = False logger.warn('Metrics are not available in Windows, ignoring.') super().serve(spawn_target, target_loader, wrap_loader)