493 lines
17 KiB
Python
493 lines
17 KiB
Python
import asyncio
|
|
import multiprocessing
|
|
import sys
|
|
import time
|
|
from collections.abc import Callable, Sequence
|
|
from functools import wraps
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from .._futures import _future_watcher_wrapper, _new_cbscheduler
|
|
from .._granian import ASGIWorker, RSGIWorker, WorkerSignal
|
|
from .._imports import dotenv
|
|
from .._internal import load_env
|
|
from .._types import SSLCtx
|
|
from ..asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap
|
|
from ..errors import ConfigurationError, FatalError
|
|
from ..rsgi import _callback_wrapper as _rsgi_call_wrap, _callbacks_from_target as _rsgi_cbs_from_target
|
|
from .common import (
|
|
_PY_312,
|
|
_PYV,
|
|
AbstractServer,
|
|
AbstractWorker,
|
|
HTTP1Settings,
|
|
HTTP2Settings,
|
|
HTTPModes,
|
|
Interfaces,
|
|
LogLevels,
|
|
SSLProtocols,
|
|
TaskImpl,
|
|
logger,
|
|
)
|
|
|
|
|
|
class AsyncWorker(AbstractWorker):
|
|
def __init__(self, parent, idx, target, args, sig):
|
|
self._sig = sig
|
|
self._loop = asyncio.get_event_loop()
|
|
self._task = None
|
|
self._wtask = None
|
|
super().__init__(parent, idx, target, args)
|
|
|
|
@staticmethod
|
|
def wrap_target(target):
|
|
@wraps(target)
|
|
def wrapped(worker_id, sig, callback, sock, *args, **kwargs):
|
|
loop = asyncio.get_event_loop()
|
|
return target(worker_id, sig, callback, sock, loop, *args, **kwargs)
|
|
|
|
return wrapped
|
|
|
|
def _spawn(self, target, args):
|
|
self._task = self._loop.create_task(target(*args))
|
|
self._alive = True
|
|
|
|
def _id(self):
|
|
return id(self._task)
|
|
|
|
async def _watcher(self):
|
|
try:
|
|
await self._task
|
|
except BaseException:
|
|
pass
|
|
if not self.interrupt_by_parent:
|
|
logger.error(f'Unexpected exit from worker-{self.idx + 1}')
|
|
self.parent.interrupt_children.append(self.idx)
|
|
self.parent.main_loop_interrupt.set()
|
|
|
|
def _watch(self):
|
|
self._wtask = self._loop.create_task(self._watcher())
|
|
|
|
def start(self):
|
|
logger.info(f'Spawning worker-{self.idx + 1} with {self._idl}: {self._id()}')
|
|
self._watch()
|
|
|
|
def is_alive(self):
|
|
if not self._alive:
|
|
return False
|
|
return not self._task.done()
|
|
|
|
def terminate(self):
|
|
self._alive = False
|
|
self.interrupt_by_parent = True
|
|
self._sig.set()
|
|
|
|
def kill(self):
|
|
self._alive = False
|
|
self.interrupt_by_parent = True
|
|
self._task.cancel()
|
|
|
|
def join(self, timeout=None):
|
|
return asyncio.wait_for(self._task, timeout=timeout)
|
|
|
|
|
|
class Server(AbstractServer[AsyncWorker]):
|
|
def __init__(
|
|
self,
|
|
target: Any,
|
|
address: str = '127.0.0.1',
|
|
port: int = 8000,
|
|
uds: Path | None = None,
|
|
interface: Interfaces = Interfaces.RSGI,
|
|
blocking_threads: int | None = None,
|
|
blocking_threads_idle_timeout: int = 30,
|
|
runtime_threads: int = 1,
|
|
runtime_blocking_threads: int | None = None,
|
|
task_impl: TaskImpl = TaskImpl.asyncio,
|
|
http: HTTPModes = HTTPModes.auto,
|
|
websockets: bool = True,
|
|
backlog: int = 128,
|
|
backpressure: int | None = None,
|
|
http1_settings: HTTP1Settings | None = None,
|
|
http2_settings: HTTP2Settings | None = None,
|
|
log_enabled: bool = True,
|
|
log_level: LogLevels = LogLevels.info,
|
|
log_dictconfig: dict[str, Any] | None = None,
|
|
log_access: bool = False,
|
|
log_access_format: str | None = None,
|
|
ssl_cert: Path | None = None,
|
|
ssl_key: Path | None = None,
|
|
ssl_key_password: str | None = None,
|
|
ssl_protocol_min: SSLProtocols = SSLProtocols.tls13,
|
|
ssl_ca: Path | None = None,
|
|
ssl_crl: list[Path] | None = None,
|
|
ssl_client_verify: bool = False,
|
|
url_path_prefix: str | None = None,
|
|
factory: bool = False,
|
|
static_path_route: Sequence[str] | None = None,
|
|
static_path_mount: Sequence[Path] | None = None,
|
|
static_path_dir_to_file: str | None = None,
|
|
static_path_expires: int = 86400,
|
|
):
|
|
super().__init__(
|
|
target=target,
|
|
address=address,
|
|
port=port,
|
|
uds=uds,
|
|
interface=interface,
|
|
blocking_threads=blocking_threads,
|
|
blocking_threads_idle_timeout=blocking_threads_idle_timeout,
|
|
runtime_threads=runtime_threads,
|
|
runtime_blocking_threads=runtime_blocking_threads,
|
|
task_impl=task_impl,
|
|
http=http,
|
|
websockets=websockets,
|
|
backlog=backlog,
|
|
backpressure=backpressure,
|
|
http1_settings=http1_settings,
|
|
http2_settings=http2_settings,
|
|
log_enabled=log_enabled,
|
|
log_level=log_level,
|
|
log_dictconfig=log_dictconfig,
|
|
log_access=log_access,
|
|
log_access_format=log_access_format,
|
|
ssl_cert=ssl_cert,
|
|
ssl_key=ssl_key,
|
|
ssl_key_password=ssl_key_password,
|
|
ssl_protocol_min=ssl_protocol_min,
|
|
ssl_ca=ssl_ca,
|
|
ssl_crl=ssl_crl,
|
|
ssl_client_verify=ssl_client_verify,
|
|
url_path_prefix=url_path_prefix,
|
|
factory=factory,
|
|
static_path_route=static_path_route,
|
|
static_path_mount=static_path_mount,
|
|
static_path_dir_to_file=static_path_dir_to_file,
|
|
static_path_expires=static_path_expires,
|
|
)
|
|
self.main_loop_interrupt = asyncio.Event()
|
|
|
|
def _spawn_worker(self, idx, target, callback_loader) -> AsyncWorker:
|
|
sig = WorkerSignal()
|
|
|
|
return AsyncWorker(
|
|
parent=self,
|
|
idx=idx,
|
|
target=target,
|
|
args=(
|
|
idx + 1,
|
|
sig,
|
|
callback_loader,
|
|
self._shd,
|
|
self.runtime_threads,
|
|
self.runtime_blocking_threads,
|
|
self.blocking_threads,
|
|
self.blocking_threads_idle_timeout,
|
|
self.backpressure,
|
|
self.task_impl,
|
|
self.http,
|
|
self.http1_settings,
|
|
self.http2_settings,
|
|
self.websockets,
|
|
self.static_path,
|
|
self.log_access_format if self.log_access else None,
|
|
self.ssl_ctx,
|
|
{'url_path_prefix': self.url_path_prefix},
|
|
),
|
|
sig=sig,
|
|
)
|
|
|
|
@staticmethod
|
|
@AsyncWorker.wrap_target
|
|
async def _spawn_asgi_worker(
|
|
worker_id: int,
|
|
shutdown_event: Any,
|
|
callback: Any,
|
|
sock: Any,
|
|
loop: Any,
|
|
runtime_threads: int,
|
|
runtime_blocking_threads: int | None,
|
|
blocking_threads: int,
|
|
blocking_threads_idle_timeout: int,
|
|
backpressure: int,
|
|
task_impl: TaskImpl,
|
|
http_mode: HTTPModes,
|
|
http1_settings: HTTP1Settings | None,
|
|
http2_settings: HTTP2Settings | None,
|
|
websockets: bool,
|
|
static_path: tuple[str, str, str | None, str | None] | None,
|
|
log_access_fmt: str | None,
|
|
ssl_ctx: SSLCtx,
|
|
scope_opts: dict[str, Any],
|
|
):
|
|
wcallback = _future_watcher_wrapper(_asgi_call_wrap(callback, scope_opts, {}, log_access_fmt))
|
|
|
|
worker = ASGIWorker(
|
|
worker_id,
|
|
sock,
|
|
None,
|
|
runtime_threads,
|
|
runtime_blocking_threads,
|
|
blocking_threads,
|
|
blocking_threads_idle_timeout,
|
|
backpressure,
|
|
http_mode,
|
|
http1_settings,
|
|
http2_settings,
|
|
websockets,
|
|
static_path,
|
|
*ssl_ctx,
|
|
(None, None),
|
|
)
|
|
serve = worker.serve_async_uds if sock.is_uds() else worker.serve_async
|
|
scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio)
|
|
await serve(scheduler, loop, shutdown_event)
|
|
|
|
@staticmethod
|
|
@AsyncWorker.wrap_target
|
|
async def _spawn_asgi_lifespan_worker(
|
|
worker_id: int,
|
|
shutdown_event: Any,
|
|
callback: Any,
|
|
sock: Any,
|
|
loop: Any,
|
|
runtime_threads: int,
|
|
runtime_blocking_threads: int | None,
|
|
blocking_threads: int,
|
|
blocking_threads_idle_timeout: int,
|
|
backpressure: int,
|
|
task_impl: TaskImpl,
|
|
http_mode: HTTPModes,
|
|
http1_settings: HTTP1Settings | None,
|
|
http2_settings: HTTP2Settings | None,
|
|
websockets: bool,
|
|
static_path: tuple[str, str, str | None, str | None] | None,
|
|
log_access_fmt: str | None,
|
|
ssl_ctx: SSLCtx,
|
|
scope_opts: dict[str, Any],
|
|
):
|
|
lifespan_handler = LifespanProtocol(callback)
|
|
wcallback = _future_watcher_wrapper(
|
|
_asgi_call_wrap(callback, scope_opts, lifespan_handler.state, log_access_fmt)
|
|
)
|
|
|
|
await lifespan_handler.startup()
|
|
if lifespan_handler.interrupt:
|
|
logger.error('ASGI lifespan startup failed', exc_info=lifespan_handler.exc)
|
|
raise FatalError('ASGI lifespan startup')
|
|
|
|
worker = ASGIWorker(
|
|
worker_id,
|
|
sock,
|
|
None,
|
|
runtime_threads,
|
|
runtime_blocking_threads,
|
|
blocking_threads,
|
|
blocking_threads_idle_timeout,
|
|
backpressure,
|
|
http_mode,
|
|
http1_settings,
|
|
http2_settings,
|
|
websockets,
|
|
static_path,
|
|
*ssl_ctx,
|
|
(None, None),
|
|
)
|
|
serve = worker.serve_async_uds if sock.is_uds() else worker.serve_async
|
|
scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio)
|
|
await serve(scheduler, loop, shutdown_event)
|
|
await lifespan_handler.shutdown()
|
|
|
|
@staticmethod
|
|
@AsyncWorker.wrap_target
|
|
async def _spawn_rsgi_worker(
|
|
worker_id: int,
|
|
shutdown_event: Any,
|
|
callback: Any,
|
|
sock: Any,
|
|
loop: Any,
|
|
runtime_threads: int,
|
|
runtime_blocking_threads: int | None,
|
|
blocking_threads: int,
|
|
blocking_threads_idle_timeout: int,
|
|
backpressure: int,
|
|
task_impl: TaskImpl,
|
|
http_mode: HTTPModes,
|
|
http1_settings: HTTP1Settings | None,
|
|
http2_settings: HTTP2Settings | None,
|
|
websockets: bool,
|
|
static_path: tuple[str, str, str | None, str | None] | None,
|
|
log_access_fmt: str | None,
|
|
ssl_ctx: SSLCtx,
|
|
scope_opts: dict[str, Any],
|
|
):
|
|
callback, callback_init, callback_del = _rsgi_cbs_from_target(callback)
|
|
wcallback = _future_watcher_wrapper(_rsgi_call_wrap(callback, log_access_fmt))
|
|
callback_init(loop)
|
|
|
|
worker = RSGIWorker(
|
|
worker_id,
|
|
sock,
|
|
None,
|
|
runtime_threads,
|
|
runtime_blocking_threads,
|
|
blocking_threads,
|
|
blocking_threads_idle_timeout,
|
|
backpressure,
|
|
http_mode,
|
|
http1_settings,
|
|
http2_settings,
|
|
websockets,
|
|
static_path,
|
|
*ssl_ctx,
|
|
(None, None),
|
|
)
|
|
serve = worker.serve_async_uds if sock.is_uds() else worker.serve_async
|
|
scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio)
|
|
await serve(scheduler, loop, shutdown_event)
|
|
callback_del(loop)
|
|
|
|
async def _respawn_workers(self, workers, spawn_target, target_loader, delay: float = 0):
|
|
for idx in workers:
|
|
self.respawned_wrks[idx] = time.monotonic()
|
|
logger.info(f'Respawning worker-{idx + 1}')
|
|
old_wrk = self.wrks.pop(idx)
|
|
wrk = self._spawn_worker(idx=idx, target=spawn_target, callback_loader=target_loader)
|
|
wrk.start()
|
|
self.wrks.insert(idx, wrk)
|
|
await asyncio.sleep(delay)
|
|
logger.info(f'Stopping old worker-{idx + 1}')
|
|
old_wrk.terminate()
|
|
await old_wrk.join(self.workers_kill_timeout)
|
|
if self.workers_kill_timeout:
|
|
# the worker might still be reported alive after `join`, let's context switch
|
|
if old_wrk.is_alive():
|
|
await asyncio.sleep(0.001)
|
|
if old_wrk.is_alive():
|
|
logger.warning(f'Killing old worker-{idx + 1} after it refused to gracefully stop')
|
|
old_wrk.kill()
|
|
await old_wrk.join()
|
|
|
|
async def _stop_workers(self):
|
|
for wrk in self.wrks:
|
|
wrk.terminate()
|
|
|
|
for wrk in self.wrks:
|
|
await wrk.join(self.workers_kill_timeout)
|
|
if self.workers_kill_timeout:
|
|
if wrk.is_alive():
|
|
logger.warning(f'Killing worker-{wrk.idx} after it refused to gracefully stop')
|
|
wrk.kill()
|
|
|
|
self.wrks.clear()
|
|
|
|
def startup(self, spawn_target, target_loader):
|
|
logger.info('Starting granian (embedded)')
|
|
self._init_shared_socket()
|
|
proto = 'https' if self.ssl_ctx[0] else 'http'
|
|
logger.info(f'Listening at: {proto}://{self._bind_addr_fmt}')
|
|
|
|
load_env(self.env_files)
|
|
self._call_hooks(self.hooks_startup)
|
|
self._spawn_workers(spawn_target, target_loader)
|
|
|
|
async def _serve_loop(self, spawn_target, target_loader):
|
|
while True:
|
|
await self.main_loop_interrupt.wait()
|
|
if self.interrupt_signal:
|
|
break
|
|
|
|
if self.interrupt_children:
|
|
break
|
|
|
|
if self.reload_signal:
|
|
await self._reload(spawn_target, target_loader)
|
|
|
|
async def shutdown(self, exit_code=0):
|
|
logger.info('Shutting down granian')
|
|
await self._stop_workers()
|
|
self._call_hooks(self.hooks_shutdown)
|
|
if self.bind_uds and self.bind_uds.exists():
|
|
self.bind_uds.unlink()
|
|
|
|
async def _serve(self, spawn_target, target_loader):
|
|
target = target_loader()
|
|
self.startup(spawn_target, target)
|
|
await self._serve_loop(spawn_target, target)
|
|
await self.shutdown()
|
|
|
|
async def serve(self, spawn_target: Callable[..., None] | None = None):
|
|
def target_loader(*args, **kwargs):
|
|
if self.factory:
|
|
return self.target()
|
|
return self.target
|
|
|
|
default_spawners = {
|
|
Interfaces.ASGI: self._spawn_asgi_lifespan_worker,
|
|
Interfaces.ASGINL: self._spawn_asgi_worker,
|
|
Interfaces.RSGI: self._spawn_rsgi_worker,
|
|
}
|
|
|
|
logger.warning('Embedded server is experimental!')
|
|
|
|
if self.interface == Interfaces.WSGI:
|
|
logger.error('WSGI is not supported in embedded mode')
|
|
raise ConfigurationError('interface')
|
|
|
|
if self.reload_on_changes:
|
|
logger.error('The changes reloader is not supported in embedded mode')
|
|
raise ConfigurationError('reload')
|
|
|
|
if self.workers_rss:
|
|
logger.error('The resource monitor is not supported in embedded mode')
|
|
raise ConfigurationError('workers_max_rss')
|
|
|
|
if self.metrics_enabled:
|
|
logger.error('Metrics are not available in embedded mode')
|
|
raise ConfigurationError('metrics_enabled')
|
|
|
|
if not spawn_target:
|
|
spawn_target = default_spawners[self.interface]
|
|
|
|
if self.bind_uds and sys.platform == 'win32':
|
|
logger.error('Unix Domain sockets are not available on Windows')
|
|
raise ConfigurationError('uds')
|
|
|
|
if self.blocking_threads > 1:
|
|
logger.error('Blocking threads > 1 is not supported on ASGI and RSGI')
|
|
raise ConfigurationError('blocking_threads')
|
|
|
|
if self.websockets:
|
|
if self.http == HTTPModes.http2:
|
|
logger.info('Websockets are not supported on HTTP/2 only, ignoring')
|
|
|
|
if self.env_files and dotenv is None:
|
|
logger.error('Environment file(s) usage requires the granian[dotenv] extra')
|
|
raise ConfigurationError('env_files')
|
|
|
|
if self.blocking_threads_idle_timeout < 5 or self.blocking_threads_idle_timeout > 600:
|
|
logger.error('Blocking threads idle timeout must be between 10 and 600 seconds')
|
|
raise ConfigurationError('blocking_threads_idle_timeout')
|
|
|
|
cpus = multiprocessing.cpu_count()
|
|
if self.runtime_threads > cpus:
|
|
logger.warning(
|
|
'Configured number of Rust threads appears to be too high given the amount of CPU cores available. '
|
|
'Mind that Rust threads are not involved in Python code execution, and they almost never be the '
|
|
'limiting factor in scaling. Consider configuring the amount of blocking threads instead'
|
|
)
|
|
|
|
if self.task_impl == TaskImpl.rust:
|
|
if _PYV >= _PY_312:
|
|
self.task_impl = TaskImpl.asyncio
|
|
logger.warning('Rust task implementation is not available on Python >= 3.12, falling back to asyncio')
|
|
else:
|
|
logger.warning('Rust task implementation is experimental!')
|
|
|
|
await self._serve(spawn_target, target_loader)
|
|
|
|
def stop(self):
|
|
self.signal_handler_interrupt()
|
|
|
|
def reload(self):
|
|
self.signal_handler_reload()
|