eptm_dashboard/.venv/lib/python3.12/site-packages/reflex_base/registry.py

155 lines
5.3 KiB
Python

"""A contextual registry for state and event handlers."""
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING
from typing_extensions import Self
from reflex_base.context.base import BaseContext
from reflex_base.utils.exceptions import StateValueError
if TYPE_CHECKING:
from reflex.state import BaseState
from reflex_base.event import EventHandler
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class RegisteredEventHandler:
"""A registered event handler, which includes the handler and its full name."""
handler: EventHandler
states: tuple[type[BaseState], ...]
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True, eq=False)
class RegistrationContext(BaseContext):
"""Context for registering event handlers and states."""
event_handlers: dict[str, RegisteredEventHandler] = dataclasses.field(
default_factory=dict,
repr=False,
)
base_states: dict[str, type[BaseState]] = dataclasses.field(
default_factory=dict,
repr=False,
)
base_state_substates: dict[str, set[type[BaseState]]] = dataclasses.field(
default_factory=dict,
repr=False,
)
@classmethod
def ensure_context(cls) -> Self:
"""Ensure the context is attached, or create a new instance and attach it.
Returns:
The registration context instance.
"""
try:
return cls.get()
except LookupError:
# If the context is not attached, create a new instance and attach it.
ctx = cls()
cls._context_var.set(ctx)
return ctx
@classmethod
def register_base_state(cls, state_cls: type[BaseState]) -> type[BaseState]:
"""Register a base state class with its full name.
Also registers parent_state until finding one that is already registered.
Args:
state_cls: The base state class to register.
Returns:
The registered base state class.
"""
return cls.ensure_context()._register_base_state(state_cls)
def _register_base_state(self, state_cls: type[BaseState]) -> type[BaseState]:
"""Register a base state class with its full name.
Also registers parent_state until finding one that is already registered.
Args:
state_cls: The base state class to register.
Returns:
The registered base state class.
"""
self.base_states[state_cls.get_full_name()] = state_cls
for event_handler in state_cls.event_handlers.values():
self._register_event_handler(event_handler, states=(state_cls,))
if (parent_state := state_cls.get_parent_state()) is not None:
if parent_state.get_full_name() not in self.base_states:
self._register_base_state(parent_state)
parent_state_substates = self.base_state_substates.setdefault(
parent_state.get_full_name(), set()
)
if state_cls in parent_state_substates:
msg = (
f"State class {state_cls.get_full_name()} is already registered as a substate of "
f"{parent_state.get_full_name()}. This likely means there are multiple classes with the same name "
"in the same module, which causes a conflict in the registry. Please rename one of the classes to avoid "
"shadowing. Shadowing substate classes is not allowed."
)
raise StateValueError(msg)
parent_state_substates.add(state_cls)
return state_cls
@classmethod
def register_event_handler(
cls, handler: EventHandler, states: tuple[type[BaseState], ...] = ()
) -> EventHandler:
"""Register an event handler with its full name and associated states.
Args:
handler: The event handler to register.
states: The states associated with the event handler.
Returns:
The registered event handler.
"""
return cls.ensure_context()._register_event_handler(handler, states=states)
def _register_event_handler(
self,
handler: EventHandler,
states: tuple[type[BaseState], ...] = (),
) -> EventHandler:
"""Register an event handler with its full name and associated states.
Args:
handler: The event handler to register.
states: The states associated with the event handler.
Returns:
The registered event handler.
"""
from reflex.utils.format import format_event_handler
full_name = format_event_handler(handler)
self.event_handlers[full_name] = RegisteredEventHandler(
handler=handler, states=states
)
return handler
def get_substates(
self, base_state_cls: type[BaseState] | str
) -> set[type[BaseState]]:
"""Get the substates for a base state class.
Args:
base_state_cls: The base state class to get substates for.
Returns:
A set of substate classes.
"""
if isinstance(base_state_cls, str):
return self.base_state_substates.setdefault(base_state_cls, set())
return self.base_state_substates.setdefault(
base_state_cls.get_full_name(), set()
)