eptm_dashboard/.venv/lib/python3.12/site-packages/reflex/app_mixins/lifespan.py

199 lines
7.5 KiB
Python

"""Mixin that allow tasks to run during the whole app lifespan."""
from __future__ import annotations
import asyncio
import contextlib
import dataclasses
import functools
import inspect
import time
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING, TypeVar, overload
from reflex_base.utils import console
from reflex_base.utils.exceptions import InvalidLifespanTaskTypeError
from starlette.applications import Starlette
from .mixin import AppMixin
if TYPE_CHECKING:
from typing_extensions import deprecated
_LifespanTaskT = TypeVar("_LifespanTaskT", bound="Callable | asyncio.Task")
def _get_task_name(task: asyncio.Task | Callable) -> str:
"""Get a display name for a lifespan task.
Args:
task: The task to get the name for.
Returns:
The name of the task.
"""
if isinstance(task, asyncio.Task):
return task.get_name()
return task.__name__ # pyright: ignore[reportAttributeAccessIssue]
@dataclasses.dataclass
class LifespanMixin(AppMixin):
"""A Mixin that allow tasks to run during the whole app lifespan.
Attributes:
lifespan_tasks: Set of lifespan tasks that are planned to run (deprecated).
"""
_lifespan_tasks: dict[asyncio.Task | Callable, None] = dataclasses.field(
default_factory=dict, init=False, repr=False
)
_lifespan_tasks_started: bool = dataclasses.field(
default=False, init=False, repr=False
)
if TYPE_CHECKING:
# Static deprecation warning for IDE/type checkers.
@property
@deprecated("Use get_lifespan_tasks method instead.")
def lifespan_tasks(self) -> frozenset[asyncio.Task | Callable]:
"""Get a copy of registered lifespan tasks (deprecated)."""
...
else:
@property
def lifespan_tasks(self) -> frozenset[asyncio.Task | Callable]:
"""Get a copy of registered lifespan tasks.
Returns:
A frozenset of registered lifespan tasks.
"""
# Runtime deprecation warning prints to the console when accessed.
console.deprecate(
feature_name="LifespanMixin.lifespan_tasks",
reason="Use get_lifespan_tasks method instead to get a copy of registered lifespan tasks.",
deprecation_version="0.9.0",
removal_version="1.0",
)
return frozenset(self._lifespan_tasks)
def get_lifespan_tasks(self) -> tuple[asyncio.Task | Callable, ...]:
"""Get a copy of currently registered lifespan tasks.
Returns:
A tuple of registered lifespan tasks.
"""
return tuple(self._lifespan_tasks)
@contextlib.asynccontextmanager
async def _run_lifespan_tasks(self, app: Starlette):
self._lifespan_tasks_started = True
running_tasks = []
try:
async with contextlib.AsyncExitStack() as stack:
for task in self._lifespan_tasks:
task_name = _get_task_name(task)
run_msg = f"Started lifespan task: {task_name} as {{type}}"
if isinstance(task, asyncio.Task):
running_tasks.append(task)
else:
signature = inspect.signature(task)
if "app" in signature.parameters:
task = functools.partial(task, app=app)
t_ = task()
if isinstance(t_, contextlib._AsyncGeneratorContextManager):
await stack.enter_async_context(t_)
console.debug(run_msg.format(type="asynccontextmanager"))
elif isinstance(t_, Coroutine):
task_ = asyncio.create_task(
t_,
name=f"reflex_lifespan_task|{task_name}|{time.time()}",
)
task_.add_done_callback(lambda t: t.result())
running_tasks.append(task_)
console.debug(run_msg.format(type="coroutine"))
else:
console.debug(run_msg.format(type="function"))
yield
finally:
for task in running_tasks:
console.debug(f"Canceling lifespan task: {task}")
task.cancel(msg="lifespan_cleanup")
# Disassociate sid / token pairings so they can be reconnected properly.
try:
event_namespace = self.event_namespace # pyright: ignore[reportAttributeAccessIssue]
except AttributeError:
pass
else:
try:
if event_namespace:
await event_namespace._token_manager.disconnect_all()
except Exception as e:
console.error(f"Error during lifespan cleanup: {e}")
# Flush any pending writes from the state manager.
try:
state_manager = self.state_manager # pyright: ignore[reportAttributeAccessIssue]
except (AttributeError, ValueError):
pass
else:
await state_manager.close()
@overload
def register_lifespan_task(
self, task: _LifespanTaskT, **task_kwargs
) -> _LifespanTaskT: ...
@overload
def register_lifespan_task(
self, task: None = None, **task_kwargs
) -> Callable[[_LifespanTaskT], _LifespanTaskT]: ...
def register_lifespan_task(
self,
task: Callable | asyncio.Task | None = None,
**task_kwargs,
):
"""Register a task to run during the lifespan of the app.
Supports three call shapes:
- `app.register_lifespan_task(fn, **kwargs)` — direct call.
- `@app.register_lifespan_task` — bare decorator.
- `@app.register_lifespan_task(**kwargs)` — parameterized decorator.
Args:
task: The task to register, or None to return a decorator.
**task_kwargs: The kwargs of the task.
Returns:
The original task when called with a task, or a decorator when
called without one.
Raises:
InvalidLifespanTaskTypeError: If the task is a generator function.
RuntimeError: If lifespan tasks are already running.
"""
if task is None:
return functools.partial(self.register_lifespan_task, **task_kwargs)
if self._lifespan_tasks_started:
msg = (
f"Cannot register lifespan task {_get_task_name(task)!r} after "
"lifespan has started. Register all tasks before the app starts."
)
raise RuntimeError(msg)
if inspect.isgeneratorfunction(task) or inspect.isasyncgenfunction(task):
msg = f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager."
raise InvalidLifespanTaskTypeError(msg)
task_name = _get_task_name(task)
registered_task = task
if task_kwargs:
if isinstance(task, asyncio.Task):
msg = f"Task {task_name!r} of type asyncio.Task cannot be registered with kwargs."
raise InvalidLifespanTaskTypeError(msg)
registered_task = functools.partial(task, **task_kwargs)
functools.update_wrapper(registered_task, task)
self._lifespan_tasks[registered_task] = None
console.debug(f"Registered lifespan task: {task_name}")
return task