464 lines
16 KiB
Python
464 lines
16 KiB
Python
"""Token manager for handling client token to session ID mappings."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import dataclasses
|
|
import pickle
|
|
import uuid
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import AsyncIterator, Callable, Coroutine
|
|
from types import MappingProxyType
|
|
from typing import TYPE_CHECKING, ClassVar
|
|
|
|
from reflex.istate.manager.redis import StateManagerRedis
|
|
from reflex.state import StateUpdate
|
|
from reflex.utils import console, prerequisites
|
|
from reflex.utils.tasks import ensure_task
|
|
|
|
if TYPE_CHECKING:
|
|
from redis.asyncio import Redis
|
|
|
|
|
|
def _get_new_token() -> str:
|
|
"""Generate a new unique token.
|
|
|
|
Returns:
|
|
A new UUID4 token string.
|
|
"""
|
|
return str(uuid.uuid4())
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
class SocketRecord:
|
|
"""Record for a connected socket client."""
|
|
|
|
instance_id: str
|
|
sid: str
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
class LostAndFoundRecord:
|
|
"""Record for a StateUpdate for a token with its socket on another instance."""
|
|
|
|
token: str
|
|
update: StateUpdate
|
|
|
|
|
|
class TokenManager(ABC):
|
|
"""Abstract base class for managing client token to session ID mappings."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the token manager with local dictionaries."""
|
|
# Each process has an instance_id to identify its own sockets.
|
|
self.instance_id: str = _get_new_token()
|
|
# Keep a mapping between client token and socket ID.
|
|
self.token_to_socket: dict[str, SocketRecord] = {}
|
|
# Keep a mapping between socket ID and client token.
|
|
self.sid_to_token: dict[str, str] = {}
|
|
|
|
@property
|
|
def token_to_sid(self) -> MappingProxyType[str, str]:
|
|
"""Read-only compatibility property for token_to_socket mapping.
|
|
|
|
Returns:
|
|
The token to session ID mapping.
|
|
"""
|
|
return MappingProxyType({
|
|
token: sr.sid for token, sr in self.token_to_socket.items()
|
|
})
|
|
|
|
async def enumerate_tokens(self) -> AsyncIterator[str]:
|
|
"""Iterate over all tokens in the system.
|
|
|
|
Yields:
|
|
All client tokens known to the TokenManager.
|
|
"""
|
|
for token in self.token_to_socket:
|
|
yield token
|
|
|
|
@abstractmethod
|
|
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
|
|
"""Link a token to a session ID.
|
|
|
|
Args:
|
|
token: The client token.
|
|
sid: The Socket.IO session ID.
|
|
|
|
Returns:
|
|
New token if duplicate detected and new token generated, None otherwise.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def disconnect_token(self, token: str, sid: str) -> None:
|
|
"""Clean up token mapping when client disconnects.
|
|
|
|
Args:
|
|
token: The client token.
|
|
sid: The Socket.IO session ID.
|
|
"""
|
|
|
|
@classmethod
|
|
def create(cls) -> TokenManager:
|
|
"""Factory method to create appropriate TokenManager implementation.
|
|
|
|
Returns:
|
|
RedisTokenManager if Redis is available, LocalTokenManager otherwise.
|
|
"""
|
|
if prerequisites.check_redis_used():
|
|
redis_client = prerequisites.get_redis()
|
|
if redis_client is not None:
|
|
return RedisTokenManager(redis_client)
|
|
|
|
return LocalTokenManager()
|
|
|
|
async def disconnect_all(self):
|
|
"""Disconnect all tracked tokens when the server is going down."""
|
|
token_sid_pairs: set[tuple[str, str]] = {
|
|
(token, sr.sid) for token, sr in self.token_to_socket.items()
|
|
}
|
|
token_sid_pairs.update(
|
|
((token, sid) for sid, token in self.sid_to_token.items())
|
|
)
|
|
# Perform the disconnection logic here
|
|
for token, sid in token_sid_pairs:
|
|
await self.disconnect_token(token, sid)
|
|
|
|
|
|
class LocalTokenManager(TokenManager):
|
|
"""Token manager using local in-memory dictionaries (single worker)."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the local token manager."""
|
|
super().__init__()
|
|
|
|
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
|
|
"""Link a token to a session ID.
|
|
|
|
Args:
|
|
token: The client token.
|
|
sid: The Socket.IO session ID.
|
|
|
|
Returns:
|
|
New token if duplicate detected and new token generated, None otherwise.
|
|
"""
|
|
# Check if token is already mapped to a different SID (duplicate tab)
|
|
if (
|
|
socket_record := self.token_to_socket.get(token)
|
|
) is not None and sid != socket_record.sid:
|
|
new_token = _get_new_token()
|
|
self.token_to_socket[new_token] = SocketRecord(
|
|
instance_id=self.instance_id, sid=sid
|
|
)
|
|
self.sid_to_token[sid] = new_token
|
|
return new_token
|
|
|
|
# Normal case - link token to SID
|
|
self.token_to_socket[token] = SocketRecord(
|
|
instance_id=self.instance_id, sid=sid
|
|
)
|
|
self.sid_to_token[sid] = token
|
|
return None
|
|
|
|
async def disconnect_token(self, token: str, sid: str) -> None:
|
|
"""Clean up token mapping when client disconnects.
|
|
|
|
Args:
|
|
token: The client token.
|
|
sid: The Socket.IO session ID.
|
|
"""
|
|
# Clean up both mappings
|
|
self.token_to_socket.pop(token, None)
|
|
self.sid_to_token.pop(sid, None)
|
|
|
|
|
|
class RedisTokenManager(LocalTokenManager):
|
|
"""Token manager using Redis for distributed multi-worker support.
|
|
|
|
Inherits local dict logic from LocalTokenManager and adds Redis layer
|
|
for cross-worker duplicate detection.
|
|
"""
|
|
|
|
_token_socket_record_prefix: ClassVar[str] = "token_manager_socket_record_"
|
|
|
|
def __init__(self, redis: Redis):
|
|
"""Initialize the Redis token manager.
|
|
|
|
Args:
|
|
redis: The Redis client instance.
|
|
"""
|
|
# Initialize parent's local dicts
|
|
super().__init__()
|
|
|
|
self.redis = redis
|
|
|
|
# Get token expiration from config (default 1 hour)
|
|
from reflex_base.config import get_config
|
|
|
|
config = get_config()
|
|
self.token_expiration = config.redis_token_expiration
|
|
|
|
# Pub/sub tasks for handling sockets owned by other instances.
|
|
self._socket_record_task: asyncio.Task | None = None
|
|
self._lost_and_found_task: asyncio.Task | None = None
|
|
|
|
def _get_redis_key(self, token: str) -> str:
|
|
"""Get Redis key for token mapping.
|
|
|
|
Args:
|
|
token: The client token.
|
|
|
|
Returns:
|
|
Redis key following Reflex conventions: token_manager_socket_record_{token}
|
|
"""
|
|
return f"{self._token_socket_record_prefix}{token}"
|
|
|
|
async def enumerate_tokens(self) -> AsyncIterator[str]:
|
|
"""Iterate over all tokens in the system.
|
|
|
|
Yields:
|
|
All client tokens known to the RedisTokenManager.
|
|
"""
|
|
cursor = 0
|
|
while scan_result := await self.redis.scan(
|
|
cursor=cursor, match=self._get_redis_key("*")
|
|
):
|
|
cursor = int(scan_result[0])
|
|
for key in scan_result[1]:
|
|
yield key.decode().replace(self._token_socket_record_prefix, "")
|
|
if not cursor:
|
|
break
|
|
|
|
async def _handle_socket_record_del(
|
|
self, token: str, expired: bool = False
|
|
) -> None:
|
|
"""Handle deletion of a socket record from Redis.
|
|
|
|
Args:
|
|
token: The client token whose record was deleted.
|
|
expired: Whether the deletion was due to expiration.
|
|
"""
|
|
if (
|
|
socket_record := self.token_to_socket.pop(token, None)
|
|
) is not None and socket_record.instance_id == self.instance_id:
|
|
self.sid_to_token.pop(socket_record.sid, None)
|
|
if expired:
|
|
# Keep the record alive as long as this process is alive and not deleted.
|
|
await self.link_token_to_sid(token, socket_record.sid)
|
|
|
|
async def _subscribe_socket_record_updates(self) -> None:
|
|
"""Subscribe to Redis keyspace notifications for socket record updates."""
|
|
await StateManagerRedis(redis=self.redis)._enable_keyspace_notifications()
|
|
redis_db = self.redis.get_connection_kwargs().get("db", 0)
|
|
|
|
async with self.redis.pubsub() as pubsub:
|
|
await pubsub.psubscribe(
|
|
f"__keyspace@{redis_db}__:{self._get_redis_key('*')}"
|
|
)
|
|
async for message in pubsub.listen():
|
|
if message["type"] == "pmessage":
|
|
key = message["channel"].split(b":", 1)[1].decode()
|
|
token = key.replace(self._token_socket_record_prefix, "")
|
|
|
|
if token not in self.token_to_socket:
|
|
# We don't know about this token, skip
|
|
continue
|
|
|
|
event = message["data"].decode()
|
|
if event in ("del", "expired", "evicted"):
|
|
await self._handle_socket_record_del(
|
|
token,
|
|
expired=(event == "expired"),
|
|
)
|
|
elif event == "set":
|
|
await self._get_token_owner(token, refresh=True)
|
|
|
|
def _ensure_socket_record_task(self) -> None:
|
|
"""Ensure the socket record updates subscriber task is running."""
|
|
ensure_task(
|
|
owner=self,
|
|
task_attribute="_socket_record_task",
|
|
coro_function=self._subscribe_socket_record_updates,
|
|
suppress_exceptions=[Exception],
|
|
)
|
|
|
|
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
|
|
"""Link a token to a session ID with Redis-based duplicate detection.
|
|
|
|
Args:
|
|
token: The client token.
|
|
sid: The Socket.IO session ID.
|
|
|
|
Returns:
|
|
New token if duplicate detected and new token generated, None otherwise.
|
|
"""
|
|
# Fast local check first (handles reconnections)
|
|
if (
|
|
socket_record := self.token_to_socket.get(token)
|
|
) is not None and sid == socket_record.sid:
|
|
return None # Same token, same SID = reconnection, no Redis check needed
|
|
|
|
# Make sure the update subscriber is running
|
|
self._ensure_socket_record_task()
|
|
|
|
# Check Redis for cross-worker duplicates
|
|
redis_key = self._get_redis_key(token)
|
|
|
|
try:
|
|
token_exists_in_redis = await self.redis.exists(redis_key)
|
|
except Exception as e:
|
|
console.error(f"Redis error checking token existence: {e}")
|
|
return await super().link_token_to_sid(token, sid)
|
|
|
|
new_token = None
|
|
if token_exists_in_redis:
|
|
# Duplicate exists somewhere - generate new token
|
|
token = new_token = _get_new_token()
|
|
redis_key = self._get_redis_key(new_token)
|
|
|
|
# Store in local dicts
|
|
socket_record = self.token_to_socket[token] = SocketRecord(
|
|
instance_id=self.instance_id, sid=sid
|
|
)
|
|
self.sid_to_token[sid] = token
|
|
|
|
# Store in Redis if possible
|
|
try:
|
|
await self.redis.set(
|
|
redis_key,
|
|
pickle.dumps(socket_record),
|
|
ex=self.token_expiration,
|
|
)
|
|
except Exception as e:
|
|
console.error(f"Redis error storing token: {e}")
|
|
# Return the new token if one was generated
|
|
return new_token
|
|
|
|
async def disconnect_token(self, token: str, sid: str) -> None:
|
|
"""Clean up token mapping when client disconnects.
|
|
|
|
Args:
|
|
token: The client token.
|
|
sid: The Socket.IO session ID.
|
|
"""
|
|
# Only clean up if we own it locally (fast ownership check)
|
|
if (
|
|
(socket_record := self.token_to_socket.get(token)) is not None
|
|
and socket_record.sid == sid
|
|
and socket_record.instance_id == self.instance_id
|
|
):
|
|
# Clean up Redis
|
|
redis_key = self._get_redis_key(token)
|
|
try:
|
|
await self.redis.delete(redis_key)
|
|
except Exception as e:
|
|
console.error(f"Redis error deleting token: {e}")
|
|
|
|
# Clean up local dicts (always do this)
|
|
await super().disconnect_token(token, sid)
|
|
|
|
@staticmethod
|
|
def _get_lost_and_found_key(instance_id: str) -> str:
|
|
"""Get the Redis key for lost and found deltas for an instance.
|
|
|
|
Args:
|
|
instance_id: The instance ID.
|
|
|
|
Returns:
|
|
The Redis key for lost and found deltas.
|
|
"""
|
|
return f"token_manager_lost_and_found_{instance_id}"
|
|
|
|
async def _subscribe_lost_and_found_updates(
|
|
self,
|
|
emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]],
|
|
) -> None:
|
|
"""Subscribe to Redis channel notifications for lost and found deltas.
|
|
|
|
Args:
|
|
emit_update: The function to emit state updates.
|
|
"""
|
|
async with self.redis.pubsub() as pubsub:
|
|
await pubsub.psubscribe(
|
|
f"channel:{self._get_lost_and_found_key(self.instance_id)}"
|
|
)
|
|
async for message in pubsub.listen():
|
|
if message["type"] == "pmessage":
|
|
record = pickle.loads(message["data"])
|
|
await emit_update(record.update, record.token)
|
|
|
|
def ensure_lost_and_found_task(
|
|
self,
|
|
emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]],
|
|
) -> None:
|
|
"""Ensure the lost and found subscriber task is running.
|
|
|
|
Args:
|
|
emit_update: The function to emit state updates.
|
|
"""
|
|
ensure_task(
|
|
owner=self,
|
|
task_attribute="_lost_and_found_task",
|
|
coro_function=self._subscribe_lost_and_found_updates,
|
|
suppress_exceptions=[Exception],
|
|
emit_update=emit_update,
|
|
)
|
|
|
|
async def _get_token_owner(self, token: str, refresh: bool = False) -> str | None:
|
|
"""Get the instance ID of the owner of a token.
|
|
|
|
Args:
|
|
token: The client token.
|
|
refresh: Whether to fetch the latest record from Redis.
|
|
|
|
Returns:
|
|
The instance ID of the owner, or None if not found.
|
|
"""
|
|
if (
|
|
not refresh
|
|
and (socket_record := self.token_to_socket.get(token)) is not None
|
|
):
|
|
return socket_record.instance_id
|
|
|
|
redis_key = self._get_redis_key(token)
|
|
try:
|
|
record_pkl = await self.redis.get(redis_key)
|
|
if record_pkl:
|
|
socket_record = pickle.loads(record_pkl)
|
|
self.token_to_socket[token] = socket_record
|
|
self.sid_to_token[socket_record.sid] = token
|
|
return socket_record.instance_id
|
|
console.warn(f"Redis token owner not found for token {token}")
|
|
except Exception as e:
|
|
console.error(f"Redis error getting token owner: {e}")
|
|
return None
|
|
|
|
async def emit_lost_and_found(
|
|
self,
|
|
token: str,
|
|
update: StateUpdate,
|
|
) -> bool:
|
|
"""Emit a lost and found delta to Redis.
|
|
|
|
Args:
|
|
token: The client token.
|
|
update: The state update.
|
|
|
|
Returns:
|
|
True if the delta was published, False otherwise.
|
|
"""
|
|
# See where this update belongs
|
|
owner_instance_id = await self._get_token_owner(token)
|
|
if owner_instance_id is None:
|
|
return False
|
|
record = LostAndFoundRecord(token=token, update=update)
|
|
try:
|
|
await self.redis.publish(
|
|
f"channel:{self._get_lost_and_found_key(owner_instance_id)}",
|
|
pickle.dumps(record),
|
|
)
|
|
except Exception as e:
|
|
console.error(f"Redis error publishing lost and found delta: {e}")
|
|
else:
|
|
return True
|
|
return False
|