387 lines
14 KiB
Python
387 lines
14 KiB
Python
"""A state manager that stores states on disk."""
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import time
|
|
from collections.abc import AsyncIterator
|
|
from hashlib import md5
|
|
from pathlib import Path
|
|
from typing import Any, Generic, cast
|
|
|
|
from reflex_base.environment import environment
|
|
from typing_extensions import Unpack, override
|
|
|
|
from reflex.istate.manager import (
|
|
StateManager,
|
|
StateModificationContext,
|
|
_default_token_expiration,
|
|
)
|
|
from reflex.istate.manager.token import TOKEN_TYPE, BaseStateToken, StateToken
|
|
from reflex.state import BaseState
|
|
from reflex.utils import console, path_ops, prerequisites
|
|
from reflex.utils.misc import run_in_thread
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class QueueItem(Generic[TOKEN_TYPE]):
|
|
"""An item in the write queue."""
|
|
|
|
token: StateToken[TOKEN_TYPE]
|
|
state: TOKEN_TYPE
|
|
timestamp: float
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class StateManagerDisk(StateManager):
|
|
"""A state manager that stores states on disk."""
|
|
|
|
# The mapping of client ids to states.
|
|
states: dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
|
|
# The mutex ensures the dict of mutexes is updated exclusively
|
|
_state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock())
|
|
|
|
# The dict of mutexes for each client
|
|
_states_locks: dict[str, asyncio.Lock] = dataclasses.field(
|
|
default_factory=dict,
|
|
init=False,
|
|
)
|
|
|
|
# The token expiration time (s).
|
|
token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
|
|
|
|
# Last time a token was touched.
|
|
_token_last_touched: dict[str, float] = dataclasses.field(
|
|
default_factory=dict,
|
|
init=False,
|
|
)
|
|
|
|
# Pending writes
|
|
_write_queue: dict[StateToken, QueueItem] = dataclasses.field(
|
|
default_factory=dict,
|
|
init=False,
|
|
)
|
|
_write_queue_task: asyncio.Task | None = None
|
|
_write_debounce_seconds: float = dataclasses.field(
|
|
default=environment.REFLEX_STATE_MANAGER_DISK_DEBOUNCE_SECONDS.get()
|
|
)
|
|
|
|
def __post_init__(self):
|
|
"""Create a new state manager."""
|
|
path_ops.mkdir(self.states_directory)
|
|
|
|
self._purge_expired_states()
|
|
|
|
@functools.cached_property
|
|
def states_directory(self) -> Path:
|
|
"""Get the states directory.
|
|
|
|
Returns:
|
|
The states directory.
|
|
"""
|
|
return prerequisites.get_states_dir()
|
|
|
|
def _purge_expired_states(self):
|
|
"""Purge expired states from the disk."""
|
|
for path in path_ops.ls(self.states_directory):
|
|
# check path is a pickle file
|
|
if path.suffix != ".pkl":
|
|
continue
|
|
|
|
# load last edited field from file
|
|
last_edited = path.stat().st_mtime
|
|
|
|
# check if the file is older than the token expiration time
|
|
if time.time() - last_edited > self.token_expiration:
|
|
# remove the file
|
|
path.unlink()
|
|
|
|
def token_path(self, token: StateToken) -> Path:
|
|
"""Get the path for a token.
|
|
|
|
Args:
|
|
token: The token to get the path for.
|
|
|
|
Returns:
|
|
The path for the token.
|
|
"""
|
|
return (
|
|
self.states_directory / f"{md5(str(token).encode()).hexdigest()}.pkl"
|
|
).absolute()
|
|
|
|
async def load_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE | None:
|
|
"""Load a state object based on the provided token.
|
|
|
|
Args:
|
|
token: The token used to identify the state object.
|
|
|
|
Returns:
|
|
The loaded state object or None.
|
|
"""
|
|
token_path = self.token_path(token)
|
|
|
|
if token_path.exists():
|
|
try:
|
|
with token_path.open(mode="rb") as file:
|
|
return token.deserialize(fp=file)
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
async def populate_substates(
|
|
self, token: BaseStateToken, state: BaseState, root_state: BaseState
|
|
):
|
|
"""Populate the substates of a state object.
|
|
|
|
Args:
|
|
token: The token used to identify the state object.
|
|
state: The state object to populate.
|
|
root_state: The root state object.
|
|
"""
|
|
for substate in state.get_substates():
|
|
substate_token = token.with_cls(substate)
|
|
|
|
fresh_instance = await root_state.get_state(substate)
|
|
instance = await self.load_state(substate_token)
|
|
if instance is not None:
|
|
# Ensure all substates exist, even if they weren't serialized previously.
|
|
instance.substates = fresh_instance.substates
|
|
else:
|
|
instance = fresh_instance
|
|
state.substates[substate.get_name()] = instance
|
|
instance.parent_state = state
|
|
|
|
await self.populate_substates(token, instance, root_state)
|
|
|
|
@override
|
|
async def get_state(
|
|
self,
|
|
token: StateToken[TOKEN_TYPE],
|
|
) -> TOKEN_TYPE:
|
|
"""Get the state for a token.
|
|
|
|
Args:
|
|
token: The token to get the state for.
|
|
|
|
Returns:
|
|
The state for the token.
|
|
"""
|
|
token = self._coerce_token(token)
|
|
root_state = self.states.get(token.cache_key)
|
|
self._token_last_touched[token.cache_key] = time.time()
|
|
if root_state is not None:
|
|
# Retrieved state from memory.
|
|
return root_state
|
|
|
|
# Deserialize root state from disk.
|
|
if isinstance(token, BaseStateToken):
|
|
# Find the root state
|
|
root_state_cls = token.cls.get_root_state()
|
|
root_state = await self.load_state(token.with_cls(root_state_cls))
|
|
# Create a new root state tree with all substates instantiated.
|
|
fresh_root_state = root_state_cls(_reflex_internal_init=True)
|
|
if root_state is None:
|
|
root_state = fresh_root_state
|
|
elif not isinstance(root_state, BaseState):
|
|
msg = "Deserialized state is not an instance of BaseState, cannot populate substates."
|
|
raise TypeError(msg)
|
|
else:
|
|
# Ensure all substates exist, even if they were not serialized previously.
|
|
root_state.substates = fresh_root_state.substates
|
|
await self.populate_substates(token, root_state, root_state)
|
|
self.states[token.cache_key] = root_state
|
|
return cast(TOKEN_TYPE, root_state)
|
|
# For non-BaseState tokens, if the deserialized state is None, we create a new instance using the token's cls.
|
|
state = await self.load_state(token)
|
|
if state is None:
|
|
state = token.cls()
|
|
self.states[token.cache_key] = state
|
|
return cast(TOKEN_TYPE, state)
|
|
|
|
async def set_state_for_substate(
|
|
self, token: StateToken[TOKEN_TYPE], substate: TOKEN_TYPE
|
|
):
|
|
"""Set the state for a substate.
|
|
|
|
Args:
|
|
token: The token used to identify the state object.
|
|
substate: The substate to set.
|
|
"""
|
|
substate_token = token.with_cls(type(substate))
|
|
|
|
if token.get_and_reset_touched_state(substate):
|
|
pickle_state = token.serialize(substate)
|
|
if pickle_state:
|
|
if not self.states_directory.exists():
|
|
self.states_directory.mkdir(parents=True, exist_ok=True)
|
|
await run_in_thread(
|
|
lambda: self.token_path(substate_token).write_bytes(pickle_state),
|
|
)
|
|
|
|
if isinstance(token, BaseStateToken) and isinstance(substate, BaseState):
|
|
for substate_substate in substate.substates.values():
|
|
await self.set_state_for_substate(token, substate_substate)
|
|
|
|
async def _process_write_queue_delay(self):
|
|
"""Wait for the debounce period before processing the write queue again."""
|
|
now = time.time()
|
|
if self._write_queue:
|
|
# There are still items in the queue, schedule another run.
|
|
next_write_in = max(
|
|
0,
|
|
min(
|
|
self._write_debounce_seconds - (now - item.timestamp)
|
|
for item in self._write_queue.values()
|
|
),
|
|
)
|
|
await asyncio.sleep(next_write_in)
|
|
elif self._write_debounce_seconds > 0:
|
|
# No items left, wait a bit before checking again.
|
|
await asyncio.sleep(self._write_debounce_seconds)
|
|
else:
|
|
# Debounce is disabled, so sleep until the next token expiration.
|
|
oldest_token_last_touch = min(
|
|
self._token_last_touched.values(), default=now
|
|
)
|
|
next_expiration_in = self.token_expiration - (now - oldest_token_last_touch)
|
|
await asyncio.sleep(next_expiration_in)
|
|
|
|
async def _process_write_queue(self):
|
|
"""Long running task that checks for states to write to disk.
|
|
|
|
Raises:
|
|
asyncio.CancelledError: When the task is cancelled.
|
|
"""
|
|
while True:
|
|
try:
|
|
now = time.time()
|
|
# sort the _write_queue by oldest timestamp and exclude items younger than debounce time
|
|
items_to_write = sorted(
|
|
(
|
|
item
|
|
for item in self._write_queue.values()
|
|
if now - item.timestamp >= self._write_debounce_seconds
|
|
),
|
|
key=lambda item: item.timestamp,
|
|
)
|
|
for item in items_to_write:
|
|
token = item.token
|
|
await self.set_state_for_substate(
|
|
token, self._write_queue.pop(token).state
|
|
)
|
|
# Check for expired states to purge.
|
|
for cache_key, last_touched in list(self._token_last_touched.items()):
|
|
if now - last_touched > self.token_expiration:
|
|
self._token_last_touched.pop(cache_key)
|
|
self.states.pop(cache_key, None)
|
|
await run_in_thread(self._purge_expired_states)
|
|
await self._process_write_queue_delay()
|
|
except asyncio.CancelledError: # noqa: PERF203
|
|
await self._flush_write_queue()
|
|
raise
|
|
except Exception as e:
|
|
console.error(f"Error processing write queue: {e!r}")
|
|
if e.args == ("cannot schedule new futures after shutdown",):
|
|
# Event loop is shutdown, nothing else we can really do...
|
|
return
|
|
await self._process_write_queue_delay()
|
|
|
|
async def _flush_write_queue(self):
|
|
"""Flush any remaining items in the write queue to disk."""
|
|
outstanding_items = list(self._write_queue.values())
|
|
n_outstanding_items = len(outstanding_items)
|
|
self._write_queue.clear()
|
|
# When the task is cancelled, write all remaining items to disk.
|
|
console.debug(
|
|
f"StateManagerDisk._flush_write_queue: writing {n_outstanding_items} remaining items to disk"
|
|
)
|
|
for item in outstanding_items:
|
|
await self.set_state_for_substate(
|
|
item.token,
|
|
item.state,
|
|
)
|
|
console.debug(
|
|
f"StateManagerDisk._flush_write_queue: Finished writing {n_outstanding_items} items"
|
|
)
|
|
|
|
async def _schedule_process_write_queue(self):
|
|
"""Schedule the write queue processing task if not already running."""
|
|
if self._write_queue_task is None or self._write_queue_task.done():
|
|
async with self._state_manager_lock:
|
|
if self._write_queue_task is None or self._write_queue_task.done():
|
|
self._write_queue_task = asyncio.create_task(
|
|
self._process_write_queue(),
|
|
name="StateManagerDisk|WriteQueueProcessor",
|
|
)
|
|
await asyncio.sleep(0) # Yield to allow the task to start.
|
|
|
|
@override
|
|
async def set_state(
|
|
self,
|
|
token: StateToken[TOKEN_TYPE],
|
|
state: TOKEN_TYPE,
|
|
**context: Unpack[StateModificationContext],
|
|
):
|
|
"""Set the state for a token.
|
|
|
|
Args:
|
|
token: The token to set the state for.
|
|
state: The state to set.
|
|
context: The state modification context.
|
|
"""
|
|
token = self._coerce_token(token)
|
|
if self._write_debounce_seconds > 0:
|
|
# Deferred write to reduce disk IO overhead.
|
|
if token not in self._write_queue:
|
|
self._write_queue[token] = QueueItem(
|
|
token=token,
|
|
state=state,
|
|
timestamp=time.time(),
|
|
)
|
|
else:
|
|
# Immediate write to disk.
|
|
await self.set_state_for_substate(token, state)
|
|
# Ensure the processing task is scheduled to handle expirations and any deferred writes.
|
|
await self._schedule_process_write_queue()
|
|
|
|
@override
|
|
@contextlib.asynccontextmanager
|
|
async def modify_state(
|
|
self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext]
|
|
) -> AsyncIterator[TOKEN_TYPE]:
|
|
"""Modify the state for a token while holding exclusive lock.
|
|
|
|
Args:
|
|
token: The token to modify the state for.
|
|
context: The state modification context.
|
|
|
|
Yields:
|
|
The state for the token.
|
|
"""
|
|
token = self._coerce_token(token)
|
|
# Disk state manager ignores the substate suffix and always returns the top-level state.
|
|
lock_key = token.lock_key
|
|
if lock_key not in self._states_locks:
|
|
async with self._state_manager_lock:
|
|
if lock_key not in self._states_locks:
|
|
self._states_locks[lock_key] = asyncio.Lock()
|
|
|
|
async with self._states_locks[lock_key]:
|
|
state = await self.get_state(token)
|
|
yield state
|
|
await self.set_state(token, state, **context)
|
|
|
|
async def close(self):
|
|
"""Close the state manager, flushing any pending writes to disk."""
|
|
async with self._state_manager_lock:
|
|
if self._write_queue_task:
|
|
self._write_queue_task.cancel()
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await self._write_queue_task
|
|
self._write_queue_task = None
|
|
# Dump unlocked locks.
|
|
for token, lock in tuple(self._states_locks.items()):
|
|
if not lock.locked():
|
|
self._states_locks.pop(token)
|