eptm_dashboard/.venv/lib/python3.12/site-packages/reflex/istate/shared.py

521 lines
21 KiB
Python

"""Base classes for shared / linked states."""
import asyncio
import contextlib
from collections.abc import AsyncIterator
from typing import TypeVar
from reflex_base.constants import ROUTER_DATA
from reflex_base.event import Event, get_hydrate_event
from reflex_base.utils import console
from reflex_base.utils.exceptions import ReflexRuntimeError
from typing_extensions import Self
from reflex.istate.manager.token import BaseStateToken
from reflex.state import BaseState, State, _override_base_method
UPDATE_OTHER_CLIENT_TASKS: set[asyncio.Task] = set()
LINKED_STATE = TypeVar("LINKED_STATE", bound="SharedStateBaseInternal")
def _log_update_client_errors(task: asyncio.Task):
"""Log errors from updating other clients.
Args:
task: The asyncio task to check for errors.
"""
try:
task.result()
except Exception as e:
console.warn(f"Error updating linked client: {e}")
finally:
UPDATE_OTHER_CLIENT_TASKS.discard(task)
def _do_update_other_tokens(
affected_tokens: set[str],
previous_dirty_vars: dict[str, set[str]],
state_type: type[BaseState],
) -> list[asyncio.Task]:
"""Update other clients after a shared state update.
Submit the updates in separate asyncio tasks to avoid deadlocking.
Args:
affected_tokens: The tokens to update.
previous_dirty_vars: The dirty vars to apply to other clients.
state_type: The type of the shared state.
Returns:
The list of asyncio tasks created to perform the updates.
"""
from reflex.utils.prerequisites import get_app
app = get_app().app
async def _update_client(token: str):
async with app.modify_state(
BaseStateToken(ident=token, cls=state_type),
previous_dirty_vars=previous_dirty_vars,
):
pass
tasks = []
for affected_token in affected_tokens:
# Don't send updates for disconnected clients.
if affected_token not in app.event_namespace._token_manager.token_to_socket:
continue
# TODO: remove disconnected clients after some time.
t = asyncio.create_task(_update_client(affected_token))
UPDATE_OTHER_CLIENT_TASKS.add(t)
t.add_done_callback(_log_update_client_errors)
tasks.append(t)
return tasks
@contextlib.asynccontextmanager
async def _patch_state(
original_state: BaseState, linked_state: BaseState, full_delta: bool = False
):
"""Patch the linked state into the original state's tree, restoring it afterward.
Args:
original_state: The original shared state.
linked_state: The linked shared state.
full_delta: If True, mark all Vars in linked_state dirty and resolve
the delta from the root. This option is used when linking or unlinking
to ensure that other computed vars in the tree pick up the newly
linked/unlinked values.
"""
if (original_parent_state := original_state.parent_state) is None:
msg = "Cannot patch root state as linked state."
raise ReflexRuntimeError(msg)
state_name = original_state.get_name()
original_parent_state.substates[state_name] = linked_state
linked_parent_state = linked_state.parent_state
linked_state.parent_state = original_parent_state
try:
if full_delta:
linked_state.dirty_vars.update(linked_state.base_vars)
linked_state.dirty_vars.update(linked_state._backend_vars)
linked_state.dirty_vars.update(linked_state.computed_vars)
linked_state._mark_dirty()
# Apply the updates into the existing state tree for rehydrate.
root_state = original_state._get_root_state()
root_state.dirty_vars.add("router")
root_state.dirty_vars.add(ROUTER_DATA)
root_state._mark_dirty()
await root_state._get_resolved_delta()
yield
finally:
original_parent_state.substates[state_name] = original_state
linked_state.parent_state = linked_parent_state
class SharedStateBaseInternal(State):
"""The private base state for all shared states."""
_exit_stack: contextlib.AsyncExitStack | None = None
_held_locks: dict[str, dict[type[BaseState], BaseState]] | None = None
_held_locks_lock: asyncio.Lock = asyncio.Lock()
def __getstate__(self):
"""Override redis serialization to remove temporary fields.
Returns:
The state dictionary without temporary fields.
"""
s = super().__getstate__()
s.pop("_previous_dirty_vars", None)
s.pop("_exit_stack", None)
s.pop("_held_locks", None)
s.pop("_held_locks_lock", None)
return s
@_override_base_method
def _clean(self):
"""Override BaseState._clean to track the last set of dirty vars.
This is necessary for applying dirty vars from one event to other linked states.
"""
if (
previous_dirty_vars := getattr(self, "_previous_dirty_vars", None)
) is not None:
previous_dirty_vars.clear()
previous_dirty_vars.update(self.dirty_vars)
super()._clean()
@_override_base_method
def _mark_dirty(self):
"""Override BaseState._mark_dirty to avoid marking certain vars as dirty.
Since these internal fields are not persisted to redis, they shouldn't cause the
state to be considered dirty either.
"""
self.dirty_vars.discard("_previous_dirty_vars")
self.dirty_vars.discard("_exit_stack")
self.dirty_vars.discard("_held_locks")
self.dirty_vars.discard("_held_locks_lock")
# Only mark dirty if there are still dirty vars, or any substate is dirty
if self.dirty_vars or any(
substate.dirty_vars for substate in self.substates.values()
):
super()._mark_dirty()
def _rehydrate(self):
"""Get the events to rehydrate the state.
Returns:
The events to rehydrate the state (these should be returned/yielded).
"""
return [
Event(
name=get_hydrate_event(self._get_root_state()),
),
State.set_is_hydrated(True),
]
async def _resolve_linked_state(
self, state_cls: type["BaseState"], linked_token: str
) -> "BaseState":
"""Load and patch a linked state that was not pre-loaded in the tree.
Called by State._get_state_from_redis when a state in
_reflex_internal_links is not yet in the cache. This loads the
private copy into the tree first, then patches the linked version
on top of it via _internal_patch_linked_state.
Args:
state_cls: The shared state class to resolve.
linked_token: The shared token the state is linked to.
Returns:
The linked state instance, patched into the current tree.
Raises:
ReflexRuntimeError: If the resolved state is not a SharedState.
"""
root_state = self._get_root_state()
# Load the private copy into the tree so _internal_patch_linked_state
# has an original to swap out (needed for unlink / restore).
original_state = await BaseState._get_state_from_redis(root_state, state_cls)
if isinstance(original_state, SharedStateBaseInternal):
return await original_state._internal_patch_linked_state(linked_token)
msg = f"Failed to resolve linked state {state_cls.get_full_name()} for token {linked_token}: state does not inherit from rx.SharedState"
raise ReflexRuntimeError(msg)
async def _link_to(self, token: str) -> Self:
"""Link this shared state to a token.
After linking, subsequent access to this shared state will affect the
linked token's state, and cause changes to be propagated to all other
clients linked to that token.
Args:
token: The token to link to (Cannot contain underscore characters).
Returns:
The newly linked state.
Raises:
ReflexRuntimeError: If linking fails or token is invalid.
"""
if not token:
msg = "Cannot link shared state to empty token."
raise ReflexRuntimeError(msg)
if not isinstance(self, SharedState):
msg = "Can only link SharedState instances."
raise ReflexRuntimeError(msg)
if self._linked_to == token:
return self # already linked to this token
if self._linked_to and self._linked_to != token:
# Disassociate from previous linked token since unlink will not be called.
self._linked_from.discard(self.router.session.client_token)
# TODO: Change StateManager to accept token + class instead of combining them in a string.
if "_" in token:
msg = f"Invalid token {token} for linking state {self.get_full_name()}, cannot use underscore (_) in the token name."
raise ReflexRuntimeError(msg)
# Associate substate with the given link token.
state_name = self.get_full_name()
if self._reflex_internal_links is None:
self._reflex_internal_links = {}
self._reflex_internal_links[state_name] = token
return await self._internal_patch_linked_state(token, full_delta=True)
async def _unlink(self):
"""Unlink this shared state from its linked token.
Returns:
The events to rehydrate the state after unlinking (these should be returned/yielded).
"""
from reflex.istate.manager import get_state_manager
if not isinstance(self, SharedState):
msg = "Can only unlink SharedState instances."
raise ReflexRuntimeError(msg)
state_name = self.get_full_name()
if (
not self._reflex_internal_links
or state_name not in self._reflex_internal_links
):
msg = f"State {state_name} is not linked and cannot be unlinked."
raise ReflexRuntimeError(msg)
# Break the linkage for future events.
self._reflex_internal_links.pop(state_name)
self._linked_from.discard(self.router.session.client_token)
# Patch in the original state, apply updates, then rehydrate.
private_root_state = await get_state_manager().get_state(
BaseStateToken(
ident=self.router.session.client_token,
cls=type(self),
)
)
private_state = await private_root_state.get_state(type(self))
async with _patch_state(
original_state=self,
linked_state=private_state,
full_delta=True,
):
return self._rehydrate()
async def _internal_patch_linked_state(
self, token: str, full_delta: bool = False
) -> Self:
"""Load and replace this state with the linked state for a given token.
Must be called inside a `_modify_linked_states` context, to ensure locks are
released after the event is done processing.
Args:
token: The token of the linked state.
full_delta: If True, mark all Vars in linked_state dirty and resolve
delta to update cached computed vars
Returns:
The state that was linked into the tree.
"""
from reflex.istate.manager import get_state_manager
if self._exit_stack is None or self._held_locks is None:
msg = "Cannot link shared state outside of _modify_linked_states context."
raise ReflexRuntimeError(msg)
linked_root_state = None
# Get the newly linked state and update pointers/delta for subsequent events.
if token not in self._held_locks:
async with self._held_locks_lock:
if token not in self._held_locks:
linked_root_state = await self._exit_stack.enter_async_context(
get_state_manager().modify_state(
BaseStateToken(ident=token, cls=type(self))
)
)
self._held_locks.setdefault(token, {})
# Set client_token on the linked root so that subsequent get_state
# calls when directly modifying a linked token will load the
# associated instance.
if linked_root_state.router.session.client_token != token:
import dataclasses as dc
linked_root_state.router = dc.replace(
linked_root_state.router,
session=dc.replace(
linked_root_state.router.session, client_token=token
),
)
if linked_root_state is None:
linked_root_state = await get_state_manager().get_state(
BaseStateToken(ident=token, cls=type(self))
)
linked_state = await linked_root_state.get_state(type(self))
if not isinstance(linked_state, SharedState):
msg = f"Linked state for token {token} is not a SharedState."
raise ReflexRuntimeError(msg)
# Avoid unnecessary dirtiness of shared state when there are no changes.
if type(self) not in self._held_locks[token]:
self._held_locks[token][type(self)] = linked_state
if self.router.session.client_token not in linked_state._linked_from:
linked_state._linked_from.add(self.router.session.client_token)
if linked_state._linked_to != token:
linked_state._linked_to = token
await self._exit_stack.enter_async_context(
_patch_state(
original_state=self,
linked_state=linked_state,
full_delta=full_delta,
)
)
return linked_state
def _held_locks_linked_states(self) -> list["SharedState"]:
"""Get all linked states currently held by this state.
Returns:
The list of linked states currently held.
"""
if self._held_locks is None:
return []
return [
linked_state
for linked_state_cls_to_instance in self._held_locks.values()
for linked_state in linked_state_cls_to_instance.values()
if isinstance(linked_state, SharedState)
]
@contextlib.asynccontextmanager
async def _modify_linked_states(
self, previous_dirty_vars: dict[str, set[str]] | None = None
) -> AsyncIterator[None]:
"""Take lock, fetch all linked states, and patch them into the current state tree.
If previous_dirty_vars is NOT provided, then any dirty vars after
exiting the context will be applied to all other clients linked to this
state's linked token.
Args:
previous_dirty_vars: When apply linked state changes to other
tokens, provide mapping of state full_name to set of dirty vars.
Yields:
None.
"""
if self._exit_stack is not None:
msg = "Cannot nest _modify_linked_states contexts."
raise ReflexRuntimeError(msg)
if self._reflex_internal_links is None:
msg = "No linked states to modify."
raise ReflexRuntimeError(msg)
self._exit_stack = contextlib.AsyncExitStack()
self._held_locks = {}
current_dirty_vars: dict[str, set[str]] = {}
affected_tokens: set[str] = set()
try:
# Go through all linked states and patch them in if they are present in the tree
for linked_state_name, linked_token in self._reflex_internal_links.items():
linked_state_cls: type[SharedState] = (
self.get_root_state().get_class_substate( # pyright: ignore[reportAssignmentType]
linked_state_name
)
)
try:
original_state = self._get_state_from_cache(linked_state_cls)
except ValueError:
# This state wasn't required for processing the event.
continue
linked_state = await original_state._internal_patch_linked_state(
linked_token
)
if (
previous_dirty_vars
and (dv := previous_dirty_vars.get(linked_state_name)) is not None
):
linked_state.dirty_vars.update(dv)
linked_state._mark_dirty()
async with self._exit_stack:
yield None
# Collect dirty vars and other affected clients that need to be updated.
for linked_state in self._held_locks_linked_states():
if linked_state._previous_dirty_vars is not None:
current_dirty_vars[linked_state.get_full_name()] = set(
linked_state._previous_dirty_vars
)
if (
linked_state._get_was_touched()
or linked_state._previous_dirty_vars is not None
):
affected_tokens.update(
token
for token in linked_state._linked_from
if token != self.router.session.client_token
)
# When modifying a shared token directly (empty _reflex_internal_links),
# the held locks will be empty. Check SharedState substates for linked
# clients that need to be notified.
if not self._reflex_internal_links:
shared_state_base_internal = await self.get_state(
SharedStateBaseInternal
)
if not isinstance(
shared_state_base_internal, SharedStateBaseInternal
):
msg = "Expected SharedStateBaseInternal in substates."
raise ReflexRuntimeError(msg)
# Collect affected tokens from all potentially linked states.
shared_state_base_internal._collect_shared_token_updates(
affected_tokens, current_dirty_vars
)
finally:
self._exit_stack = None
# Only propagate dirty vars when we are not already propagating from another state.
if previous_dirty_vars is None:
_do_update_other_tokens(
affected_tokens=affected_tokens,
previous_dirty_vars=current_dirty_vars,
state_type=type(self),
)
def _collect_shared_token_updates(
self,
affected_tokens: set[str],
current_dirty_vars: dict[str, set[str]],
) -> None:
"""Recursively collect dirty vars and linked clients from SharedState substates.
When a shared state is modified directly by its shared token (rather than
through a private client token), the held locks are empty so the normal
collection loop above finds nothing. This method recursively checks
SharedState substates for linked clients that need to be notified.
Args:
affected_tokens: Set to update with client tokens that need notification.
current_dirty_vars: Dict to update with dirty var mappings per state.
"""
for substate in self.substates.values():
if not isinstance(substate, SharedState):
continue
if substate._linked_from:
if substate._previous_dirty_vars:
current_dirty_vars[substate.get_full_name()] = set(
substate._previous_dirty_vars
)
if substate._get_was_touched() or substate._previous_dirty_vars:
affected_tokens.update(substate._linked_from)
substate._collect_shared_token_updates(affected_tokens, current_dirty_vars)
class SharedState(SharedStateBaseInternal, mixin=True):
"""Mixin for defining new shared states."""
_linked_from: set[str] = set()
_linked_to: str = ""
_previous_dirty_vars: set[str] = set()
@classmethod
def __init_subclass__(cls, **kwargs):
"""Initialize subclass and set up shared state fields.
Args:
**kwargs: The kwargs to pass to the init_subclass method.
"""
kwargs["mixin"] = False
cls._mixin = False
super().__init_subclass__(**kwargs)
root_state = cls.get_root_state()
if root_state.backend_vars["_reflex_internal_links"] is None:
root_state.backend_vars["_reflex_internal_links"] = {}
if root_state is State:
# Always fetch SharedStateBaseInternal to access
# `_modify_linked_states` without having to use `.get_state()` which
# pulls in all linked states and substates which may not actually be
# accessed for this event.
root_state._always_dirty_substates.add(SharedStateBaseInternal.get_name())