from __future__ import annotations import errno import multiprocessing import os import ssl import sys import threading import time from collections.abc import Callable, Sequence from functools import partial from pathlib import Path from typing import Any, Generic, TypeVar from .._compat import _PY_312, _PYV from .._granian import MetricsAggregator, MetricsExporter, WorkerSignal from .._imports import dotenv, setproctitle, watchfiles from .._internal import build_env_loader, load_target from .._signals import set_main_signals from ..constants import HTTPModes, Interfaces, Loops, RuntimeModes, SSLProtocols, TaskImpl from ..errors import ConfigurationError, PidFileError from ..http import HTTP1Settings, HTTP2Settings from ..log import DEFAULT_ACCESSLOG_FMT, LogLevels, configure_logging, logger from ..net import SocketSpec, UnixSocketSpec WT = TypeVar('WT') WORKERS_METHODS = { RuntimeModes.mt: {False: 'serve_mtr', True: 'serve_mtr_uds'}, RuntimeModes.st: {False: 'serve_str', True: 'serve_str_uds'}, } class AbstractWorker: _idl = 'id' def __init__(self, parent: AbstractServer, idx: int, target: Any, args: Any): self.parent = parent self.idx = idx self.interrupt_by_parent = False self.birth = time.monotonic() self._spawn(target, args) def _spawn(self, target, args): raise NotImplementedError def _id(self): raise NotImplementedError def _watcher(self): self.inner.join() if not self.interrupt_by_parent: logger.error(f'Unexpected exit from worker-{self.idx + 1}') if self.parent.reload_on_changes and self.parent.reload_ignore_worker_failure: return self.parent.interrupt_children.append(self.idx) self.parent.main_loop_interrupt.set() def _watch(self): watcher = threading.Thread(target=self._watcher) watcher.start() def start(self): self.inner.start() logger.info(f'Spawning worker-{self.idx + 1} with {self._idl}: {self._id()}') self._watch() def is_alive(self): return self.inner.is_alive() def terminate(self): raise NotImplementedError def kill(self): raise NotImplementedError def join(self, timeout=None): self.inner.join(timeout=timeout) class AbstractServer(Generic[WT]): def __init__( self, target: str, address: str = '127.0.0.1', port: int = 8000, uds: Path | None = None, uds_permissions: int | None = None, interface: Interfaces = Interfaces.RSGI, workers: int = 1, blocking_threads: int | None = None, blocking_threads_idle_timeout: int = 30, runtime_threads: int = 1, runtime_blocking_threads: int | None = None, runtime_mode: RuntimeModes = RuntimeModes.auto, loop: Loops = Loops.auto, task_impl: TaskImpl = TaskImpl.asyncio, http: HTTPModes = HTTPModes.auto, websockets: bool = True, backlog: int = 1024, backpressure: int | None = None, http1_settings: HTTP1Settings | None = None, http2_settings: HTTP2Settings | None = None, log_enabled: bool = True, log_level: LogLevels = LogLevels.info, log_dictconfig: dict[str, Any] | None = None, log_access: bool = False, log_access_format: str | None = None, ssl_cert: Path | None = None, ssl_key: Path | None = None, ssl_key_password: str | None = None, ssl_protocol_min: SSLProtocols = SSLProtocols.tls13, ssl_ca: Path | None = None, ssl_crl: list[Path] | None = None, ssl_client_verify: bool = False, url_path_prefix: str | None = None, respawn_failed_workers: bool = False, respawn_interval: float = 3.5, rss_sample_interval: int = 30, rss_samples: int = 1, workers_lifetime: int | None = None, workers_max_rss: int | None = None, workers_kill_timeout: int | None = None, factory: bool = False, working_dir: Path | None = None, env_files: Sequence[Path] | None = None, static_path_route: Sequence[str] | None = None, static_path_mount: Sequence[Path] | None = None, static_path_dir_to_file: str | None = None, static_path_expires: int = 86400, metrics_enabled: bool = False, metrics_scrape_interval: int = 15, metrics_address: str = '127.0.0.1', metrics_port: int = 9090, reload: bool = False, reload_paths: Sequence[Path] | None = None, reload_ignore_dirs: Sequence[str] | None = None, reload_ignore_patterns: Sequence[str] | None = None, reload_ignore_paths: Sequence[Path] | None = None, reload_filter: type[watchfiles.BaseFilter] | None = None, reload_tick: int = 50, reload_ignore_worker_failure: bool = False, process_name: str | None = None, pid_file: Path | None = None, ): self.target = target self.bind_addr = address self.bind_port = port self.bind_uds = uds.resolve() if uds else None self.uds_permissions = uds_permissions self.interface = interface self.workers = max(1, workers) self.runtime_threads = max(1, runtime_threads) self.runtime_blocking_threads = 512 if runtime_blocking_threads is None else max(1, runtime_blocking_threads) self.runtime_mode = runtime_mode self.loop = loop self.task_impl = task_impl self.http = http self.websockets = websockets self.backlog = max(128, backlog) self.backpressure = max(1, backpressure or self.backlog // self.workers) self.blocking_threads = ( blocking_threads if blocking_threads is not None else max(1, (self.backpressure // 2) if self.interface == Interfaces.WSGI else 1) ) self.blocking_threads_idle_timeout = blocking_threads_idle_timeout self.http1_settings = http1_settings self.http2_settings = http2_settings self.log_enabled = log_enabled self.log_level = log_level self.log_config = log_dictconfig self.log_access = log_access self.log_access_format = log_access_format or DEFAULT_ACCESSLOG_FMT self.url_path_prefix = url_path_prefix self.respawn_failed_workers = respawn_failed_workers self.reload_on_changes = reload self.respawn_interval = respawn_interval self.rss_sample_interval = rss_sample_interval self.rss_samples = rss_samples self._rss_wrk_samples = {} self.workers_lifetime = workers_lifetime self.workers_rss = workers_max_rss * 1024 * 1024 if workers_max_rss else None self.workers_kill_timeout = workers_kill_timeout self.factory = factory self.working_dir = working_dir self.env_files = env_files or () self.static_path = None self.metrics_enabled = metrics_enabled self.metrics_scrape_interval = metrics_scrape_interval self.metrics_address = metrics_address self.metrics_port = metrics_port self.reload_paths = reload_paths or [Path.cwd()] self.reload_ignore_paths = reload_ignore_paths or () self.reload_ignore_dirs = reload_ignore_dirs or () self.reload_ignore_patterns = reload_ignore_patterns or () self.reload_filter = reload_filter self.reload_tick = reload_tick self.reload_ignore_worker_failure = reload_ignore_worker_failure self.process_name = process_name self.pid_file = pid_file self.hooks_startup = [] self.hooks_reload = [] self.hooks_shutdown = [] configure_logging(self.log_level, self.log_config, self.log_enabled) if static_path_mount: self._init_static_mounts( static_path_route or [], static_path_mount, static_path_dir_to_file, (str(static_path_expires) if static_path_expires else None), ) self.build_ssl_context( ssl_cert, ssl_key, ssl_key_password, ssl_protocol_min, ssl_ca, ssl_crl or [], ssl_client_verify ) self._ssp = None self._shd = None self._sfd = None self._metrics = MetricsAggregator(self.workers) self._metrics_exporter = MetricsExporter(self._metrics) self.wrks: list[WT] = [] self.main_loop_interrupt = threading.Event() self.interrupt_signal = False self.interrupt_children = [] self.respawned_wrks = {} self.reload_signal = False self.lifetime_signal = False self.rss_signal = False self.pid = None self._env_loader = build_env_loader() def _init_static_mounts( self, routes: Sequence[str], paths: Sequence[Path], dir_to_file: str | None, expires: str | None, ): if not paths: return if len(paths) == 1 and not routes: self.static_path = ( [('/static', str(paths[0].resolve()))], dir_to_file, expires, ) return if len(paths) != len(routes): logger.error('Static path routes and mounts should have the same length') raise ConfigurationError('static_path') self.static_path = ( [(routes[idx], str(path.resolve())) for idx, path in enumerate(paths)], dir_to_file, expires, ) def build_ssl_context( self, cert: Path | None, key: Path | None, password: str | None, proto: SSLProtocols, ca: Path | None, crl: list[Path], client_verify: bool, ): if not (cert and key): self.ssl_ctx = (False, None, None, None, str(proto), None, [], False) return # uneeded? ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ctx.load_cert_chain(str(cert.resolve()), str(key.resolve()), password) #: build ctx if client_verify and not ca: logger.warning('SSL client verification requires a CA certificate, ignoring') client_verify = False self.ssl_ctx = ( True, str(cert.resolve()), str(key.resolve()), password, str(proto), str(ca.resolve()) if ca else None, [str(item.resolve()) for item in crl], client_verify, ) @property def _bind_addr_fmt(self): return f'unix:{self.bind_uds}' if self.bind_uds else f'{self.bind_addr}:{self.bind_port}' @staticmethod def _call_hooks(hooks): for hook in hooks: hook() def on_startup(self, hook: Callable[[], Any]) -> Callable[[], Any]: self.hooks_startup.append(hook) return hook def on_reload(self, hook: Callable[[], Any]) -> Callable[[], Any]: self.hooks_reload.append(hook) return hook def on_shutdown(self, hook: Callable[[], Any]) -> Callable[[], Any]: self.hooks_shutdown.append(hook) return hook def _init_shared_socket(self): if self.bind_uds: self._ssp = UnixSocketSpec(str(self.bind_uds), self.backlog, self.uds_permissions) else: self._ssp = SocketSpec(self.bind_addr, self.bind_port, self.backlog) self._shd = self._ssp.build() self._sfd = self._shd.get_fd() def signal_handler_interrupt(self, *args, **kwargs): self.interrupt_signal = True self.main_loop_interrupt.set() def signal_handler_reload(self, *args, **kwargs): self.reload_signal = True self.main_loop_interrupt.set() def _spawn_worker(self, idx, target, callback_loader, socket_loader) -> WT: raise NotImplementedError def _spawn_workers(self, spawn_target, target_loader): for idx in range(self.workers): wrk = self._spawn_worker(idx=idx, target=spawn_target, callback_loader=target_loader) wrk.start() self.wrks.append(wrk) self._metrics.incr_spawn(self.workers) def _respawn_workers(self, workers, spawn_target, target_loader, delay: float = 0): for idx in workers: self.respawned_wrks[idx] = time.monotonic() logger.info(f'Respawning worker-{idx + 1}') old_wrk = self.wrks.pop(idx) wrk = self._spawn_worker(idx=idx, target=spawn_target, callback_loader=target_loader) wrk.start() self.wrks.insert(idx, wrk) time.sleep(delay) logger.info(f'Stopping old worker-{idx + 1}') old_wrk.terminate() old_wrk.join(self.workers_kill_timeout) if self.workers_kill_timeout: # the worker might still be reported alive after `join`, let's context switch if old_wrk.is_alive(): time.sleep(0.001) if old_wrk.is_alive(): logger.warning(f'Killing old worker-{idx + 1} after it refused to gracefully stop') old_wrk.kill() old_wrk.join() self._metrics.incr_spawn(len(workers)) def _stop_workers(self): for wrk in self.wrks: wrk.terminate() for wrk in self.wrks: wrk.join(self.workers_kill_timeout) if self.workers_kill_timeout: # the worker might still be reported after `join`, let's context switch if wrk.is_alive(): time.sleep(0.001) if wrk.is_alive(): logger.warning(f'Killing worker-{wrk.idx} after it refused to gracefully stop') wrk.kill() wrk.join() self.wrks.clear() def _workers_lifetime_watcher(self, ttl): time.sleep(ttl) self.lifetime_signal = True self.main_loop_interrupt.set() def _workers_rss_watcher(self): time.sleep(self.rss_sample_interval) self.rss_signal = True self.main_loop_interrupt.set() def _watch_workers_lifetime(self, ttl): waker = threading.Thread(target=self._workers_lifetime_watcher, args=(ttl,), daemon=True) waker.start() def _watch_workers_rss(self): waker = threading.Thread(target=self._workers_rss_watcher, daemon=True) waker.start() def _write_pid(self): with self.pid_file.open('w') as pid_file: pid_file.write(str(self.pid)) def _write_pidfile(self): if not self.pid_file: return existing_pid = None if self.pid_file.exists(): try: with self.pid_file.open('r') as pid_file: existing_pid = int(pid_file.read()) except Exception: logger.error(f'Unable to read existing PID file {self.pid_file}') raise PidFileError if existing_pid is not None and existing_pid != self.pid: existing_process = True try: os.kill(existing_pid, 0) except OSError as e: if e.args[0] == errno.ESRCH: existing_process = False if existing_process: logger.error(f'The PID file {self.pid_file} already exists for {existing_pid}') raise PidFileError self._write_pid() def _unlink_pidfile(self): if self.bind_uds and self.bind_uds.exists(): self.bind_uds.unlink() if not (self.pid_file and self.pid_file.exists()): return try: with self.pid_file.open('r') as pid_file: file_pid = int(pid_file.read()) except Exception: logger.error(f'Unable to read PID file {self.pid_file}') return if file_pid == self.pid: self.pid_file.unlink() def _start_ipc(self): pass def _stop_ipc(self): pass def _start_metrics(self): self._metrics_sig = WorkerSignal() self._metrics_exporter.run( SocketSpec(self.metrics_address, self.metrics_port, 128).build(), self._metrics_sig, ) def _stop_metrics(self): self._metrics_sig.set() def startup(self, spawn_target, target_loader): self.pid = os.getpid() logger.info(f'Starting granian (main PID: {self.pid})') self._write_pidfile() set_main_signals(self.signal_handler_interrupt, self.signal_handler_reload) self._init_shared_socket() self._start_ipc() proto = 'https' if self.ssl_ctx[0] else 'http' logger.info(f'Listening at: {proto}://{self._bind_addr_fmt}') self._env_loader(self.env_files) self._call_hooks(self.hooks_startup) self._spawn_workers(spawn_target, target_loader) if self.workers_lifetime is not None: self._watch_workers_lifetime(self.workers_lifetime) if self.workers_rss is not None: self._watch_workers_rss() if self.metrics_enabled: self._start_metrics() def shutdown(self, exit_code=0): logger.info('Shutting down granian') if self.metrics_enabled: self._stop_metrics() self._stop_workers() self._stop_ipc() self._call_hooks(self.hooks_shutdown) self._unlink_pidfile() if not exit_code and self.interrupt_children: exit_code = 1 if exit_code: sys.exit(exit_code) def _reload(self, spawn_target, target_loader): logger.info('HUP signal received, gracefully respawning workers..') workers = list(range(self.workers)) self.reload_signal = False self.respawned_wrks.clear() self.main_loop_interrupt.clear() self._env_loader(self.env_files) self._call_hooks(self.hooks_reload) return self._respawn_workers(workers, spawn_target, target_loader, delay=self.respawn_interval) def _handle_rss_signal(self, spawn_target, target_loader): raise NotImplementedError def _serve_loop(self, spawn_target, target_loader): while True: self.main_loop_interrupt.wait() if self.interrupt_signal: break if self.interrupt_children: if not self.respawn_failed_workers: break cycle = time.monotonic() if any(cycle - self.respawned_wrks.get(idx, 0) <= 5.5 for idx in self.interrupt_children): logger.error('Worker crash loop detected, exiting') break workers = list(self.interrupt_children) self.interrupt_children.clear() self.respawned_wrks.clear() self.main_loop_interrupt.clear() self._respawn_workers(workers, spawn_target, target_loader) self._metrics.incr_respawn_err(1) if self.reload_signal: self._reload(spawn_target, target_loader) if self.lifetime_signal or self.rss_signal: self.main_loop_interrupt.clear() if self.lifetime_signal: self.lifetime_signal = False ttl = self.workers_lifetime * 0.95 now = time.monotonic() etas = [self.workers_lifetime] for worker in list(self.wrks): if (now - worker.birth) >= ttl: logger.info(f'worker-{worker.idx + 1} lifetime expired, gracefully respawning..') self._respawn_workers( [worker.idx], spawn_target, target_loader, delay=self.respawn_interval ) self._metrics.incr_respawn_ttl(1) else: elapsed = now - worker.birth remaining = self.workers_lifetime - elapsed etas.append(max(60, int(remaining))) next_tick = min(etas) self._watch_workers_lifetime(next_tick) if self.rss_signal: self.rss_signal = False self._handle_rss_signal(spawn_target, target_loader) self._watch_workers_rss() def _serve(self, spawn_target, target_loader): self.startup(spawn_target, target_loader) self._serve_loop(spawn_target, target_loader) self.shutdown() def _serve_with_reloader(self, spawn_target, target_loader): if watchfiles is None: logger.error('Using --reload requires the granian[reload] extra') sys.exit(1) # Use given or default filter rules reload_filter_cls = self.reload_filter or watchfiles.filters.DefaultFilter # Extend `reload_filter` with provided args reload_filter_cls.ignore_dirs = (*reload_filter_cls.ignore_dirs, *self.reload_ignore_dirs) reload_filter_cls.ignore_entity_patterns = ( *reload_filter_cls.ignore_entity_patterns, *self.reload_ignore_patterns, ) reload_filter_cls.ignore_paths = (*reload_filter_cls.ignore_paths, *self.reload_ignore_paths) # Construct new filter reload_filter = reload_filter_cls() self.startup(spawn_target, target_loader) serve_loop = True while serve_loop: try: for changes in watchfiles.watch( *self.reload_paths, watch_filter=reload_filter, stop_event=self.main_loop_interrupt, step=self.reload_tick, ): logger.info('Changes detected, reloading workers..') for change, file in changes: logger.info(f'{change.raw_str().capitalize()}: {file}') self._env_loader(self.env_files) self._call_hooks(self.hooks_reload) self._stop_workers() self._spawn_workers(spawn_target, target_loader) except StopIteration: pass if self.reload_signal: self._reload(spawn_target, target_loader) else: serve_loop = False self.shutdown() def serve( self, spawn_target: Callable[..., None] | None = None, target_loader: Callable[..., Callable[..., Any]] | None = None, wrap_loader: bool = True, ): default_spawners = { Interfaces.ASGI: self._spawn_asgi_lifespan_worker, Interfaces.ASGINL: self._spawn_asgi_worker, Interfaces.RSGI: self._spawn_rsgi_worker, Interfaces.WSGI: self._spawn_wsgi_worker, } if target_loader: if wrap_loader: target_loader = partial(target_loader, self.target) else: target_loader = partial(load_target, self.target, wd=self.working_dir, factory=self.factory) if not spawn_target: spawn_target = default_spawners[self.interface] if sys.platform == 'win32' and self.workers > 1: self.workers = 1 logger.warn( 'Due to a bug in Windows unblocking socket implementation ' "granian can't support multiple workers on this platform. " 'Number of workers will now fallback to 1.' ) if self.bind_uds and sys.platform == 'win32': logger.error('Unix Domain sockets are not available on Windows') raise ConfigurationError('uds') if self.interface != Interfaces.WSGI and self.blocking_threads > 1: logger.error('Blocking threads > 1 is not supported on ASGI and RSGI') raise ConfigurationError('blocking_threads') if self.websockets: if self.interface == Interfaces.WSGI: self.websockets = False logger.info('Websockets are not supported on WSGI, ignoring') if self.http == HTTPModes.http2: logger.info('Websockets are not supported on HTTP/2 only, ignoring') if setproctitle is not None: self.process_name = self.process_name or (f'granian {self.interface} {self._bind_addr_fmt} {self.target}') setproctitle.setproctitle(self.process_name) elif self.process_name is not None: logger.error('Setting process name requires the granian[pname] extra') raise ConfigurationError('process_name') if self.env_files and dotenv is None: logger.error('Environment file(s) usage requires the granian[dotenv] extra') raise ConfigurationError('env_files') if self.workers_lifetime is not None: if self.workers_lifetime < 60: logger.error('Workers lifetime cannot be less than 60 seconds') raise ConfigurationError('workers_lifetime') if self.reload_on_changes: self.workers_lifetime = None logger.info('Workers lifetime is not available in combination with changes reloader, ignoring') if self.workers_rss is not None: if self.reload_on_changes: self.workers_rss = None logger.info('The resource monitor is not available in combination with changes reloader, ignoring') if self.metrics_enabled: if self.reload_on_changes: self.metrics_enabled = False logger.info('Metrics are not available in combination with changes reloader, ignoring') if self.blocking_threads_idle_timeout < 5 or self.blocking_threads_idle_timeout > 600: logger.error('Blocking threads idle timeout must be between 5 and 600 seconds') raise ConfigurationError('blocking_threads_idle_timeout') cpus = multiprocessing.cpu_count() if self.workers > cpus: logger.warning( 'Configured number of workers appears to be higher than the amount of CPU cores available. ' 'Mind that such value might actually decrease the overall throughput of the server. ' f'Consider using {cpus} workers and tune threads configuration instead' ) if self.runtime_threads > cpus: logger.warning( 'Configured number of Rust threads appears to be too high given the amount of CPU cores available. ' 'Mind that Rust threads are not involved in Python code execution, and they almost never be the ' 'limiting factor in scaling. Consider configuring the amount of blocking threads instead' ) if self.runtime_mode == RuntimeModes.auto: self.runtime_mode = RuntimeModes.st if any( [ self.interface != Interfaces.RSGI, self.runtime_threads > 1, self.http == HTTPModes.http2, ] ): self.runtime_mode = RuntimeModes.mt if self.task_impl == TaskImpl.rust: if _PYV >= _PY_312: self.task_impl = TaskImpl.asyncio logger.warning('Rust task implementation is not available on Python >= 3.12, falling back to asyncio') else: logger.warning('Rust task implementation is experimental!') serve_method = self._serve_with_reloader if self.reload_on_changes else self._serve serve_method(spawn_target, target_loader)