"""A state manager that stores states in redis.""" import asyncio import contextlib import dataclasses import inspect import os import sys import time import uuid from collections.abc import AsyncIterator from typing import Any, TypedDict, cast from redis import ResponseError from redis.asyncio import Redis from reflex_base.config import get_config from reflex_base.environment import environment from reflex_base.utils import console from reflex_base.utils.exceptions import ( InvalidLockWarningThresholdError, LockExpiredError, StateSchemaMismatchError, ) from typing_extensions import Unpack, override from reflex.istate.manager import ( StateManager, StateModificationContext, _default_token_expiration, ) from reflex.istate.manager.token import TOKEN_TYPE, BaseStateToken, StateToken from reflex.state import BaseState from reflex.utils.tasks import ensure_task def _default_lock_expiration() -> int: """Get the default lock expiration time. Returns: The default lock expiration time. """ return get_config().redis_lock_expiration def _default_lock_warning_threshold() -> int: """Get the default lock warning threshold. Returns: The default lock warning threshold. """ return get_config().redis_lock_warning_threshold def _default_oplock_hold_time_ms() -> int: """Get the default opportunistic lock hold time. Returns: The default opportunistic lock hold time. """ return environment.REFLEX_OPLOCK_HOLD_TIME_MS.get() or ( _default_lock_expiration() // 2 ) # The lock waiter task should subscribe to lock channel updates within this period. LOCK_SUBSCRIBE_TASK_TIMEOUT = 2 # seconds SMR = f"[SMR:{os.getpid()}]" start = time.monotonic() class RedisPubSubMessage(TypedDict): """A Redis Pub/Sub message.""" type: str pattern: bytes | None channel: bytes data: bytes | int class OplockFound(Exception): # noqa: N818 """Indicates that an opportunistic lock was found.""" @dataclasses.dataclass class StateManagerRedis(StateManager): """A state manager that stores states in redis.""" # The redis client to use. redis: Redis # The token expiration time (s). token_expiration: int = dataclasses.field(default_factory=_default_token_expiration) # The maximum time to hold a lock (ms). lock_expiration: int = dataclasses.field(default_factory=_default_lock_expiration) # The maximum time to hold a lock (ms) before warning. lock_warning_threshold: int = dataclasses.field( default_factory=_default_lock_warning_threshold ) # How long to opportunistically hold the redis lock in milliseconds (must be less than the token expiration). oplock_hold_time_ms: int = dataclasses.field( default_factory=_default_oplock_hold_time_ms ) # The keyspace subscription string when redis is waiting for lock to be released. _redis_notify_keyspace_events: str = dataclasses.field( default="K" # Enable keyspace notifications (target a particular key) "$" # For String commands (like setting keys) "s" # For Set commands (SADD, SREM, etc) "g" # For generic commands (DEL, EXPIRE, etc) "x" # For expired events "e" # For evicted events (i.e. maxmemory exceeded) ) # These events indicate that a lock is no longer held. _redis_keyspace_lock_release_events: set[bytes] = dataclasses.field( default_factory=lambda: { b"del", b"expired", b"evicted", } ) # Whether keyspace notifications have been enabled. _redis_notify_keyspace_events_enabled: bool = dataclasses.field(default=False) # The mutex ensures the dict of mutexes is updated exclusively _state_manager_lock: asyncio.Lock = dataclasses.field( default=asyncio.Lock(), init=False ) # Whether to opportunistically hold locks for fast in-memory access. _oplock_enabled: bool = dataclasses.field( default_factory=environment.REFLEX_OPLOCK_ENABLED.get, init=False ) # Cached states _cached_states: dict[str, Any] = dataclasses.field(default_factory=dict, init=False) _cached_states_locks: dict[str, asyncio.Lock] = dataclasses.field( default_factory=dict, init=False ) # Local Leases (token -> flush task) _local_leases: dict[str, asyncio.Task] = dataclasses.field( default_factory=dict, init=False ) # The unique ID for this state manager, the domain for _local_leases. _instance_id: str = dataclasses.field(default_factory=lambda: str(uuid.uuid4())) # Lock waiters for redis per-token lock. _lock_waiters: dict[bytes, list[asyncio.Event]] = dataclasses.field( default_factory=dict, init=False, ) _lock_updates_subscribed: asyncio.Event = dataclasses.field( default_factory=asyncio.Event, init=False, ) _lock_task: asyncio.Task | None = dataclasses.field(default=None, init=False) # Whether debug prints are enabled. _debug_enabled: bool = dataclasses.field( default=environment.REFLEX_STATE_MANAGER_REDIS_DEBUG.get(), init=False, ) def __post_init__(self): """Validate the lock warning threshold. Raises: InvalidLockWarningThresholdError: If the lock warning threshold is invalid. """ if self.lock_warning_threshold >= (lock_expiration := self.lock_expiration): msg = f"The lock warning threshold({self.lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})." raise InvalidLockWarningThresholdError(msg) if self._oplock_enabled and self.oplock_hold_time_ms >= lock_expiration: msg = f"The opportunistic lock hold time({self.oplock_hold_time_ms}) must be less than the lock expiration time({lock_expiration})." raise InvalidLockWarningThresholdError(msg) with contextlib.suppress(RuntimeError): asyncio.get_running_loop() # Check if we're in an event loop. self._ensure_lock_task() def _get_required_state_classes( self, target_state_cls: type[BaseState], subclasses: bool = False, required_state_classes: set[type[BaseState]] | None = None, ) -> set[type[BaseState]]: """Recursively determine which states are required to fetch the target state. This will always include potentially dirty substates that depend on vars in the target_state_cls. Args: target_state_cls: The target state class being fetched. subclasses: Whether to include subclasses of the target state. required_state_classes: Recursive argument tracking state classes that have already been seen. Returns: The set of state classes required to fetch the target state. """ if required_state_classes is None: required_state_classes = set() # Get the substates if requested. if subclasses: for substate in target_state_cls.get_substates(): self._get_required_state_classes( substate, subclasses=True, required_state_classes=required_state_classes, ) if target_state_cls in required_state_classes: return required_state_classes required_state_classes.add(target_state_cls) # Get dependent substates. for pd_substates in target_state_cls._get_potentially_dirty_states(): self._get_required_state_classes( pd_substates, subclasses=False, required_state_classes=required_state_classes, ) # Get the parent state if it exists. if parent_state := target_state_cls.get_parent_state(): self._get_required_state_classes( parent_state, subclasses=False, required_state_classes=required_state_classes, ) return required_state_classes def _get_populated_states( self, target_state: BaseState, populated_states: dict[str, BaseState] | None = None, ) -> dict[str, BaseState]: """Recursively determine which states from target_state are already fetched. Args: target_state: The state to check for populated states. populated_states: Recursive argument tracking states seen in previous calls. Returns: A dictionary of state full name to state instance. """ if populated_states is None: populated_states = {} if target_state.get_full_name() in populated_states: return populated_states populated_states[target_state.get_full_name()] = target_state for substate in target_state.substates.values(): self._get_populated_states(substate, populated_states=populated_states) if target_state.parent_state is not None: self._get_populated_states( target_state.parent_state, populated_states=populated_states ) return populated_states @override async def get_state( self, token: StateToken[TOKEN_TYPE], top_level: bool = True, for_state_instance: BaseState | None = None, ) -> TOKEN_TYPE: """Get the state for a token. Args: token: The token to get the state for. top_level: If true, return the top-level root state. for_state_instance: If provided, attach the requested states to this existing state tree. Returns: The state for the token. Raises: RuntimeError: when the parent state for a requested state was not fetched. """ token = self._coerce_token(token) if not isinstance(token, BaseStateToken): # Non-BaseState token: simple single-key fetch. redis_data = await self.redis.get(str(token)) if redis_data is not None: return token.deserialize(data=redis_data) return token.cls() requested_state_cls = token.cls # Determine which states we already have. flat_state_tree: dict[str, BaseState] = ( self._get_populated_states(for_state_instance) if for_state_instance else {} ) # Determine which states from the tree need to be fetched. required_state_classes = sorted( self._get_required_state_classes(requested_state_cls, subclasses=True) - {type(s) for s in flat_state_tree.values()}, key=lambda x: x.get_full_name(), ) redis_pipeline = self.redis.pipeline() for state_cls in required_state_classes: redis_pipeline.get(str(token.with_cls(state_cls))) for state_cls, redis_state in zip( required_state_classes, await redis_pipeline.execute(), strict=False, ): state = None if redis_state is not None: # Deserialize the substate. with contextlib.suppress(StateSchemaMismatchError): state = BaseState._deserialize(data=redis_state) if state is None: # Key didn't exist or schema mismatch so create a new instance for this token. state = state_cls( init_substates=False, _reflex_internal_init=True, ) flat_state_tree[state.get_full_name()] = state if state.get_parent_state() is not None: parent_state_name, _dot, state_name = state.get_full_name().rpartition( "." ) parent_state = flat_state_tree.get(parent_state_name) if parent_state is None: msg = ( f"Parent state for {state.get_full_name()} was not found " "in the state tree, but should have already been fetched. " "This is a bug" ) raise RuntimeError(msg) parent_state.substates[state_name] = state state.parent_state = parent_state # To retain compatibility with previous implementation, by default, we return # the top-level state which should always be fetched or already cached. if top_level: return cast( TOKEN_TYPE, flat_state_tree[requested_state_cls.get_root_state().get_full_name()], ) return cast(TOKEN_TYPE, flat_state_tree[requested_state_cls.get_full_name()]) @override async def set_state( self, token: StateToken[TOKEN_TYPE], state: TOKEN_TYPE, *, lock_id: bytes | None = None, **context: Unpack[StateModificationContext], ): """Set the state for a token. Args: token: The token to set the state for. state: The state to set. lock_id: If provided, the lock must be held with this value to set the state. context: The event context. Raises: LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID. RuntimeError: If the state instance doesn't match the state name in the token. """ token = self._coerce_token(token) # Check that we're holding the lock. if ( lock_id is not None and (existing_lock_id := await self.redis.get(self._lock_key(token))) != lock_id ): msg = ( f"Lock expired for token {token} while processing. Consider increasing " f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) " "or use `@rx.event(background=True)` decorator for long-running tasks. " f"Current lock id: {existing_lock_id!r}, expected lock id: {lock_id!r}." + ( f" Happened in event: {event.name}" if (event := context.get("event")) is not None else "" ) ) raise LockExpiredError(msg) if not isinstance(token, BaseStateToken): # Non-BaseState token: simple single-key write. pickle_state = token.serialize(state) if pickle_state: await self.redis.set(str(token), pickle_state, ex=self.token_expiration) return base_state = cast(BaseState, state) lock_key = token.lock_key if lock_id is not None and lock_key not in self._local_leases: time_taken = ( self.lock_expiration - (await self.redis.pttl(self._lock_key(token))) ) / 1000 if time_taken > self.lock_warning_threshold / 1000: console.warn( f"Lock for token {token} was held too long {time_taken=}s, " f"use `@rx.event(background=True)` decorator for long-running tasks." + ( f" Happened in event: {event.name}" if (event := context.get("event")) is not None else "" ), dedupe=True, ) # Recursively set_state on all known substates. tasks = [ asyncio.create_task( self.set_state( token, substate, lock_id=lock_id, **context, ), name=f"reflex_set_state|{lock_key}|{substate.get_full_name()}", ) for substate in base_state.substates.values() ] # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). if base_state._get_was_touched(): pickle_state = base_state._serialize() if pickle_state: await self.redis.set( str(token.with_cls(type(base_state))), pickle_state, ex=self.token_expiration, ) # Wait for substates to be persisted. for t in tasks: await t @contextlib.asynccontextmanager async def _try_modify_state( self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext] ) -> AsyncIterator[TOKEN_TYPE | None]: """Modify the state for a token while holding exclusive lock. Args: token: The token to modify the state for. context: The state modification context. Yields: The state for the token or None if we couldn't get the lock. """ event_name = event.name if (event := context.get("event")) is not None else None if not self._oplock_enabled: # OpLock is disabled, get a fresh lock, write, and release. async with self._lock(token, event_name=event_name) as lock_id: state = await self.get_state(token) yield state await self.set_state(token, state, lock_id=lock_id, **context) return # Opportunistically reuse existing lock. async with self._get_state_cached(token) as cached_state: if cached_state is not None: yield cached_state self._notify_next_waiter(self._lock_key(token)) return # Opportunistic locking is enabled, so try to hold the lock across multiple calls. lock_key = token.lock_key lock_held_ctx = contextlib.AsyncExitStack() try: lock_id = await lock_held_ctx.enter_async_context( self._lock(token, event_name=event_name) ) except OplockFound: # While waiting for the lock, another process has acquired it, but we can piggy back. pass else: # Do not create a lease break task when multiple instances are waiting. if ( not await self._get_local_lease(lock_key) and await self._n_lock_contenders(self._lock_key(token)) > 0 ): if self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} has contention, not leasing" ) async with lock_held_ctx: state = await self.get_state(token) yield state await self.set_state(token, state, lock_id=lock_id, **context) return # Create the lease break task since we got the lock. if ( new_lease_task := await self._create_lease_break_task( token, lock_id, cleanup_ctx=lock_held_ctx, **context ) ) is ( current_lease_task := await self._get_local_lease(lock_key) ) and new_lease_task is not None: if self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} obtained lock {lock_id.decode()}." ) elif current_lease_task is None: # Check if we still have the redis lock, then just try to send this one update and release it. await self._try_extend_lock(self._lock_key(token)) if await self.redis.get(self._lock_key(token)) == lock_id: if self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} holding lock {lock_id.decode()}, {new_lease_task=} already exited, doing single update..." ) async with lock_held_ctx: state = await self.get_state(token) yield state await self.set_state(token, state, lock_id=lock_id, **context) return elif self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lock {lock_id.decode()} expired while waiting for lease task to exit..." ) # Have to retry getting the state, but now it's probably cached. yield None @override @contextlib.asynccontextmanager async def modify_state( self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext] ) -> AsyncIterator[TOKEN_TYPE]: """Modify the state for a token while holding exclusive lock. Args: token: The token to modify the state for. context: The state modification context. Yields: The state for the token. """ token = self._coerce_token(token) while True: async with self._try_modify_state(token, **context) as state_instance: if state_instance is not None: yield cast(TOKEN_TYPE, state_instance) return @contextlib.asynccontextmanager async def _get_state_cached( self, token: StateToken[TOKEN_TYPE] ) -> AsyncIterator[TOKEN_TYPE | None]: """Get the cached state for a token, while holding the local lease lock. Args: token: The token to get the cached state for. Yields: The cached state for the token, or None if not cached/uncachable. """ lock_key = token.lock_key # Opportunistically reuse existing lock. if ( lock_key in self._local_leases and (state_lock := self._cached_states_locks.get(lock_key)) is not None ): async with state_lock: if await self._get_local_lease(lock_key) is not None: if (cached_state := self._cached_states.get(lock_key)) is not None: if isinstance(token, BaseStateToken): # Make sure we have the substate cached (or fetch it from redis). state_path = token.cls.get_full_name() try: substate = cached_state.get_substate( state_path.split(".") ) if len(substate.substates) != len( type(substate).get_substates() ): # If the substate is missing substates, we need to refetch it. raise ValueError # noqa: TRY301 except ValueError: await self.get_state( token, for_state_instance=cached_state ) yield cast(TOKEN_TYPE, cached_state) return elif self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease task found, lock held, but no cached state" ) elif self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} no active lease task found" ) yield None def _notify_next_waiter(self, key: bytes): """Notify the next waiter for a given lock key. Args: key: The redis lock key. """ # Notify the next un-notified waiter, if any. for event in self._lock_waiters.get(key, ()): if not event.is_set(): event.set() if self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {key.decode()} NOTIFY 1 / {len(self._lock_waiters[key])} waiters {event=}" ) break async def _create_lease_break_task( self, token: StateToken[TOKEN_TYPE], lock_id: bytes, cleanup_ctx: contextlib.AsyncExitStack, **context: Unpack[StateModificationContext], ) -> asyncio.Task | None: """Create a background task to break the local lease after lock expiration. Args: token: The token to create the lease break task for. lock_id: The ID of the lock. cleanup_ctx: Enter this context while running the lease break task. context: The state modification context. Returns: The lease break task, or None when there is contention. """ self._ensure_lock_task() lock_key = token.lock_key async def do_flush() -> None: if (state_lock := self._cached_states_locks.get(lock_key)) is None: # If we lost the lock, we can't write the state, something went wrong. console.warn( f"State lock for {lock_key} missing while finalizing lease." ) return async with state_lock: # Write the state to redis while no one else can modify the cached copy. state = self._cached_states.pop(lock_key, None) try: if state: if self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease breaker {lock_id.decode()} flushing state" ) await self.set_state(token, state, lock_id=lock_id, **context) finally: if (current_lease := self._local_leases.get(lock_key)) is task: self._local_leases.pop(lock_key, None) # TODO: clean up the cached states locks periodically elif self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease breaker {lock_id.decode()} cleanup of {task=} found different task in _local_leases {current_lease=}." ) async def lease_breaker(): cancelled_error: asyncio.CancelledError | None = None async with cleanup_ctx: lease_break_time = self.oplock_hold_time_ms / 1000 if self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease breaker {lock_id.decode()} started, sleeping for {lease_break_time}s" ) try: await asyncio.sleep(lease_break_time) except asyncio.CancelledError as err: cancelled_error = err # We got cancelled so if someone is holding the lock, # extend the timeout so they get the full time to complete. if ( state_lock := self._cached_states_locks[lock_key] ) is not None and state_lock.locked(): await self._try_extend_lock(self._lock_key(token)) try: # Shield the flush from cancellation to ensure it always runs to completion. await asyncio.shield(do_flush()) except Exception as e: # Propagate exception to the main loop, since we have nowhere to catch it. if not isinstance(e, asyncio.CancelledError): asyncio.get_running_loop().call_exception_handler({ "message": "Exception in Redis State Manager lease breaker", "exception": e, }) raise finally: # Re-raise any cancellation error after cleaning up. if cancelled_error is not None: raise cancelled_error if (state_lock := self._cached_states_locks.get(lock_key)) is not None: # We have an existing lock, so lets see if we have an existing lease to cancel. async with state_lock: if (existing_task := self._local_leases.get(lock_key)) is not None: # There's already a lease break task, so cancel it to clear it out. existing_task.cancel() if existing_task is not None: with contextlib.suppress(asyncio.CancelledError): await existing_task # Now we might need to create a new lock. if (state_lock := self._cached_states_locks.get(lock_key)) is None: async with self._state_manager_lock: if (state_lock := self._cached_states_locks.get(lock_key)) is None: state_lock = self._cached_states_locks[lock_key] = asyncio.Lock() async with state_lock: # Create the task now if one didn't sneak past us. if ( lock_key not in self._local_leases and await self._n_lock_contenders(self._lock_key(token)) == 0 ): self._local_leases[lock_key] = task = asyncio.create_task( lease_breaker(), name=f"reflex_lease_breaker|{lock_key}|{lock_id.decode()}", ) # Fetch the requested state into the cache. self._cached_states[lock_key] = await self.get_state(token) return task return None @staticmethod def _lock_key(token: StateToken[Any]) -> bytes: """Get the redis key for a token's lock. Args: token: The token to get the lock key for. Returns: The redis lock key for the token. """ return f"{token.lock_key}_lock".encode() async def _try_extend_lock(self, lock_key: bytes) -> bool | None: """Extends the current lock for another lock_expiration period. Does not change ownership of the lock! Args: lock_key: The redis key for the lock. Returns: True if the lock was extended. """ return await self.redis.pexpire(lock_key, self.lock_expiration, xx=True) async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None: """Try to get a redis lock for a token. Args: lock_key: The redis key for the lock. lock_id: The ID of the lock. Returns: True if the lock was obtained. """ return await self.redis.set( lock_key, lock_id, px=self.lock_expiration, nx=True, # only set if it doesn't exist ) async def _handle_lock_release(self, message: RedisPubSubMessage) -> None: """Handle a lock release message from redis. Args: message: The redis message. """ if message["data"] in self._redis_keyspace_lock_release_events: key = message["channel"].split(b":", 1)[1] if key in self._lock_waiters: self._notify_next_waiter(key) async def _handle_lock_contention(self, message: RedisPubSubMessage) -> None: """Handle a lock contention message from redis. Args: message: The redis message. """ # Opportunistic lock contention notification. token = message["channel"].rsplit(b":", 1)[1][: -len(b"_lock_waiters")].decode() if ( message["data"] == b"sadd" and (state_lock := self._cached_states_locks.get(token)) is not None ): # Cancel the lease break task to force a lock reacquisition. async with state_lock: if (lease_task := await self._get_local_lease(token)) is not None: lease_task.cancel() if self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {token} OPLOCK CONTEND - lease break task cancelled {lease_task=}" ) async def _subscribe_lock_updates(self): """Subscribe to redis keyspace notifications for lock updates.""" await self._enable_keyspace_notifications() redis_db = self.redis.get_connection_kwargs().get("db", 0) lock_key_pattern = f"__keyspace@{redis_db}__:*_lock" lock_waiter_key_pattern = f"__keyspace@{redis_db}__:*_lock_waiters" handlers = { lock_key_pattern: self._handle_lock_release, lock_waiter_key_pattern: self._handle_lock_contention, } async with self.redis.pubsub() as pubsub: await pubsub.psubscribe(**handlers) # pyright: ignore[reportArgumentType] self._lock_updates_subscribed.set() try: async for _ in pubsub.listen(): pass finally: self._lock_updates_subscribed.clear() def _ensure_lock_task(self) -> None: """Ensure the lock updates subscriber task is running.""" ensure_task( owner=self, task_attribute="_lock_task", coro_function=self._subscribe_lock_updates, suppress_exceptions=[Exception], ) async def _ensure_lock_task_subscribed(self, timeout: float | None = None) -> None: """Ensure the lock updates subscriber task is running and subscribed to avoid missing notifications. Args: timeout: How long to wait for the subscriber to be subscribed before raising an error. If None, defaults to min(LOCK_SUBSCRIBE_TASK_TIMEOUT, lock_expiration). Raises: TimeoutError: If the lock updates subscriber task fails to subscribe in time. """ if timeout is None: timeout = min( LOCK_SUBSCRIBE_TASK_TIMEOUT, max(self.lock_expiration / 1000, 0), ) # Make sure lock waiter task is running. self._ensure_lock_task() # Make sure the lock waiter is subscribed to avoid missing notifications. await asyncio.wait_for( self._lock_updates_subscribed.wait(), timeout=timeout, ) async def _enable_keyspace_notifications(self): """Enable keyspace notifications for the redis server. Raises: ResponseError: when the keyspace config cannot be set. """ if self._redis_notify_keyspace_events_enabled: return try: await self.redis.config_set( "notify-keyspace-events", self._redis_notify_keyspace_events, ) except ResponseError: # Some redis servers only allow out-of-band configuration, so ignore errors here. if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get(): raise self._redis_notify_keyspace_events_enabled = True @contextlib.asynccontextmanager async def _lock_waiter(self, lock_key: bytes) -> AsyncIterator[asyncio.Event]: """Create a lock waiter for a given lock key. Args: lock_key: The redis key for the lock. Yields: The event that will be set when the lock is released. """ lock_released_events = self._lock_waiters.get(lock_key) if lock_released_events is None: # Create a new or get existing set of waiters in manager lock. async with self._state_manager_lock: lock_released_events = self._lock_waiters.setdefault(lock_key, []) lock_released_event = asyncio.Event() lock_released_events.append(lock_released_event) try: yield lock_released_event finally: # Set before removing to signal that we don't care about it anymore. lock_released_event.set() # Clean up the waiter lock_released_events.remove(lock_released_event) if not lock_released_events: # Try to clean up the whole set if empty. async with self._state_manager_lock: if not lock_released_events: self._lock_waiters.pop(lock_key, None) def _n_lock_waiters(self, lock_key: bytes) -> int: """Get the number of local waiters for a given lock key. Args: lock_key: The redis key for the lock. Returns: The number of waiters for the lock key on this instance. """ lock_released_events = self._lock_waiters.get(lock_key) if lock_released_events is None: return 0 return len(lock_released_events) async def _n_lock_contenders(self, lock_key: bytes) -> int: """Get the number of contenders for a given lock key. Args: lock_key: The redis key for the lock. Returns: The number of contenders for the lock key across all instances. """ res = self.redis.scard(lock_key + b"_waiters") if inspect.isawaitable(res): res = await res return res @contextlib.asynccontextmanager async def _request_lock_release( self, lock_key: bytes, lock_id: bytes ) -> AsyncIterator[None]: """Request the release of a redis lock. Args: lock_key: The redis key for the lock. lock_id: The ID of the lock. """ if not self._oplock_enabled: yield return lock_waiter_key = lock_key + b"_waiters" pipeline = self.redis.pipeline() # Signal intention to request oplock for this process. pipeline.sadd(lock_waiter_key, self._instance_id) pipeline.pexpire(lock_waiter_key, self.lock_expiration) await pipeline.execute() try: yield # Waiting for redis/oplock to be acquired. finally: res = self.redis.srem(lock_waiter_key, self._instance_id) if inspect.isawaitable(res): await res async def _get_local_lease( self, token: str, raise_when_found: bool = False ) -> asyncio.Task | None: """Check if there is a local lease for a token. Args: token: The token to check for a local lease. raise_when_found: If true, raise OplockFound when a local lease is found. Returns: The local lease task if found, None otherwise. Raises: OplockFound: If there is a local lease for the token and raise_when_found is True. """ if ( self._oplock_enabled and (lease_task := self._local_leases.get(token)) is not None and not lease_task.done() and not lease_task.cancelled() and (sys.version_info < (3, 11) or not lease_task.cancelling()) ): if raise_when_found: raise OplockFound return lease_task return None async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: """Wait for a redis lock to be released via pubsub. Coroutine will not return until the lock is obtained. It _might_ raise OplockFound if another coroutine in this process did get the lock and Oplock is enabled. Args: lock_key: The redis key for the lock. lock_id: The ID of the lock. """ token = lock_key.decode().rsplit("_lock", 1)[0] if ( # If there's not a line, try to get the lock immediately. not self._n_lock_waiters(lock_key) and await self._try_get_lock(lock_key, lock_id) ): if self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key.decode()} instaque by {lock_id.decode()}" ) return # Make sure lock waiter task is running. with contextlib.suppress(TimeoutError, asyncio.TimeoutError): await self._ensure_lock_task_subscribed() async with ( self._lock_waiter(lock_key) as lock_released_event, self._request_lock_release(lock_key, lock_id), ): while ( self._n_lock_waiters(lock_key) > 1 and not lock_released_event.is_set() ) or ( # We didn't get the lock so wait for the next release event. lock_released_event.clear() is None and not await self._try_get_lock(lock_key, lock_id) ): # Check if this process got a lease, then we can abandon waiting on the redis lock. await self._get_local_lease(token, raise_when_found=True) if self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key.decode()} waiting for {lock_id.decode()}" ) try: await asyncio.wait_for( lock_released_event.wait(), timeout=max(self.lock_expiration / 1000, 0), ) except (TimeoutError, asyncio.TimeoutError): if self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key.decode()} wait timeout for {lock_id.decode()}" ) lock_released_event.set() # to re-check the lock if self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key.decode()} acquired by {lock_id.decode()} event={lock_released_event}" ) @contextlib.asynccontextmanager async def _lock( self, token: StateToken[Any], event_name: str | None = None ) -> AsyncIterator[bytes]: """Obtain a redis lock for a token. Args: token: The token to obtain a lock for. event_name: The name of the event associated with the lock. Yields: The ID of the lock (to be passed to set_state). Raises: LockExpiredError: If the lock has expired while processing the event. """ lock_key = self._lock_key(token) lock_id = ( f"{event_name}:{uuid.uuid4().hex}" if event_name else uuid.uuid4().hex ).encode() await self._wait_lock(lock_key, lock_id) state_is_locked = True try: yield lock_id except LockExpiredError: state_is_locked = False raise finally: if state_is_locked: # only delete our lock deleted_lock_id = await self.redis.getdel(lock_key) if deleted_lock_id == lock_id: if self._debug_enabled: console.debug( f"{SMR} [{time.monotonic() - start:.3f}] {lock_key.decode()} released by {lock_id.decode()}" ) elif deleted_lock_id is not None: # This can happen if the caller never tried to `set_state` before the lock expired and is a pretty bad bug. console.warn( f"{lock_key.decode()} was released by {lock_id.decode()}, but it belonged to {deleted_lock_id.decode()}. This is a bug." ) # To avoid race when a waiter is registered after the del message is processed. self._notify_next_waiter(lock_key) async def close(self): """Explicitly close the redis connection and connection_pool. It is necessary in testing scenarios to close between asyncio test cases to avoid having lingering redis connections associated with event loops that will be closed (each test case uses its own event loop). Note: Connections will be automatically reopened when needed. """ try: # Kill the lock task first so waiters don't get lock notifications. if self._lock_task is not None: self._lock_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._lock_task self._lock_task = None # Then cancel all outstanding leases and write the cached states to redis. for lease_task in self._local_leases.values(): lease_task.cancel() await asyncio.gather(*self._local_leases.values(), return_exceptions=True) finally: await self.redis.aclose(close_connection_pool=True)