521 lines
21 KiB
Python
521 lines
21 KiB
Python
"""Base classes for shared / linked states."""
|
|
|
|
import asyncio
|
|
import contextlib
|
|
from collections.abc import AsyncIterator
|
|
from typing import TypeVar
|
|
|
|
from reflex_base.constants import ROUTER_DATA
|
|
from reflex_base.event import Event, get_hydrate_event
|
|
from reflex_base.utils import console
|
|
from reflex_base.utils.exceptions import ReflexRuntimeError
|
|
from typing_extensions import Self
|
|
|
|
from reflex.istate.manager.token import BaseStateToken
|
|
from reflex.state import BaseState, State, _override_base_method
|
|
|
|
UPDATE_OTHER_CLIENT_TASKS: set[asyncio.Task] = set()
|
|
LINKED_STATE = TypeVar("LINKED_STATE", bound="SharedStateBaseInternal")
|
|
|
|
|
|
def _log_update_client_errors(task: asyncio.Task):
|
|
"""Log errors from updating other clients.
|
|
|
|
Args:
|
|
task: The asyncio task to check for errors.
|
|
"""
|
|
try:
|
|
task.result()
|
|
except Exception as e:
|
|
console.warn(f"Error updating linked client: {e}")
|
|
finally:
|
|
UPDATE_OTHER_CLIENT_TASKS.discard(task)
|
|
|
|
|
|
def _do_update_other_tokens(
|
|
affected_tokens: set[str],
|
|
previous_dirty_vars: dict[str, set[str]],
|
|
state_type: type[BaseState],
|
|
) -> list[asyncio.Task]:
|
|
"""Update other clients after a shared state update.
|
|
|
|
Submit the updates in separate asyncio tasks to avoid deadlocking.
|
|
|
|
Args:
|
|
affected_tokens: The tokens to update.
|
|
previous_dirty_vars: The dirty vars to apply to other clients.
|
|
state_type: The type of the shared state.
|
|
|
|
Returns:
|
|
The list of asyncio tasks created to perform the updates.
|
|
"""
|
|
from reflex.utils.prerequisites import get_app
|
|
|
|
app = get_app().app
|
|
|
|
async def _update_client(token: str):
|
|
async with app.modify_state(
|
|
BaseStateToken(ident=token, cls=state_type),
|
|
previous_dirty_vars=previous_dirty_vars,
|
|
):
|
|
pass
|
|
|
|
tasks = []
|
|
for affected_token in affected_tokens:
|
|
# Don't send updates for disconnected clients.
|
|
if affected_token not in app.event_namespace._token_manager.token_to_socket:
|
|
continue
|
|
# TODO: remove disconnected clients after some time.
|
|
t = asyncio.create_task(_update_client(affected_token))
|
|
UPDATE_OTHER_CLIENT_TASKS.add(t)
|
|
t.add_done_callback(_log_update_client_errors)
|
|
tasks.append(t)
|
|
return tasks
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def _patch_state(
|
|
original_state: BaseState, linked_state: BaseState, full_delta: bool = False
|
|
):
|
|
"""Patch the linked state into the original state's tree, restoring it afterward.
|
|
|
|
Args:
|
|
original_state: The original shared state.
|
|
linked_state: The linked shared state.
|
|
full_delta: If True, mark all Vars in linked_state dirty and resolve
|
|
the delta from the root. This option is used when linking or unlinking
|
|
to ensure that other computed vars in the tree pick up the newly
|
|
linked/unlinked values.
|
|
"""
|
|
if (original_parent_state := original_state.parent_state) is None:
|
|
msg = "Cannot patch root state as linked state."
|
|
raise ReflexRuntimeError(msg)
|
|
|
|
state_name = original_state.get_name()
|
|
original_parent_state.substates[state_name] = linked_state
|
|
linked_parent_state = linked_state.parent_state
|
|
linked_state.parent_state = original_parent_state
|
|
try:
|
|
if full_delta:
|
|
linked_state.dirty_vars.update(linked_state.base_vars)
|
|
linked_state.dirty_vars.update(linked_state._backend_vars)
|
|
linked_state.dirty_vars.update(linked_state.computed_vars)
|
|
linked_state._mark_dirty()
|
|
# Apply the updates into the existing state tree for rehydrate.
|
|
root_state = original_state._get_root_state()
|
|
root_state.dirty_vars.add("router")
|
|
root_state.dirty_vars.add(ROUTER_DATA)
|
|
root_state._mark_dirty()
|
|
await root_state._get_resolved_delta()
|
|
yield
|
|
finally:
|
|
original_parent_state.substates[state_name] = original_state
|
|
linked_state.parent_state = linked_parent_state
|
|
|
|
|
|
class SharedStateBaseInternal(State):
|
|
"""The private base state for all shared states."""
|
|
|
|
_exit_stack: contextlib.AsyncExitStack | None = None
|
|
_held_locks: dict[str, dict[type[BaseState], BaseState]] | None = None
|
|
_held_locks_lock: asyncio.Lock = asyncio.Lock()
|
|
|
|
def __getstate__(self):
|
|
"""Override redis serialization to remove temporary fields.
|
|
|
|
Returns:
|
|
The state dictionary without temporary fields.
|
|
"""
|
|
s = super().__getstate__()
|
|
s.pop("_previous_dirty_vars", None)
|
|
s.pop("_exit_stack", None)
|
|
s.pop("_held_locks", None)
|
|
s.pop("_held_locks_lock", None)
|
|
return s
|
|
|
|
@_override_base_method
|
|
def _clean(self):
|
|
"""Override BaseState._clean to track the last set of dirty vars.
|
|
|
|
This is necessary for applying dirty vars from one event to other linked states.
|
|
"""
|
|
if (
|
|
previous_dirty_vars := getattr(self, "_previous_dirty_vars", None)
|
|
) is not None:
|
|
previous_dirty_vars.clear()
|
|
previous_dirty_vars.update(self.dirty_vars)
|
|
super()._clean()
|
|
|
|
@_override_base_method
|
|
def _mark_dirty(self):
|
|
"""Override BaseState._mark_dirty to avoid marking certain vars as dirty.
|
|
|
|
Since these internal fields are not persisted to redis, they shouldn't cause the
|
|
state to be considered dirty either.
|
|
"""
|
|
self.dirty_vars.discard("_previous_dirty_vars")
|
|
self.dirty_vars.discard("_exit_stack")
|
|
self.dirty_vars.discard("_held_locks")
|
|
self.dirty_vars.discard("_held_locks_lock")
|
|
# Only mark dirty if there are still dirty vars, or any substate is dirty
|
|
if self.dirty_vars or any(
|
|
substate.dirty_vars for substate in self.substates.values()
|
|
):
|
|
super()._mark_dirty()
|
|
|
|
def _rehydrate(self):
|
|
"""Get the events to rehydrate the state.
|
|
|
|
Returns:
|
|
The events to rehydrate the state (these should be returned/yielded).
|
|
"""
|
|
return [
|
|
Event(
|
|
name=get_hydrate_event(self._get_root_state()),
|
|
),
|
|
State.set_is_hydrated(True),
|
|
]
|
|
|
|
async def _resolve_linked_state(
|
|
self, state_cls: type["BaseState"], linked_token: str
|
|
) -> "BaseState":
|
|
"""Load and patch a linked state that was not pre-loaded in the tree.
|
|
|
|
Called by State._get_state_from_redis when a state in
|
|
_reflex_internal_links is not yet in the cache. This loads the
|
|
private copy into the tree first, then patches the linked version
|
|
on top of it via _internal_patch_linked_state.
|
|
|
|
Args:
|
|
state_cls: The shared state class to resolve.
|
|
linked_token: The shared token the state is linked to.
|
|
|
|
Returns:
|
|
The linked state instance, patched into the current tree.
|
|
|
|
Raises:
|
|
ReflexRuntimeError: If the resolved state is not a SharedState.
|
|
"""
|
|
root_state = self._get_root_state()
|
|
|
|
# Load the private copy into the tree so _internal_patch_linked_state
|
|
# has an original to swap out (needed for unlink / restore).
|
|
original_state = await BaseState._get_state_from_redis(root_state, state_cls)
|
|
|
|
if isinstance(original_state, SharedStateBaseInternal):
|
|
return await original_state._internal_patch_linked_state(linked_token)
|
|
|
|
msg = f"Failed to resolve linked state {state_cls.get_full_name()} for token {linked_token}: state does not inherit from rx.SharedState"
|
|
raise ReflexRuntimeError(msg)
|
|
|
|
async def _link_to(self, token: str) -> Self:
|
|
"""Link this shared state to a token.
|
|
|
|
After linking, subsequent access to this shared state will affect the
|
|
linked token's state, and cause changes to be propagated to all other
|
|
clients linked to that token.
|
|
|
|
Args:
|
|
token: The token to link to (Cannot contain underscore characters).
|
|
|
|
Returns:
|
|
The newly linked state.
|
|
|
|
Raises:
|
|
ReflexRuntimeError: If linking fails or token is invalid.
|
|
"""
|
|
if not token:
|
|
msg = "Cannot link shared state to empty token."
|
|
raise ReflexRuntimeError(msg)
|
|
if not isinstance(self, SharedState):
|
|
msg = "Can only link SharedState instances."
|
|
raise ReflexRuntimeError(msg)
|
|
if self._linked_to == token:
|
|
return self # already linked to this token
|
|
if self._linked_to and self._linked_to != token:
|
|
# Disassociate from previous linked token since unlink will not be called.
|
|
self._linked_from.discard(self.router.session.client_token)
|
|
# TODO: Change StateManager to accept token + class instead of combining them in a string.
|
|
if "_" in token:
|
|
msg = f"Invalid token {token} for linking state {self.get_full_name()}, cannot use underscore (_) in the token name."
|
|
raise ReflexRuntimeError(msg)
|
|
|
|
# Associate substate with the given link token.
|
|
state_name = self.get_full_name()
|
|
if self._reflex_internal_links is None:
|
|
self._reflex_internal_links = {}
|
|
self._reflex_internal_links[state_name] = token
|
|
return await self._internal_patch_linked_state(token, full_delta=True)
|
|
|
|
async def _unlink(self):
|
|
"""Unlink this shared state from its linked token.
|
|
|
|
Returns:
|
|
The events to rehydrate the state after unlinking (these should be returned/yielded).
|
|
"""
|
|
from reflex.istate.manager import get_state_manager
|
|
|
|
if not isinstance(self, SharedState):
|
|
msg = "Can only unlink SharedState instances."
|
|
raise ReflexRuntimeError(msg)
|
|
|
|
state_name = self.get_full_name()
|
|
if (
|
|
not self._reflex_internal_links
|
|
or state_name not in self._reflex_internal_links
|
|
):
|
|
msg = f"State {state_name} is not linked and cannot be unlinked."
|
|
raise ReflexRuntimeError(msg)
|
|
|
|
# Break the linkage for future events.
|
|
self._reflex_internal_links.pop(state_name)
|
|
self._linked_from.discard(self.router.session.client_token)
|
|
|
|
# Patch in the original state, apply updates, then rehydrate.
|
|
private_root_state = await get_state_manager().get_state(
|
|
BaseStateToken(
|
|
ident=self.router.session.client_token,
|
|
cls=type(self),
|
|
)
|
|
)
|
|
private_state = await private_root_state.get_state(type(self))
|
|
async with _patch_state(
|
|
original_state=self,
|
|
linked_state=private_state,
|
|
full_delta=True,
|
|
):
|
|
return self._rehydrate()
|
|
|
|
async def _internal_patch_linked_state(
|
|
self, token: str, full_delta: bool = False
|
|
) -> Self:
|
|
"""Load and replace this state with the linked state for a given token.
|
|
|
|
Must be called inside a `_modify_linked_states` context, to ensure locks are
|
|
released after the event is done processing.
|
|
|
|
Args:
|
|
token: The token of the linked state.
|
|
full_delta: If True, mark all Vars in linked_state dirty and resolve
|
|
delta to update cached computed vars
|
|
|
|
Returns:
|
|
The state that was linked into the tree.
|
|
"""
|
|
from reflex.istate.manager import get_state_manager
|
|
|
|
if self._exit_stack is None or self._held_locks is None:
|
|
msg = "Cannot link shared state outside of _modify_linked_states context."
|
|
raise ReflexRuntimeError(msg)
|
|
|
|
linked_root_state = None
|
|
|
|
# Get the newly linked state and update pointers/delta for subsequent events.
|
|
if token not in self._held_locks:
|
|
async with self._held_locks_lock:
|
|
if token not in self._held_locks:
|
|
linked_root_state = await self._exit_stack.enter_async_context(
|
|
get_state_manager().modify_state(
|
|
BaseStateToken(ident=token, cls=type(self))
|
|
)
|
|
)
|
|
self._held_locks.setdefault(token, {})
|
|
# Set client_token on the linked root so that subsequent get_state
|
|
# calls when directly modifying a linked token will load the
|
|
# associated instance.
|
|
if linked_root_state.router.session.client_token != token:
|
|
import dataclasses as dc
|
|
|
|
linked_root_state.router = dc.replace(
|
|
linked_root_state.router,
|
|
session=dc.replace(
|
|
linked_root_state.router.session, client_token=token
|
|
),
|
|
)
|
|
if linked_root_state is None:
|
|
linked_root_state = await get_state_manager().get_state(
|
|
BaseStateToken(ident=token, cls=type(self))
|
|
)
|
|
linked_state = await linked_root_state.get_state(type(self))
|
|
if not isinstance(linked_state, SharedState):
|
|
msg = f"Linked state for token {token} is not a SharedState."
|
|
raise ReflexRuntimeError(msg)
|
|
# Avoid unnecessary dirtiness of shared state when there are no changes.
|
|
if type(self) not in self._held_locks[token]:
|
|
self._held_locks[token][type(self)] = linked_state
|
|
if self.router.session.client_token not in linked_state._linked_from:
|
|
linked_state._linked_from.add(self.router.session.client_token)
|
|
if linked_state._linked_to != token:
|
|
linked_state._linked_to = token
|
|
await self._exit_stack.enter_async_context(
|
|
_patch_state(
|
|
original_state=self,
|
|
linked_state=linked_state,
|
|
full_delta=full_delta,
|
|
)
|
|
)
|
|
return linked_state
|
|
|
|
def _held_locks_linked_states(self) -> list["SharedState"]:
|
|
"""Get all linked states currently held by this state.
|
|
|
|
Returns:
|
|
The list of linked states currently held.
|
|
"""
|
|
if self._held_locks is None:
|
|
return []
|
|
return [
|
|
linked_state
|
|
for linked_state_cls_to_instance in self._held_locks.values()
|
|
for linked_state in linked_state_cls_to_instance.values()
|
|
if isinstance(linked_state, SharedState)
|
|
]
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def _modify_linked_states(
|
|
self, previous_dirty_vars: dict[str, set[str]] | None = None
|
|
) -> AsyncIterator[None]:
|
|
"""Take lock, fetch all linked states, and patch them into the current state tree.
|
|
|
|
If previous_dirty_vars is NOT provided, then any dirty vars after
|
|
exiting the context will be applied to all other clients linked to this
|
|
state's linked token.
|
|
|
|
Args:
|
|
previous_dirty_vars: When apply linked state changes to other
|
|
tokens, provide mapping of state full_name to set of dirty vars.
|
|
|
|
Yields:
|
|
None.
|
|
"""
|
|
if self._exit_stack is not None:
|
|
msg = "Cannot nest _modify_linked_states contexts."
|
|
raise ReflexRuntimeError(msg)
|
|
if self._reflex_internal_links is None:
|
|
msg = "No linked states to modify."
|
|
raise ReflexRuntimeError(msg)
|
|
self._exit_stack = contextlib.AsyncExitStack()
|
|
self._held_locks = {}
|
|
current_dirty_vars: dict[str, set[str]] = {}
|
|
affected_tokens: set[str] = set()
|
|
try:
|
|
# Go through all linked states and patch them in if they are present in the tree
|
|
for linked_state_name, linked_token in self._reflex_internal_links.items():
|
|
linked_state_cls: type[SharedState] = (
|
|
self.get_root_state().get_class_substate( # pyright: ignore[reportAssignmentType]
|
|
linked_state_name
|
|
)
|
|
)
|
|
try:
|
|
original_state = self._get_state_from_cache(linked_state_cls)
|
|
except ValueError:
|
|
# This state wasn't required for processing the event.
|
|
continue
|
|
linked_state = await original_state._internal_patch_linked_state(
|
|
linked_token
|
|
)
|
|
if (
|
|
previous_dirty_vars
|
|
and (dv := previous_dirty_vars.get(linked_state_name)) is not None
|
|
):
|
|
linked_state.dirty_vars.update(dv)
|
|
linked_state._mark_dirty()
|
|
async with self._exit_stack:
|
|
yield None
|
|
# Collect dirty vars and other affected clients that need to be updated.
|
|
for linked_state in self._held_locks_linked_states():
|
|
if linked_state._previous_dirty_vars is not None:
|
|
current_dirty_vars[linked_state.get_full_name()] = set(
|
|
linked_state._previous_dirty_vars
|
|
)
|
|
if (
|
|
linked_state._get_was_touched()
|
|
or linked_state._previous_dirty_vars is not None
|
|
):
|
|
affected_tokens.update(
|
|
token
|
|
for token in linked_state._linked_from
|
|
if token != self.router.session.client_token
|
|
)
|
|
# When modifying a shared token directly (empty _reflex_internal_links),
|
|
# the held locks will be empty. Check SharedState substates for linked
|
|
# clients that need to be notified.
|
|
if not self._reflex_internal_links:
|
|
shared_state_base_internal = await self.get_state(
|
|
SharedStateBaseInternal
|
|
)
|
|
if not isinstance(
|
|
shared_state_base_internal, SharedStateBaseInternal
|
|
):
|
|
msg = "Expected SharedStateBaseInternal in substates."
|
|
raise ReflexRuntimeError(msg)
|
|
# Collect affected tokens from all potentially linked states.
|
|
shared_state_base_internal._collect_shared_token_updates(
|
|
affected_tokens, current_dirty_vars
|
|
)
|
|
finally:
|
|
self._exit_stack = None
|
|
|
|
# Only propagate dirty vars when we are not already propagating from another state.
|
|
if previous_dirty_vars is None:
|
|
_do_update_other_tokens(
|
|
affected_tokens=affected_tokens,
|
|
previous_dirty_vars=current_dirty_vars,
|
|
state_type=type(self),
|
|
)
|
|
|
|
def _collect_shared_token_updates(
|
|
self,
|
|
affected_tokens: set[str],
|
|
current_dirty_vars: dict[str, set[str]],
|
|
) -> None:
|
|
"""Recursively collect dirty vars and linked clients from SharedState substates.
|
|
|
|
When a shared state is modified directly by its shared token (rather than
|
|
through a private client token), the held locks are empty so the normal
|
|
collection loop above finds nothing. This method recursively checks
|
|
SharedState substates for linked clients that need to be notified.
|
|
|
|
Args:
|
|
affected_tokens: Set to update with client tokens that need notification.
|
|
current_dirty_vars: Dict to update with dirty var mappings per state.
|
|
"""
|
|
for substate in self.substates.values():
|
|
if not isinstance(substate, SharedState):
|
|
continue
|
|
if substate._linked_from:
|
|
if substate._previous_dirty_vars:
|
|
current_dirty_vars[substate.get_full_name()] = set(
|
|
substate._previous_dirty_vars
|
|
)
|
|
if substate._get_was_touched() or substate._previous_dirty_vars:
|
|
affected_tokens.update(substate._linked_from)
|
|
substate._collect_shared_token_updates(affected_tokens, current_dirty_vars)
|
|
|
|
|
|
class SharedState(SharedStateBaseInternal, mixin=True):
|
|
"""Mixin for defining new shared states."""
|
|
|
|
_linked_from: set[str] = set()
|
|
_linked_to: str = ""
|
|
_previous_dirty_vars: set[str] = set()
|
|
|
|
@classmethod
|
|
def __init_subclass__(cls, **kwargs):
|
|
"""Initialize subclass and set up shared state fields.
|
|
|
|
Args:
|
|
**kwargs: The kwargs to pass to the init_subclass method.
|
|
"""
|
|
kwargs["mixin"] = False
|
|
cls._mixin = False
|
|
super().__init_subclass__(**kwargs)
|
|
root_state = cls.get_root_state()
|
|
if root_state.backend_vars["_reflex_internal_links"] is None:
|
|
root_state.backend_vars["_reflex_internal_links"] = {}
|
|
if root_state is State:
|
|
# Always fetch SharedStateBaseInternal to access
|
|
# `_modify_linked_states` without having to use `.get_state()` which
|
|
# pulls in all linked states and substates which may not actually be
|
|
# accessed for this event.
|
|
root_state._always_dirty_substates.add(SharedStateBaseInternal.get_name())
|