"""A state manager that stores states on disk.""" import asyncio import contextlib import dataclasses import functools import time from collections.abc import AsyncIterator from hashlib import md5 from pathlib import Path from typing import Any, Generic, cast from reflex_base.environment import environment from typing_extensions import Unpack, override from reflex.istate.manager import ( StateManager, StateModificationContext, _default_token_expiration, ) from reflex.istate.manager.token import TOKEN_TYPE, BaseStateToken, StateToken from reflex.state import BaseState from reflex.utils import console, path_ops, prerequisites from reflex.utils.misc import run_in_thread @dataclasses.dataclass(frozen=True) class QueueItem(Generic[TOKEN_TYPE]): """An item in the write queue.""" token: StateToken[TOKEN_TYPE] state: TOKEN_TYPE timestamp: float @dataclasses.dataclass class StateManagerDisk(StateManager): """A state manager that stores states on disk.""" # The mapping of client ids to states. states: dict[str, Any] = dataclasses.field(default_factory=dict) # The mutex ensures the dict of mutexes is updated exclusively _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock()) # The dict of mutexes for each client _states_locks: dict[str, asyncio.Lock] = dataclasses.field( default_factory=dict, init=False, ) # The token expiration time (s). token_expiration: int = dataclasses.field(default_factory=_default_token_expiration) # Last time a token was touched. _token_last_touched: dict[str, float] = dataclasses.field( default_factory=dict, init=False, ) # Pending writes _write_queue: dict[StateToken, QueueItem] = dataclasses.field( default_factory=dict, init=False, ) _write_queue_task: asyncio.Task | None = None _write_debounce_seconds: float = dataclasses.field( default=environment.REFLEX_STATE_MANAGER_DISK_DEBOUNCE_SECONDS.get() ) def __post_init__(self): """Create a new state manager.""" path_ops.mkdir(self.states_directory) self._purge_expired_states() @functools.cached_property def states_directory(self) -> Path: """Get the states directory. Returns: The states directory. """ return prerequisites.get_states_dir() def _purge_expired_states(self): """Purge expired states from the disk.""" for path in path_ops.ls(self.states_directory): # check path is a pickle file if path.suffix != ".pkl": continue # load last edited field from file last_edited = path.stat().st_mtime # check if the file is older than the token expiration time if time.time() - last_edited > self.token_expiration: # remove the file path.unlink() def token_path(self, token: StateToken) -> Path: """Get the path for a token. Args: token: The token to get the path for. Returns: The path for the token. """ return ( self.states_directory / f"{md5(str(token).encode()).hexdigest()}.pkl" ).absolute() async def load_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE | None: """Load a state object based on the provided token. Args: token: The token used to identify the state object. Returns: The loaded state object or None. """ token_path = self.token_path(token) if token_path.exists(): try: with token_path.open(mode="rb") as file: return token.deserialize(fp=file) except Exception: pass return None async def populate_substates( self, token: BaseStateToken, state: BaseState, root_state: BaseState ): """Populate the substates of a state object. Args: token: The token used to identify the state object. state: The state object to populate. root_state: The root state object. """ for substate in state.get_substates(): substate_token = token.with_cls(substate) fresh_instance = await root_state.get_state(substate) instance = await self.load_state(substate_token) if instance is not None: # Ensure all substates exist, even if they weren't serialized previously. instance.substates = fresh_instance.substates else: instance = fresh_instance state.substates[substate.get_name()] = instance instance.parent_state = state await self.populate_substates(token, instance, root_state) @override async def get_state( self, token: StateToken[TOKEN_TYPE], ) -> TOKEN_TYPE: """Get the state for a token. Args: token: The token to get the state for. Returns: The state for the token. """ token = self._coerce_token(token) root_state = self.states.get(token.cache_key) self._token_last_touched[token.cache_key] = time.time() if root_state is not None: # Retrieved state from memory. return root_state # Deserialize root state from disk. if isinstance(token, BaseStateToken): # Find the root state root_state_cls = token.cls.get_root_state() root_state = await self.load_state(token.with_cls(root_state_cls)) # Create a new root state tree with all substates instantiated. fresh_root_state = root_state_cls(_reflex_internal_init=True) if root_state is None: root_state = fresh_root_state elif not isinstance(root_state, BaseState): msg = "Deserialized state is not an instance of BaseState, cannot populate substates." raise TypeError(msg) else: # Ensure all substates exist, even if they were not serialized previously. root_state.substates = fresh_root_state.substates await self.populate_substates(token, root_state, root_state) self.states[token.cache_key] = root_state return cast(TOKEN_TYPE, root_state) # For non-BaseState tokens, if the deserialized state is None, we create a new instance using the token's cls. state = await self.load_state(token) if state is None: state = token.cls() self.states[token.cache_key] = state return cast(TOKEN_TYPE, state) async def set_state_for_substate( self, token: StateToken[TOKEN_TYPE], substate: TOKEN_TYPE ): """Set the state for a substate. Args: token: The token used to identify the state object. substate: The substate to set. """ substate_token = token.with_cls(type(substate)) if token.get_and_reset_touched_state(substate): pickle_state = token.serialize(substate) if pickle_state: if not self.states_directory.exists(): self.states_directory.mkdir(parents=True, exist_ok=True) await run_in_thread( lambda: self.token_path(substate_token).write_bytes(pickle_state), ) if isinstance(token, BaseStateToken) and isinstance(substate, BaseState): for substate_substate in substate.substates.values(): await self.set_state_for_substate(token, substate_substate) async def _process_write_queue_delay(self): """Wait for the debounce period before processing the write queue again.""" now = time.time() if self._write_queue: # There are still items in the queue, schedule another run. next_write_in = max( 0, min( self._write_debounce_seconds - (now - item.timestamp) for item in self._write_queue.values() ), ) await asyncio.sleep(next_write_in) elif self._write_debounce_seconds > 0: # No items left, wait a bit before checking again. await asyncio.sleep(self._write_debounce_seconds) else: # Debounce is disabled, so sleep until the next token expiration. oldest_token_last_touch = min( self._token_last_touched.values(), default=now ) next_expiration_in = self.token_expiration - (now - oldest_token_last_touch) await asyncio.sleep(next_expiration_in) async def _process_write_queue(self): """Long running task that checks for states to write to disk. Raises: asyncio.CancelledError: When the task is cancelled. """ while True: try: now = time.time() # sort the _write_queue by oldest timestamp and exclude items younger than debounce time items_to_write = sorted( ( item for item in self._write_queue.values() if now - item.timestamp >= self._write_debounce_seconds ), key=lambda item: item.timestamp, ) for item in items_to_write: token = item.token await self.set_state_for_substate( token, self._write_queue.pop(token).state ) # Check for expired states to purge. for cache_key, last_touched in list(self._token_last_touched.items()): if now - last_touched > self.token_expiration: self._token_last_touched.pop(cache_key) self.states.pop(cache_key, None) await run_in_thread(self._purge_expired_states) await self._process_write_queue_delay() except asyncio.CancelledError: # noqa: PERF203 await self._flush_write_queue() raise except Exception as e: console.error(f"Error processing write queue: {e!r}") if e.args == ("cannot schedule new futures after shutdown",): # Event loop is shutdown, nothing else we can really do... return await self._process_write_queue_delay() async def _flush_write_queue(self): """Flush any remaining items in the write queue to disk.""" outstanding_items = list(self._write_queue.values()) n_outstanding_items = len(outstanding_items) self._write_queue.clear() # When the task is cancelled, write all remaining items to disk. console.debug( f"StateManagerDisk._flush_write_queue: writing {n_outstanding_items} remaining items to disk" ) for item in outstanding_items: await self.set_state_for_substate( item.token, item.state, ) console.debug( f"StateManagerDisk._flush_write_queue: Finished writing {n_outstanding_items} items" ) async def _schedule_process_write_queue(self): """Schedule the write queue processing task if not already running.""" if self._write_queue_task is None or self._write_queue_task.done(): async with self._state_manager_lock: if self._write_queue_task is None or self._write_queue_task.done(): self._write_queue_task = asyncio.create_task( self._process_write_queue(), name="StateManagerDisk|WriteQueueProcessor", ) await asyncio.sleep(0) # Yield to allow the task to start. @override async def set_state( self, token: StateToken[TOKEN_TYPE], state: TOKEN_TYPE, **context: Unpack[StateModificationContext], ): """Set the state for a token. Args: token: The token to set the state for. state: The state to set. context: The state modification context. """ token = self._coerce_token(token) if self._write_debounce_seconds > 0: # Deferred write to reduce disk IO overhead. if token not in self._write_queue: self._write_queue[token] = QueueItem( token=token, state=state, timestamp=time.time(), ) else: # Immediate write to disk. await self.set_state_for_substate(token, state) # Ensure the processing task is scheduled to handle expirations and any deferred writes. await self._schedule_process_write_queue() @override @contextlib.asynccontextmanager async def modify_state( self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext] ) -> AsyncIterator[TOKEN_TYPE]: """Modify the state for a token while holding exclusive lock. Args: token: The token to modify the state for. context: The state modification context. Yields: The state for the token. """ token = self._coerce_token(token) # Disk state manager ignores the substate suffix and always returns the top-level state. lock_key = token.lock_key if lock_key not in self._states_locks: async with self._state_manager_lock: if lock_key not in self._states_locks: self._states_locks[lock_key] = asyncio.Lock() async with self._states_locks[lock_key]: state = await self.get_state(token) yield state await self.set_state(token, state, **context) async def close(self): """Close the state manager, flushing any pending writes to disk.""" async with self._state_manager_lock: if self._write_queue_task: self._write_queue_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._write_queue_task self._write_queue_task = None # Dump unlocked locks. for token, lock in tuple(self._states_locks.items()): if not lock.locked(): self._states_locks.pop(token)