eptm_dashboard/.venv/lib/python3.12/site-packages/reflex/utils/token_manager.py

464 lines
16 KiB
Python

"""Token manager for handling client token to session ID mappings."""
from __future__ import annotations
import asyncio
import dataclasses
import pickle
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Callable, Coroutine
from types import MappingProxyType
from typing import TYPE_CHECKING, ClassVar
from reflex.istate.manager.redis import StateManagerRedis
from reflex.state import StateUpdate
from reflex.utils import console, prerequisites
from reflex.utils.tasks import ensure_task
if TYPE_CHECKING:
from redis.asyncio import Redis
def _get_new_token() -> str:
"""Generate a new unique token.
Returns:
A new UUID4 token string.
"""
return str(uuid.uuid4())
@dataclasses.dataclass(frozen=True, kw_only=True)
class SocketRecord:
"""Record for a connected socket client."""
instance_id: str
sid: str
@dataclasses.dataclass(frozen=True, kw_only=True)
class LostAndFoundRecord:
"""Record for a StateUpdate for a token with its socket on another instance."""
token: str
update: StateUpdate
class TokenManager(ABC):
"""Abstract base class for managing client token to session ID mappings."""
def __init__(self):
"""Initialize the token manager with local dictionaries."""
# Each process has an instance_id to identify its own sockets.
self.instance_id: str = _get_new_token()
# Keep a mapping between client token and socket ID.
self.token_to_socket: dict[str, SocketRecord] = {}
# Keep a mapping between socket ID and client token.
self.sid_to_token: dict[str, str] = {}
@property
def token_to_sid(self) -> MappingProxyType[str, str]:
"""Read-only compatibility property for token_to_socket mapping.
Returns:
The token to session ID mapping.
"""
return MappingProxyType({
token: sr.sid for token, sr in self.token_to_socket.items()
})
async def enumerate_tokens(self) -> AsyncIterator[str]:
"""Iterate over all tokens in the system.
Yields:
All client tokens known to the TokenManager.
"""
for token in self.token_to_socket:
yield token
@abstractmethod
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
"""Link a token to a session ID.
Args:
token: The client token.
sid: The Socket.IO session ID.
Returns:
New token if duplicate detected and new token generated, None otherwise.
"""
@abstractmethod
async def disconnect_token(self, token: str, sid: str) -> None:
"""Clean up token mapping when client disconnects.
Args:
token: The client token.
sid: The Socket.IO session ID.
"""
@classmethod
def create(cls) -> TokenManager:
"""Factory method to create appropriate TokenManager implementation.
Returns:
RedisTokenManager if Redis is available, LocalTokenManager otherwise.
"""
if prerequisites.check_redis_used():
redis_client = prerequisites.get_redis()
if redis_client is not None:
return RedisTokenManager(redis_client)
return LocalTokenManager()
async def disconnect_all(self):
"""Disconnect all tracked tokens when the server is going down."""
token_sid_pairs: set[tuple[str, str]] = {
(token, sr.sid) for token, sr in self.token_to_socket.items()
}
token_sid_pairs.update(
((token, sid) for sid, token in self.sid_to_token.items())
)
# Perform the disconnection logic here
for token, sid in token_sid_pairs:
await self.disconnect_token(token, sid)
class LocalTokenManager(TokenManager):
"""Token manager using local in-memory dictionaries (single worker)."""
def __init__(self):
"""Initialize the local token manager."""
super().__init__()
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
"""Link a token to a session ID.
Args:
token: The client token.
sid: The Socket.IO session ID.
Returns:
New token if duplicate detected and new token generated, None otherwise.
"""
# Check if token is already mapped to a different SID (duplicate tab)
if (
socket_record := self.token_to_socket.get(token)
) is not None and sid != socket_record.sid:
new_token = _get_new_token()
self.token_to_socket[new_token] = SocketRecord(
instance_id=self.instance_id, sid=sid
)
self.sid_to_token[sid] = new_token
return new_token
# Normal case - link token to SID
self.token_to_socket[token] = SocketRecord(
instance_id=self.instance_id, sid=sid
)
self.sid_to_token[sid] = token
return None
async def disconnect_token(self, token: str, sid: str) -> None:
"""Clean up token mapping when client disconnects.
Args:
token: The client token.
sid: The Socket.IO session ID.
"""
# Clean up both mappings
self.token_to_socket.pop(token, None)
self.sid_to_token.pop(sid, None)
class RedisTokenManager(LocalTokenManager):
"""Token manager using Redis for distributed multi-worker support.
Inherits local dict logic from LocalTokenManager and adds Redis layer
for cross-worker duplicate detection.
"""
_token_socket_record_prefix: ClassVar[str] = "token_manager_socket_record_"
def __init__(self, redis: Redis):
"""Initialize the Redis token manager.
Args:
redis: The Redis client instance.
"""
# Initialize parent's local dicts
super().__init__()
self.redis = redis
# Get token expiration from config (default 1 hour)
from reflex_base.config import get_config
config = get_config()
self.token_expiration = config.redis_token_expiration
# Pub/sub tasks for handling sockets owned by other instances.
self._socket_record_task: asyncio.Task | None = None
self._lost_and_found_task: asyncio.Task | None = None
def _get_redis_key(self, token: str) -> str:
"""Get Redis key for token mapping.
Args:
token: The client token.
Returns:
Redis key following Reflex conventions: token_manager_socket_record_{token}
"""
return f"{self._token_socket_record_prefix}{token}"
async def enumerate_tokens(self) -> AsyncIterator[str]:
"""Iterate over all tokens in the system.
Yields:
All client tokens known to the RedisTokenManager.
"""
cursor = 0
while scan_result := await self.redis.scan(
cursor=cursor, match=self._get_redis_key("*")
):
cursor = int(scan_result[0])
for key in scan_result[1]:
yield key.decode().replace(self._token_socket_record_prefix, "")
if not cursor:
break
async def _handle_socket_record_del(
self, token: str, expired: bool = False
) -> None:
"""Handle deletion of a socket record from Redis.
Args:
token: The client token whose record was deleted.
expired: Whether the deletion was due to expiration.
"""
if (
socket_record := self.token_to_socket.pop(token, None)
) is not None and socket_record.instance_id == self.instance_id:
self.sid_to_token.pop(socket_record.sid, None)
if expired:
# Keep the record alive as long as this process is alive and not deleted.
await self.link_token_to_sid(token, socket_record.sid)
async def _subscribe_socket_record_updates(self) -> None:
"""Subscribe to Redis keyspace notifications for socket record updates."""
await StateManagerRedis(redis=self.redis)._enable_keyspace_notifications()
redis_db = self.redis.get_connection_kwargs().get("db", 0)
async with self.redis.pubsub() as pubsub:
await pubsub.psubscribe(
f"__keyspace@{redis_db}__:{self._get_redis_key('*')}"
)
async for message in pubsub.listen():
if message["type"] == "pmessage":
key = message["channel"].split(b":", 1)[1].decode()
token = key.replace(self._token_socket_record_prefix, "")
if token not in self.token_to_socket:
# We don't know about this token, skip
continue
event = message["data"].decode()
if event in ("del", "expired", "evicted"):
await self._handle_socket_record_del(
token,
expired=(event == "expired"),
)
elif event == "set":
await self._get_token_owner(token, refresh=True)
def _ensure_socket_record_task(self) -> None:
"""Ensure the socket record updates subscriber task is running."""
ensure_task(
owner=self,
task_attribute="_socket_record_task",
coro_function=self._subscribe_socket_record_updates,
suppress_exceptions=[Exception],
)
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
"""Link a token to a session ID with Redis-based duplicate detection.
Args:
token: The client token.
sid: The Socket.IO session ID.
Returns:
New token if duplicate detected and new token generated, None otherwise.
"""
# Fast local check first (handles reconnections)
if (
socket_record := self.token_to_socket.get(token)
) is not None and sid == socket_record.sid:
return None # Same token, same SID = reconnection, no Redis check needed
# Make sure the update subscriber is running
self._ensure_socket_record_task()
# Check Redis for cross-worker duplicates
redis_key = self._get_redis_key(token)
try:
token_exists_in_redis = await self.redis.exists(redis_key)
except Exception as e:
console.error(f"Redis error checking token existence: {e}")
return await super().link_token_to_sid(token, sid)
new_token = None
if token_exists_in_redis:
# Duplicate exists somewhere - generate new token
token = new_token = _get_new_token()
redis_key = self._get_redis_key(new_token)
# Store in local dicts
socket_record = self.token_to_socket[token] = SocketRecord(
instance_id=self.instance_id, sid=sid
)
self.sid_to_token[sid] = token
# Store in Redis if possible
try:
await self.redis.set(
redis_key,
pickle.dumps(socket_record),
ex=self.token_expiration,
)
except Exception as e:
console.error(f"Redis error storing token: {e}")
# Return the new token if one was generated
return new_token
async def disconnect_token(self, token: str, sid: str) -> None:
"""Clean up token mapping when client disconnects.
Args:
token: The client token.
sid: The Socket.IO session ID.
"""
# Only clean up if we own it locally (fast ownership check)
if (
(socket_record := self.token_to_socket.get(token)) is not None
and socket_record.sid == sid
and socket_record.instance_id == self.instance_id
):
# Clean up Redis
redis_key = self._get_redis_key(token)
try:
await self.redis.delete(redis_key)
except Exception as e:
console.error(f"Redis error deleting token: {e}")
# Clean up local dicts (always do this)
await super().disconnect_token(token, sid)
@staticmethod
def _get_lost_and_found_key(instance_id: str) -> str:
"""Get the Redis key for lost and found deltas for an instance.
Args:
instance_id: The instance ID.
Returns:
The Redis key for lost and found deltas.
"""
return f"token_manager_lost_and_found_{instance_id}"
async def _subscribe_lost_and_found_updates(
self,
emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]],
) -> None:
"""Subscribe to Redis channel notifications for lost and found deltas.
Args:
emit_update: The function to emit state updates.
"""
async with self.redis.pubsub() as pubsub:
await pubsub.psubscribe(
f"channel:{self._get_lost_and_found_key(self.instance_id)}"
)
async for message in pubsub.listen():
if message["type"] == "pmessage":
record = pickle.loads(message["data"])
await emit_update(record.update, record.token)
def ensure_lost_and_found_task(
self,
emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]],
) -> None:
"""Ensure the lost and found subscriber task is running.
Args:
emit_update: The function to emit state updates.
"""
ensure_task(
owner=self,
task_attribute="_lost_and_found_task",
coro_function=self._subscribe_lost_and_found_updates,
suppress_exceptions=[Exception],
emit_update=emit_update,
)
async def _get_token_owner(self, token: str, refresh: bool = False) -> str | None:
"""Get the instance ID of the owner of a token.
Args:
token: The client token.
refresh: Whether to fetch the latest record from Redis.
Returns:
The instance ID of the owner, or None if not found.
"""
if (
not refresh
and (socket_record := self.token_to_socket.get(token)) is not None
):
return socket_record.instance_id
redis_key = self._get_redis_key(token)
try:
record_pkl = await self.redis.get(redis_key)
if record_pkl:
socket_record = pickle.loads(record_pkl)
self.token_to_socket[token] = socket_record
self.sid_to_token[socket_record.sid] = token
return socket_record.instance_id
console.warn(f"Redis token owner not found for token {token}")
except Exception as e:
console.error(f"Redis error getting token owner: {e}")
return None
async def emit_lost_and_found(
self,
token: str,
update: StateUpdate,
) -> bool:
"""Emit a lost and found delta to Redis.
Args:
token: The client token.
update: The state update.
Returns:
True if the delta was published, False otherwise.
"""
# See where this update belongs
owner_instance_id = await self._get_token_owner(token)
if owner_instance_id is None:
return False
record = LostAndFoundRecord(token=token, update=update)
try:
await self.redis.publish(
f"channel:{self._get_lost_and_found_key(owner_instance_id)}",
pickle.dumps(record),
)
except Exception as e:
console.error(f"Redis error publishing lost and found delta: {e}")
else:
return True
return False