199 lines
7.5 KiB
Python
199 lines
7.5 KiB
Python
"""Mixin that allow tasks to run during the whole app lifespan."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import inspect
|
|
import time
|
|
from collections.abc import Callable, Coroutine
|
|
from typing import TYPE_CHECKING, TypeVar, overload
|
|
|
|
from reflex_base.utils import console
|
|
from reflex_base.utils.exceptions import InvalidLifespanTaskTypeError
|
|
from starlette.applications import Starlette
|
|
|
|
from .mixin import AppMixin
|
|
|
|
if TYPE_CHECKING:
|
|
from typing_extensions import deprecated
|
|
|
|
_LifespanTaskT = TypeVar("_LifespanTaskT", bound="Callable | asyncio.Task")
|
|
|
|
|
|
def _get_task_name(task: asyncio.Task | Callable) -> str:
|
|
"""Get a display name for a lifespan task.
|
|
|
|
Args:
|
|
task: The task to get the name for.
|
|
|
|
Returns:
|
|
The name of the task.
|
|
"""
|
|
if isinstance(task, asyncio.Task):
|
|
return task.get_name()
|
|
return task.__name__ # pyright: ignore[reportAttributeAccessIssue]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class LifespanMixin(AppMixin):
|
|
"""A Mixin that allow tasks to run during the whole app lifespan.
|
|
|
|
Attributes:
|
|
lifespan_tasks: Set of lifespan tasks that are planned to run (deprecated).
|
|
"""
|
|
|
|
_lifespan_tasks: dict[asyncio.Task | Callable, None] = dataclasses.field(
|
|
default_factory=dict, init=False, repr=False
|
|
)
|
|
_lifespan_tasks_started: bool = dataclasses.field(
|
|
default=False, init=False, repr=False
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
# Static deprecation warning for IDE/type checkers.
|
|
@property
|
|
@deprecated("Use get_lifespan_tasks method instead.")
|
|
def lifespan_tasks(self) -> frozenset[asyncio.Task | Callable]:
|
|
"""Get a copy of registered lifespan tasks (deprecated)."""
|
|
...
|
|
|
|
else:
|
|
|
|
@property
|
|
def lifespan_tasks(self) -> frozenset[asyncio.Task | Callable]:
|
|
"""Get a copy of registered lifespan tasks.
|
|
|
|
Returns:
|
|
A frozenset of registered lifespan tasks.
|
|
"""
|
|
# Runtime deprecation warning prints to the console when accessed.
|
|
console.deprecate(
|
|
feature_name="LifespanMixin.lifespan_tasks",
|
|
reason="Use get_lifespan_tasks method instead to get a copy of registered lifespan tasks.",
|
|
deprecation_version="0.9.0",
|
|
removal_version="1.0",
|
|
)
|
|
return frozenset(self._lifespan_tasks)
|
|
|
|
def get_lifespan_tasks(self) -> tuple[asyncio.Task | Callable, ...]:
|
|
"""Get a copy of currently registered lifespan tasks.
|
|
|
|
Returns:
|
|
A tuple of registered lifespan tasks.
|
|
"""
|
|
return tuple(self._lifespan_tasks)
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def _run_lifespan_tasks(self, app: Starlette):
|
|
self._lifespan_tasks_started = True
|
|
running_tasks = []
|
|
try:
|
|
async with contextlib.AsyncExitStack() as stack:
|
|
for task in self._lifespan_tasks:
|
|
task_name = _get_task_name(task)
|
|
run_msg = f"Started lifespan task: {task_name} as {{type}}"
|
|
if isinstance(task, asyncio.Task):
|
|
running_tasks.append(task)
|
|
else:
|
|
signature = inspect.signature(task)
|
|
if "app" in signature.parameters:
|
|
task = functools.partial(task, app=app)
|
|
t_ = task()
|
|
if isinstance(t_, contextlib._AsyncGeneratorContextManager):
|
|
await stack.enter_async_context(t_)
|
|
console.debug(run_msg.format(type="asynccontextmanager"))
|
|
elif isinstance(t_, Coroutine):
|
|
task_ = asyncio.create_task(
|
|
t_,
|
|
name=f"reflex_lifespan_task|{task_name}|{time.time()}",
|
|
)
|
|
task_.add_done_callback(lambda t: t.result())
|
|
running_tasks.append(task_)
|
|
console.debug(run_msg.format(type="coroutine"))
|
|
else:
|
|
console.debug(run_msg.format(type="function"))
|
|
yield
|
|
finally:
|
|
for task in running_tasks:
|
|
console.debug(f"Canceling lifespan task: {task}")
|
|
task.cancel(msg="lifespan_cleanup")
|
|
# Disassociate sid / token pairings so they can be reconnected properly.
|
|
try:
|
|
event_namespace = self.event_namespace # pyright: ignore[reportAttributeAccessIssue]
|
|
except AttributeError:
|
|
pass
|
|
else:
|
|
try:
|
|
if event_namespace:
|
|
await event_namespace._token_manager.disconnect_all()
|
|
except Exception as e:
|
|
console.error(f"Error during lifespan cleanup: {e}")
|
|
# Flush any pending writes from the state manager.
|
|
try:
|
|
state_manager = self.state_manager # pyright: ignore[reportAttributeAccessIssue]
|
|
except (AttributeError, ValueError):
|
|
pass
|
|
else:
|
|
await state_manager.close()
|
|
|
|
@overload
|
|
def register_lifespan_task(
|
|
self, task: _LifespanTaskT, **task_kwargs
|
|
) -> _LifespanTaskT: ...
|
|
|
|
@overload
|
|
def register_lifespan_task(
|
|
self, task: None = None, **task_kwargs
|
|
) -> Callable[[_LifespanTaskT], _LifespanTaskT]: ...
|
|
|
|
def register_lifespan_task(
|
|
self,
|
|
task: Callable | asyncio.Task | None = None,
|
|
**task_kwargs,
|
|
):
|
|
"""Register a task to run during the lifespan of the app.
|
|
|
|
Supports three call shapes:
|
|
- `app.register_lifespan_task(fn, **kwargs)` — direct call.
|
|
- `@app.register_lifespan_task` — bare decorator.
|
|
- `@app.register_lifespan_task(**kwargs)` — parameterized decorator.
|
|
|
|
Args:
|
|
task: The task to register, or None to return a decorator.
|
|
**task_kwargs: The kwargs of the task.
|
|
|
|
Returns:
|
|
The original task when called with a task, or a decorator when
|
|
called without one.
|
|
|
|
Raises:
|
|
InvalidLifespanTaskTypeError: If the task is a generator function.
|
|
RuntimeError: If lifespan tasks are already running.
|
|
"""
|
|
if task is None:
|
|
return functools.partial(self.register_lifespan_task, **task_kwargs)
|
|
|
|
if self._lifespan_tasks_started:
|
|
msg = (
|
|
f"Cannot register lifespan task {_get_task_name(task)!r} after "
|
|
"lifespan has started. Register all tasks before the app starts."
|
|
)
|
|
raise RuntimeError(msg)
|
|
if inspect.isgeneratorfunction(task) or inspect.isasyncgenfunction(task):
|
|
msg = f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager."
|
|
raise InvalidLifespanTaskTypeError(msg)
|
|
|
|
task_name = _get_task_name(task)
|
|
registered_task = task
|
|
if task_kwargs:
|
|
if isinstance(task, asyncio.Task):
|
|
msg = f"Task {task_name!r} of type asyncio.Task cannot be registered with kwargs."
|
|
raise InvalidLifespanTaskTypeError(msg)
|
|
registered_task = functools.partial(task, **task_kwargs)
|
|
functools.update_wrapper(registered_task, task)
|
|
self._lifespan_tasks[registered_task] = None
|
|
console.debug(f"Registered lifespan task: {task_name}")
|
|
return task
|