eptm_dashboard/.venv/lib/python3.12/site-packages/granian/server/common.py

720 lines
27 KiB
Python

from __future__ import annotations
import errno
import multiprocessing
import os
import ssl
import sys
import threading
import time
from collections.abc import Callable, Sequence
from functools import partial
from pathlib import Path
from typing import Any, Generic, TypeVar
from .._compat import _PY_312, _PYV
from .._granian import MetricsAggregator, MetricsExporter, WorkerSignal
from .._imports import dotenv, setproctitle, watchfiles
from .._internal import build_env_loader, load_target
from .._signals import set_main_signals
from ..constants import HTTPModes, Interfaces, Loops, RuntimeModes, SSLProtocols, TaskImpl
from ..errors import ConfigurationError, PidFileError
from ..http import HTTP1Settings, HTTP2Settings
from ..log import DEFAULT_ACCESSLOG_FMT, LogLevels, configure_logging, logger
from ..net import SocketSpec, UnixSocketSpec
WT = TypeVar('WT')
WORKERS_METHODS = {
RuntimeModes.mt: {False: 'serve_mtr', True: 'serve_mtr_uds'},
RuntimeModes.st: {False: 'serve_str', True: 'serve_str_uds'},
}
class AbstractWorker:
_idl = 'id'
def __init__(self, parent: AbstractServer, idx: int, target: Any, args: Any):
self.parent = parent
self.idx = idx
self.interrupt_by_parent = False
self.birth = time.monotonic()
self._spawn(target, args)
def _spawn(self, target, args):
raise NotImplementedError
def _id(self):
raise NotImplementedError
def _watcher(self):
self.inner.join()
if not self.interrupt_by_parent:
logger.error(f'Unexpected exit from worker-{self.idx + 1}')
if self.parent.reload_on_changes and self.parent.reload_ignore_worker_failure:
return
self.parent.interrupt_children.append(self.idx)
self.parent.main_loop_interrupt.set()
def _watch(self):
watcher = threading.Thread(target=self._watcher)
watcher.start()
def start(self):
self.inner.start()
logger.info(f'Spawning worker-{self.idx + 1} with {self._idl}: {self._id()}')
self._watch()
def is_alive(self):
return self.inner.is_alive()
def terminate(self):
raise NotImplementedError
def kill(self):
raise NotImplementedError
def join(self, timeout=None):
self.inner.join(timeout=timeout)
class AbstractServer(Generic[WT]):
def __init__(
self,
target: str,
address: str = '127.0.0.1',
port: int = 8000,
uds: Path | None = None,
uds_permissions: int | None = None,
interface: Interfaces = Interfaces.RSGI,
workers: int = 1,
blocking_threads: int | None = None,
blocking_threads_idle_timeout: int = 30,
runtime_threads: int = 1,
runtime_blocking_threads: int | None = None,
runtime_mode: RuntimeModes = RuntimeModes.auto,
loop: Loops = Loops.auto,
task_impl: TaskImpl = TaskImpl.asyncio,
http: HTTPModes = HTTPModes.auto,
websockets: bool = True,
backlog: int = 1024,
backpressure: int | None = None,
http1_settings: HTTP1Settings | None = None,
http2_settings: HTTP2Settings | None = None,
log_enabled: bool = True,
log_level: LogLevels = LogLevels.info,
log_dictconfig: dict[str, Any] | None = None,
log_access: bool = False,
log_access_format: str | None = None,
ssl_cert: Path | None = None,
ssl_key: Path | None = None,
ssl_key_password: str | None = None,
ssl_protocol_min: SSLProtocols = SSLProtocols.tls13,
ssl_ca: Path | None = None,
ssl_crl: list[Path] | None = None,
ssl_client_verify: bool = False,
url_path_prefix: str | None = None,
respawn_failed_workers: bool = False,
respawn_interval: float = 3.5,
rss_sample_interval: int = 30,
rss_samples: int = 1,
workers_lifetime: int | None = None,
workers_max_rss: int | None = None,
workers_kill_timeout: int | None = None,
factory: bool = False,
working_dir: Path | None = None,
env_files: Sequence[Path] | None = None,
static_path_route: Sequence[str] | None = None,
static_path_mount: Sequence[Path] | None = None,
static_path_dir_to_file: str | None = None,
static_path_expires: int = 86400,
metrics_enabled: bool = False,
metrics_scrape_interval: int = 15,
metrics_address: str = '127.0.0.1',
metrics_port: int = 9090,
reload: bool = False,
reload_paths: Sequence[Path] | None = None,
reload_ignore_dirs: Sequence[str] | None = None,
reload_ignore_patterns: Sequence[str] | None = None,
reload_ignore_paths: Sequence[Path] | None = None,
reload_filter: type[watchfiles.BaseFilter] | None = None,
reload_tick: int = 50,
reload_ignore_worker_failure: bool = False,
process_name: str | None = None,
pid_file: Path | None = None,
):
self.target = target
self.bind_addr = address
self.bind_port = port
self.bind_uds = uds.resolve() if uds else None
self.uds_permissions = uds_permissions
self.interface = interface
self.workers = max(1, workers)
self.runtime_threads = max(1, runtime_threads)
self.runtime_blocking_threads = 512 if runtime_blocking_threads is None else max(1, runtime_blocking_threads)
self.runtime_mode = runtime_mode
self.loop = loop
self.task_impl = task_impl
self.http = http
self.websockets = websockets
self.backlog = max(128, backlog)
self.backpressure = max(1, backpressure or self.backlog // self.workers)
self.blocking_threads = (
blocking_threads
if blocking_threads is not None
else max(1, (self.backpressure // 2) if self.interface == Interfaces.WSGI else 1)
)
self.blocking_threads_idle_timeout = blocking_threads_idle_timeout
self.http1_settings = http1_settings
self.http2_settings = http2_settings
self.log_enabled = log_enabled
self.log_level = log_level
self.log_config = log_dictconfig
self.log_access = log_access
self.log_access_format = log_access_format or DEFAULT_ACCESSLOG_FMT
self.url_path_prefix = url_path_prefix
self.respawn_failed_workers = respawn_failed_workers
self.reload_on_changes = reload
self.respawn_interval = respawn_interval
self.rss_sample_interval = rss_sample_interval
self.rss_samples = rss_samples
self._rss_wrk_samples = {}
self.workers_lifetime = workers_lifetime
self.workers_rss = workers_max_rss * 1024 * 1024 if workers_max_rss else None
self.workers_kill_timeout = workers_kill_timeout
self.factory = factory
self.working_dir = working_dir
self.env_files = env_files or ()
self.static_path = None
self.metrics_enabled = metrics_enabled
self.metrics_scrape_interval = metrics_scrape_interval
self.metrics_address = metrics_address
self.metrics_port = metrics_port
self.reload_paths = reload_paths or [Path.cwd()]
self.reload_ignore_paths = reload_ignore_paths or ()
self.reload_ignore_dirs = reload_ignore_dirs or ()
self.reload_ignore_patterns = reload_ignore_patterns or ()
self.reload_filter = reload_filter
self.reload_tick = reload_tick
self.reload_ignore_worker_failure = reload_ignore_worker_failure
self.process_name = process_name
self.pid_file = pid_file
self.hooks_startup = []
self.hooks_reload = []
self.hooks_shutdown = []
configure_logging(self.log_level, self.log_config, self.log_enabled)
if static_path_mount:
self._init_static_mounts(
static_path_route or [],
static_path_mount,
static_path_dir_to_file,
(str(static_path_expires) if static_path_expires else None),
)
self.build_ssl_context(
ssl_cert, ssl_key, ssl_key_password, ssl_protocol_min, ssl_ca, ssl_crl or [], ssl_client_verify
)
self._ssp = None
self._shd = None
self._sfd = None
self._metrics = MetricsAggregator(self.workers)
self._metrics_exporter = MetricsExporter(self._metrics)
self.wrks: list[WT] = []
self.main_loop_interrupt = threading.Event()
self.interrupt_signal = False
self.interrupt_children = []
self.respawned_wrks = {}
self.reload_signal = False
self.lifetime_signal = False
self.rss_signal = False
self.pid = None
self._env_loader = build_env_loader()
def _init_static_mounts(
self,
routes: Sequence[str],
paths: Sequence[Path],
dir_to_file: str | None,
expires: str | None,
):
if not paths:
return
if len(paths) == 1 and not routes:
self.static_path = (
[('/static', str(paths[0].resolve()))],
dir_to_file,
expires,
)
return
if len(paths) != len(routes):
logger.error('Static path routes and mounts should have the same length')
raise ConfigurationError('static_path')
self.static_path = (
[(routes[idx], str(path.resolve())) for idx, path in enumerate(paths)],
dir_to_file,
expires,
)
def build_ssl_context(
self,
cert: Path | None,
key: Path | None,
password: str | None,
proto: SSLProtocols,
ca: Path | None,
crl: list[Path],
client_verify: bool,
):
if not (cert and key):
self.ssl_ctx = (False, None, None, None, str(proto), None, [], False)
return
# uneeded?
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ctx.load_cert_chain(str(cert.resolve()), str(key.resolve()), password)
#: build ctx
if client_verify and not ca:
logger.warning('SSL client verification requires a CA certificate, ignoring')
client_verify = False
self.ssl_ctx = (
True,
str(cert.resolve()),
str(key.resolve()),
password,
str(proto),
str(ca.resolve()) if ca else None,
[str(item.resolve()) for item in crl],
client_verify,
)
@property
def _bind_addr_fmt(self):
return f'unix:{self.bind_uds}' if self.bind_uds else f'{self.bind_addr}:{self.bind_port}'
@staticmethod
def _call_hooks(hooks):
for hook in hooks:
hook()
def on_startup(self, hook: Callable[[], Any]) -> Callable[[], Any]:
self.hooks_startup.append(hook)
return hook
def on_reload(self, hook: Callable[[], Any]) -> Callable[[], Any]:
self.hooks_reload.append(hook)
return hook
def on_shutdown(self, hook: Callable[[], Any]) -> Callable[[], Any]:
self.hooks_shutdown.append(hook)
return hook
def _init_shared_socket(self):
if self.bind_uds:
self._ssp = UnixSocketSpec(str(self.bind_uds), self.backlog, self.uds_permissions)
else:
self._ssp = SocketSpec(self.bind_addr, self.bind_port, self.backlog)
self._shd = self._ssp.build()
self._sfd = self._shd.get_fd()
def signal_handler_interrupt(self, *args, **kwargs):
self.interrupt_signal = True
self.main_loop_interrupt.set()
def signal_handler_reload(self, *args, **kwargs):
self.reload_signal = True
self.main_loop_interrupt.set()
def _spawn_worker(self, idx, target, callback_loader, socket_loader) -> WT:
raise NotImplementedError
def _spawn_workers(self, spawn_target, target_loader):
for idx in range(self.workers):
wrk = self._spawn_worker(idx=idx, target=spawn_target, callback_loader=target_loader)
wrk.start()
self.wrks.append(wrk)
self._metrics.incr_spawn(self.workers)
def _respawn_workers(self, workers, spawn_target, target_loader, delay: float = 0):
for idx in workers:
self.respawned_wrks[idx] = time.monotonic()
logger.info(f'Respawning worker-{idx + 1}')
old_wrk = self.wrks.pop(idx)
wrk = self._spawn_worker(idx=idx, target=spawn_target, callback_loader=target_loader)
wrk.start()
self.wrks.insert(idx, wrk)
time.sleep(delay)
logger.info(f'Stopping old worker-{idx + 1}')
old_wrk.terminate()
old_wrk.join(self.workers_kill_timeout)
if self.workers_kill_timeout:
# the worker might still be reported alive after `join`, let's context switch
if old_wrk.is_alive():
time.sleep(0.001)
if old_wrk.is_alive():
logger.warning(f'Killing old worker-{idx + 1} after it refused to gracefully stop')
old_wrk.kill()
old_wrk.join()
self._metrics.incr_spawn(len(workers))
def _stop_workers(self):
for wrk in self.wrks:
wrk.terminate()
for wrk in self.wrks:
wrk.join(self.workers_kill_timeout)
if self.workers_kill_timeout:
# the worker might still be reported after `join`, let's context switch
if wrk.is_alive():
time.sleep(0.001)
if wrk.is_alive():
logger.warning(f'Killing worker-{wrk.idx} after it refused to gracefully stop')
wrk.kill()
wrk.join()
self.wrks.clear()
def _workers_lifetime_watcher(self, ttl):
time.sleep(ttl)
self.lifetime_signal = True
self.main_loop_interrupt.set()
def _workers_rss_watcher(self):
time.sleep(self.rss_sample_interval)
self.rss_signal = True
self.main_loop_interrupt.set()
def _watch_workers_lifetime(self, ttl):
waker = threading.Thread(target=self._workers_lifetime_watcher, args=(ttl,), daemon=True)
waker.start()
def _watch_workers_rss(self):
waker = threading.Thread(target=self._workers_rss_watcher, daemon=True)
waker.start()
def _write_pid(self):
with self.pid_file.open('w') as pid_file:
pid_file.write(str(self.pid))
def _write_pidfile(self):
if not self.pid_file:
return
existing_pid = None
if self.pid_file.exists():
try:
with self.pid_file.open('r') as pid_file:
existing_pid = int(pid_file.read())
except Exception:
logger.error(f'Unable to read existing PID file {self.pid_file}')
raise PidFileError
if existing_pid is not None and existing_pid != self.pid:
existing_process = True
try:
os.kill(existing_pid, 0)
except OSError as e:
if e.args[0] == errno.ESRCH:
existing_process = False
if existing_process:
logger.error(f'The PID file {self.pid_file} already exists for {existing_pid}')
raise PidFileError
self._write_pid()
def _unlink_pidfile(self):
if self.bind_uds and self.bind_uds.exists():
self.bind_uds.unlink()
if not (self.pid_file and self.pid_file.exists()):
return
try:
with self.pid_file.open('r') as pid_file:
file_pid = int(pid_file.read())
except Exception:
logger.error(f'Unable to read PID file {self.pid_file}')
return
if file_pid == self.pid:
self.pid_file.unlink()
def _start_ipc(self):
pass
def _stop_ipc(self):
pass
def _start_metrics(self):
self._metrics_sig = WorkerSignal()
self._metrics_exporter.run(
SocketSpec(self.metrics_address, self.metrics_port, 128).build(),
self._metrics_sig,
)
def _stop_metrics(self):
self._metrics_sig.set()
def startup(self, spawn_target, target_loader):
self.pid = os.getpid()
logger.info(f'Starting granian (main PID: {self.pid})')
self._write_pidfile()
set_main_signals(self.signal_handler_interrupt, self.signal_handler_reload)
self._init_shared_socket()
self._start_ipc()
proto = 'https' if self.ssl_ctx[0] else 'http'
logger.info(f'Listening at: {proto}://{self._bind_addr_fmt}')
self._env_loader(self.env_files)
self._call_hooks(self.hooks_startup)
self._spawn_workers(spawn_target, target_loader)
if self.workers_lifetime is not None:
self._watch_workers_lifetime(self.workers_lifetime)
if self.workers_rss is not None:
self._watch_workers_rss()
if self.metrics_enabled:
self._start_metrics()
def shutdown(self, exit_code=0):
logger.info('Shutting down granian')
if self.metrics_enabled:
self._stop_metrics()
self._stop_workers()
self._stop_ipc()
self._call_hooks(self.hooks_shutdown)
self._unlink_pidfile()
if not exit_code and self.interrupt_children:
exit_code = 1
if exit_code:
sys.exit(exit_code)
def _reload(self, spawn_target, target_loader):
logger.info('HUP signal received, gracefully respawning workers..')
workers = list(range(self.workers))
self.reload_signal = False
self.respawned_wrks.clear()
self.main_loop_interrupt.clear()
self._env_loader(self.env_files)
self._call_hooks(self.hooks_reload)
return self._respawn_workers(workers, spawn_target, target_loader, delay=self.respawn_interval)
def _handle_rss_signal(self, spawn_target, target_loader):
raise NotImplementedError
def _serve_loop(self, spawn_target, target_loader):
while True:
self.main_loop_interrupt.wait()
if self.interrupt_signal:
break
if self.interrupt_children:
if not self.respawn_failed_workers:
break
cycle = time.monotonic()
if any(cycle - self.respawned_wrks.get(idx, 0) <= 5.5 for idx in self.interrupt_children):
logger.error('Worker crash loop detected, exiting')
break
workers = list(self.interrupt_children)
self.interrupt_children.clear()
self.respawned_wrks.clear()
self.main_loop_interrupt.clear()
self._respawn_workers(workers, spawn_target, target_loader)
self._metrics.incr_respawn_err(1)
if self.reload_signal:
self._reload(spawn_target, target_loader)
if self.lifetime_signal or self.rss_signal:
self.main_loop_interrupt.clear()
if self.lifetime_signal:
self.lifetime_signal = False
ttl = self.workers_lifetime * 0.95
now = time.monotonic()
etas = [self.workers_lifetime]
for worker in list(self.wrks):
if (now - worker.birth) >= ttl:
logger.info(f'worker-{worker.idx + 1} lifetime expired, gracefully respawning..')
self._respawn_workers(
[worker.idx], spawn_target, target_loader, delay=self.respawn_interval
)
self._metrics.incr_respawn_ttl(1)
else:
elapsed = now - worker.birth
remaining = self.workers_lifetime - elapsed
etas.append(max(60, int(remaining)))
next_tick = min(etas)
self._watch_workers_lifetime(next_tick)
if self.rss_signal:
self.rss_signal = False
self._handle_rss_signal(spawn_target, target_loader)
self._watch_workers_rss()
def _serve(self, spawn_target, target_loader):
self.startup(spawn_target, target_loader)
self._serve_loop(spawn_target, target_loader)
self.shutdown()
def _serve_with_reloader(self, spawn_target, target_loader):
if watchfiles is None:
logger.error('Using --reload requires the granian[reload] extra')
sys.exit(1)
# Use given or default filter rules
reload_filter_cls = self.reload_filter or watchfiles.filters.DefaultFilter
# Extend `reload_filter` with provided args
reload_filter_cls.ignore_dirs = (*reload_filter_cls.ignore_dirs, *self.reload_ignore_dirs)
reload_filter_cls.ignore_entity_patterns = (
*reload_filter_cls.ignore_entity_patterns,
*self.reload_ignore_patterns,
)
reload_filter_cls.ignore_paths = (*reload_filter_cls.ignore_paths, *self.reload_ignore_paths)
# Construct new filter
reload_filter = reload_filter_cls()
self.startup(spawn_target, target_loader)
serve_loop = True
while serve_loop:
try:
for changes in watchfiles.watch(
*self.reload_paths,
watch_filter=reload_filter,
stop_event=self.main_loop_interrupt,
step=self.reload_tick,
):
logger.info('Changes detected, reloading workers..')
for change, file in changes:
logger.info(f'{change.raw_str().capitalize()}: {file}')
self._env_loader(self.env_files)
self._call_hooks(self.hooks_reload)
self._stop_workers()
self._spawn_workers(spawn_target, target_loader)
except StopIteration:
pass
if self.reload_signal:
self._reload(spawn_target, target_loader)
else:
serve_loop = False
self.shutdown()
def serve(
self,
spawn_target: Callable[..., None] | None = None,
target_loader: Callable[..., Callable[..., Any]] | None = None,
wrap_loader: bool = True,
):
default_spawners = {
Interfaces.ASGI: self._spawn_asgi_lifespan_worker,
Interfaces.ASGINL: self._spawn_asgi_worker,
Interfaces.RSGI: self._spawn_rsgi_worker,
Interfaces.WSGI: self._spawn_wsgi_worker,
}
if target_loader:
if wrap_loader:
target_loader = partial(target_loader, self.target)
else:
target_loader = partial(load_target, self.target, wd=self.working_dir, factory=self.factory)
if not spawn_target:
spawn_target = default_spawners[self.interface]
if sys.platform == 'win32' and self.workers > 1:
self.workers = 1
logger.warn(
'Due to a bug in Windows unblocking socket implementation '
"granian can't support multiple workers on this platform. "
'Number of workers will now fallback to 1.'
)
if self.bind_uds and sys.platform == 'win32':
logger.error('Unix Domain sockets are not available on Windows')
raise ConfigurationError('uds')
if self.interface != Interfaces.WSGI and self.blocking_threads > 1:
logger.error('Blocking threads > 1 is not supported on ASGI and RSGI')
raise ConfigurationError('blocking_threads')
if self.websockets:
if self.interface == Interfaces.WSGI:
self.websockets = False
logger.info('Websockets are not supported on WSGI, ignoring')
if self.http == HTTPModes.http2:
logger.info('Websockets are not supported on HTTP/2 only, ignoring')
if setproctitle is not None:
self.process_name = self.process_name or (f'granian {self.interface} {self._bind_addr_fmt} {self.target}')
setproctitle.setproctitle(self.process_name)
elif self.process_name is not None:
logger.error('Setting process name requires the granian[pname] extra')
raise ConfigurationError('process_name')
if self.env_files and dotenv is None:
logger.error('Environment file(s) usage requires the granian[dotenv] extra')
raise ConfigurationError('env_files')
if self.workers_lifetime is not None:
if self.workers_lifetime < 60:
logger.error('Workers lifetime cannot be less than 60 seconds')
raise ConfigurationError('workers_lifetime')
if self.reload_on_changes:
self.workers_lifetime = None
logger.info('Workers lifetime is not available in combination with changes reloader, ignoring')
if self.workers_rss is not None:
if self.reload_on_changes:
self.workers_rss = None
logger.info('The resource monitor is not available in combination with changes reloader, ignoring')
if self.metrics_enabled:
if self.reload_on_changes:
self.metrics_enabled = False
logger.info('Metrics are not available in combination with changes reloader, ignoring')
if self.blocking_threads_idle_timeout < 5 or self.blocking_threads_idle_timeout > 600:
logger.error('Blocking threads idle timeout must be between 5 and 600 seconds')
raise ConfigurationError('blocking_threads_idle_timeout')
cpus = multiprocessing.cpu_count()
if self.workers > cpus:
logger.warning(
'Configured number of workers appears to be higher than the amount of CPU cores available. '
'Mind that such value might actually decrease the overall throughput of the server. '
f'Consider using {cpus} workers and tune threads configuration instead'
)
if self.runtime_threads > cpus:
logger.warning(
'Configured number of Rust threads appears to be too high given the amount of CPU cores available. '
'Mind that Rust threads are not involved in Python code execution, and they almost never be the '
'limiting factor in scaling. Consider configuring the amount of blocking threads instead'
)
if self.runtime_mode == RuntimeModes.auto:
self.runtime_mode = RuntimeModes.st
if any(
[
self.interface != Interfaces.RSGI,
self.runtime_threads > 1,
self.http == HTTPModes.http2,
]
):
self.runtime_mode = RuntimeModes.mt
if self.task_impl == TaskImpl.rust:
if _PYV >= _PY_312:
self.task_impl = TaskImpl.asyncio
logger.warning('Rust task implementation is not available on Python >= 3.12, falling back to asyncio')
else:
logger.warning('Rust task implementation is experimental!')
serve_method = self._serve_with_reloader if self.reload_on_changes else self._serve
serve_method(spawn_target, target_loader)