eptm_dashboard/.venv/lib/python3.12/site-packages/reflex_base/vars/base.py

3693 lines
110 KiB
Python

"""Collection of base classes."""
from __future__ import annotations
import contextlib
import copy
import dataclasses
import datetime
import functools
import inspect
import json
import re
import string
import uuid
import warnings
from abc import ABCMeta
from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence
from dataclasses import _MISSING_TYPE, MISSING
from decimal import Decimal
from types import CodeType, FunctionType
from typing import (
TYPE_CHECKING,
Annotated,
Any,
ClassVar,
Generic,
Literal,
NoReturn,
ParamSpec,
Protocol,
TypeGuard,
TypeVar,
cast,
get_args,
get_type_hints,
overload,
)
from rich.markup import escape
from typing_extensions import LiteralString, dataclass_transform, override
from reflex_base import constants
from reflex_base.constants.compiler import Hooks
from reflex_base.constants.state import FIELD_MARKER
from reflex_base.utils import console, exceptions, imports, serializers, types
from reflex_base.utils.compat import annotations_from_namespace
from reflex_base.utils.decorator import once
from reflex_base.utils.exceptions import (
ComputedVarSignatureError,
UntypedComputedVarError,
VarAttributeError,
VarDependencyError,
VarTypeError,
)
from reflex_base.utils.format import format_state_name
from reflex_base.utils.imports import (
ImmutableImportDict,
ImmutableParsedImportDict,
ImportDict,
ImportVar,
ParsedImportTuple,
parse_imports,
)
from reflex_base.utils.types import (
GenericType,
Self,
_isinstance,
get_origin,
has_args,
safe_issubclass,
unionize,
)
if TYPE_CHECKING:
from reflex.state import BaseState
from reflex_base.components.component import BaseComponent
from reflex_base.constants.colors import Color
from .color import LiteralColorVar
from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar
from .object import LiteralObjectVar, ObjectVar
from .sequence import ArrayVar, LiteralArrayVar, LiteralStringVar, StringVar
VAR_TYPE = TypeVar("VAR_TYPE", covariant=True)
OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE")
STRING_T = TypeVar("STRING_T", bound=str)
LITERAL_STRING_T = TypeVar("LITERAL_STRING_T", bound=LiteralString)
SEQUENCE_TYPE = TypeVar("SEQUENCE_TYPE", bound=Sequence)
warnings.filterwarnings("ignore", message="fields may not start with an underscore")
_PYDANTIC_VALIDATE_VALUES = "__pydantic_validate_values__"
def _pydantic_validator(*args, **kwargs):
return None
@dataclasses.dataclass(
eq=False,
frozen=True,
)
class VarSubclassEntry:
"""Entry for a Var subclass."""
var_subclass: type[Var]
to_var_subclass: type[ToOperation]
python_types: tuple[GenericType, ...]
_var_subclasses: list[VarSubclassEntry] = []
_var_literal_subclasses: list[tuple[type[LiteralVar], VarSubclassEntry]] = []
@dataclasses.dataclass(
eq=True,
frozen=True,
)
class VarData:
"""Metadata associated with a x."""
# The name of the enclosing state.
state: str = dataclasses.field(default="")
# The name of the field in the state.
field_name: str = dataclasses.field(default="")
# Imports needed to render this var
imports: ParsedImportTuple = dataclasses.field(default_factory=tuple)
# Hooks that need to be present in the component to render this var
hooks: tuple[str, ...] = dataclasses.field(default_factory=tuple)
# Dependencies of the var
deps: tuple[Var, ...] = dataclasses.field(default_factory=tuple)
# Position of the hook in the component
position: Hooks.HookPosition | None = None
# Components that are part of this var
components: tuple[BaseComponent, ...] = dataclasses.field(default_factory=tuple)
def __init__(
self,
state: str = "",
field_name: str = "",
imports: ImmutableImportDict | ImmutableParsedImportDict | None = None,
hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None,
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
components: Iterable[BaseComponent] | None = None,
):
"""Initialize the var data.
Args:
state: The name of the enclosing state.
field_name: The name of the field in the state.
imports: Imports needed to render this var.
hooks: Hooks that need to be present in the component to render this var.
deps: Dependencies of the var for useCallback.
position: Position of the hook in the component.
components: Components that are part of this var.
"""
if isinstance(hooks, str):
hooks = [hooks]
if not isinstance(hooks, dict):
hooks = dict.fromkeys(hooks or [])
immutable_imports: ParsedImportTuple = tuple(
(k, tuple(v)) for k, v in parse_imports(imports or {}).items()
)
object.__setattr__(self, "state", state)
object.__setattr__(self, "field_name", field_name)
object.__setattr__(self, "imports", immutable_imports)
object.__setattr__(self, "hooks", tuple(hooks or {}))
object.__setattr__(self, "deps", tuple(deps or []))
object.__setattr__(self, "position", position or None)
object.__setattr__(self, "components", tuple(components or []))
if hooks and any(hooks.values()):
# Merge our dependencies first, so they can be referenced.
merged_var_data = VarData.merge(*hooks.values(), self)
if merged_var_data is not None:
object.__setattr__(self, "state", merged_var_data.state)
object.__setattr__(self, "field_name", merged_var_data.field_name)
object.__setattr__(self, "imports", merged_var_data.imports)
object.__setattr__(self, "hooks", merged_var_data.hooks)
object.__setattr__(self, "deps", merged_var_data.deps)
object.__setattr__(self, "position", merged_var_data.position)
object.__setattr__(self, "components", merged_var_data.components)
def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict.
Returns:
The imports as a mutable dict.
"""
return {k: list(v) for k, v in self.imports}
def merge(*all: VarData | None) -> VarData | None:
"""Merge multiple var data objects.
Args:
*all: The var data objects to merge.
Returns:
The merged var data object.
Raises:
ReflexError: If trying to merge VarData with different positions.
# noqa: DAR102 *all
"""
all_var_datas = list(filter(None, all))
if not all_var_datas:
return None
if len(all_var_datas) == 1:
return all_var_datas[0]
# Get the first non-empty field name or default to empty string.
field_name = next(
(var_data.field_name for var_data in all_var_datas if var_data.field_name),
"",
)
# Get the first non-empty state or default to empty string.
state = next(
(var_data.state for var_data in all_var_datas if var_data.state), ""
)
hooks: dict[str, VarData | None] = {
hook: None for var_data in all_var_datas for hook in var_data.hooks
}
imports_ = imports.merge_imports(
*(var_data.imports for var_data in all_var_datas)
)
deps = [dep for var_data in all_var_datas for dep in var_data.deps]
positions = list(
dict.fromkeys(
var_data.position
for var_data in all_var_datas
if var_data.position is not None
)
)
if positions:
if len(positions) > 1:
msg = f"Cannot merge var data with different positions: {positions}"
raise exceptions.ReflexError(msg)
position = positions[0]
else:
position = None
components = tuple(
component for var_data in all_var_datas for component in var_data.components
)
return VarData(
state=state,
field_name=field_name,
imports=imports_,
hooks=hooks,
deps=deps,
position=position,
components=components,
)
def __bool__(self) -> bool:
"""Check if the var data is non-empty.
Returns:
True if any field is set to a non-default value.
"""
return bool(
self.state
or self.imports
or self.hooks
or self.field_name
or self.deps
or self.position
or self.components
)
@classmethod
def from_state(cls, state: type[BaseState] | str, field_name: str = "") -> VarData:
"""Set the state of the var.
Args:
state: The state to set or the full name of the state.
field_name: The name of the field in the state. Optional.
Returns:
The var with the set state.
"""
from reflex_base.utils import format
state_name = state if isinstance(state, str) else state.get_full_name()
return VarData(
state=state_name,
field_name=field_name,
hooks={
"const {0} = useContext(StateContexts.{0})".format(
format.format_state_name(state_name)
): None
},
imports={
f"$/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
"react": [ImportVar(tag="useContext")],
},
)
def _decode_var_immutable(value: str) -> tuple[VarData | None, str]:
"""Decode the state name from a formatted var.
Args:
value: The value to extract the state name from.
Returns:
The extracted state name and the value without the state name.
"""
var_datas = []
if isinstance(value, str):
# fast path if there is no encoded VarData
if constants.REFLEX_VAR_OPENING_TAG not in value:
return None, value
offset = 0
# Find all tags.
while m := _decode_var_pattern.search(value):
start, end = m.span()
value = value[:start] + value[end:]
serialized_data = m.group(1)
if serialized_data.isnumeric() or (
serialized_data[0] == "-" and serialized_data[1:].isnumeric()
):
# This is a global immutable var.
var = _global_vars[int(serialized_data)]
var_data = var._get_all_var_data()
if var_data is not None:
var_datas.append(var_data)
offset += end - start
return VarData.merge(*var_datas) if var_datas else None, value
def can_use_in_object_var(cls: GenericType) -> bool:
"""Check if the class can be used in an ObjectVar.
Args:
cls: The class to check.
Returns:
Whether the class can be used in an ObjectVar.
"""
if types.is_union(cls):
return all(can_use_in_object_var(t) for t in types.get_args(cls))
return (
isinstance(cls, type)
and not safe_issubclass(cls, Var)
and serializers.can_serialize(cls, dict)
)
class MetaclassVar(type):
"""Metaclass for the Var class."""
def __setattr__(cls, name: str, value: Any):
"""Set an attribute on the class.
Args:
name: The name of the attribute.
value: The value of the attribute.
"""
super().__setattr__(
name, value if name != _PYDANTIC_VALIDATE_VALUES else _pydantic_validator
)
@dataclasses.dataclass(
eq=False,
frozen=True,
)
class Var(Generic[VAR_TYPE], metaclass=MetaclassVar):
"""Base class for immutable vars."""
# The name of the var.
_js_expr: str = dataclasses.field()
# The type of the var.
_var_type: types.GenericType = dataclasses.field(default=Any)
# Extra metadata associated with the Var
_var_data: VarData | None = dataclasses.field(default=None)
def __str__(self) -> str:
"""String representation of the var. Guaranteed to be a valid Javascript expression.
Returns:
The name of the var.
"""
return self._js_expr
@property
def _var_is_local(self) -> bool:
"""Whether this is a local javascript variable.
Returns:
False
"""
return False
@property
def _var_is_string(self) -> bool:
"""Whether the var is a string literal.
Returns:
False
"""
return False
def __init_subclass__(
cls,
python_types: tuple[GenericType, ...] | GenericType = types.Unset(),
default_type: GenericType = types.Unset(),
**kwargs,
):
"""Initialize the subclass.
Args:
python_types: The python types that the var represents.
default_type: The default type of the var. Defaults to the first python type.
**kwargs: Additional keyword arguments.
"""
super().__init_subclass__(**kwargs)
if python_types or default_type:
python_types = (
(python_types if isinstance(python_types, tuple) else (python_types,))
if python_types
else ()
)
default_type = default_type or (python_types[0] if python_types else Any)
@dataclasses.dataclass(
eq=False,
frozen=True,
slots=True,
)
class ToVarOperation(ToOperation, cls):
"""Base class of converting a var to another var type."""
_original: Var = dataclasses.field(
default=Var(_js_expr="null", _var_type=None),
)
_default_var_type: ClassVar[GenericType] = default_type
new_to_var_operation_name = f"{cls.__name__.removesuffix('Var')}CastedVar"
ToVarOperation.__qualname__ = (
ToVarOperation.__qualname__.removesuffix(ToVarOperation.__name__)
+ new_to_var_operation_name
)
ToVarOperation.__name__ = new_to_var_operation_name
_var_subclasses.append(VarSubclassEntry(cls, ToVarOperation, python_types))
def __post_init__(self):
"""Post-initialize the var.
Raises:
TypeError: If _js_expr is not a string.
"""
if not isinstance(self._js_expr, str):
msg = f"Expected _js_expr to be a string, got value {self._js_expr!r} of type {type(self._js_expr).__name__}"
raise TypeError(msg)
if self._var_data is not None and not isinstance(self._var_data, VarData):
msg = f"Expected _var_data to be a VarData, got value {self._var_data!r} of type {type(self._var_data).__name__}"
raise TypeError(msg)
# Decode any inline Var markup and apply it to the instance
var_data_, js_expr_ = _decode_var_immutable(self._js_expr)
if var_data_ or js_expr_ != self._js_expr:
self.__init__(
_js_expr=js_expr_,
_var_type=self._var_type,
_var_data=VarData.merge(self._var_data, var_data_),
)
def __hash__(self) -> int:
"""Define a hash function for the var.
Returns:
The hash of the var.
"""
return hash((self._js_expr, self._var_type, self._var_data))
def _get_all_var_data(self) -> VarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return self._var_data
def __deepcopy__(self, memo: dict[int, Any]) -> Self:
"""Deepcopy the var.
Args:
memo: The memo dictionary to use for the deepcopy.
Returns:
A deepcopy of the var.
"""
return self
def equals(self, other: Var) -> bool:
"""Check if two vars are equal.
Args:
other: The other var to compare.
Returns:
Whether the vars are equal.
"""
return (
self._js_expr == other._js_expr
and self._var_type == other._var_type
and self._get_all_var_data() == other._get_all_var_data()
)
@overload
def _replace(
self,
_var_type: type[OTHER_VAR_TYPE],
merge_var_data: VarData | None = None,
**kwargs: Any,
) -> Var[OTHER_VAR_TYPE]: ...
@overload
def _replace(
self,
_var_type: GenericType | None = None,
merge_var_data: VarData | None = None,
**kwargs: Any,
) -> Self: ...
def _replace(
self,
_var_type: GenericType | None = None,
merge_var_data: VarData | None = None,
**kwargs: Any,
) -> Self | Var:
"""Make a copy of this Var with updated fields.
Args:
_var_type: The new type of the Var.
merge_var_data: VarData to merge into the existing VarData.
**kwargs: Var fields to update.
Returns:
A new Var with the updated fields overwriting the corresponding fields in this Var.
Raises:
TypeError: If _var_is_local, _var_is_string, or _var_full_name_needs_state_prefix is not None.
"""
if kwargs.get("_var_is_local", False) is not False:
msg = "The _var_is_local argument is not supported for Var."
raise TypeError(msg)
if kwargs.get("_var_is_string", False) is not False:
msg = "The _var_is_string argument is not supported for Var."
raise TypeError(msg)
if kwargs.get("_var_full_name_needs_state_prefix", False) is not False:
msg = "The _var_full_name_needs_state_prefix argument is not supported for Var."
raise TypeError(msg)
value_with_replaced = dataclasses.replace(
self,
_var_type=_var_type or self._var_type,
_var_data=VarData.merge(
kwargs.get("_var_data", self._var_data), merge_var_data
),
**kwargs,
)
if (js_expr := kwargs.get("_js_expr")) is not None:
object.__setattr__(value_with_replaced, "_js_expr", js_expr)
return value_with_replaced
@overload
@classmethod
def create( # pyright: ignore[reportOverlappingOverload]
cls,
value: NoReturn,
_var_data: VarData | None = None,
) -> Var[Any]: ...
@overload
@classmethod
def create( # pyright: ignore[reportOverlappingOverload]
cls,
value: bool,
_var_data: VarData | None = None,
) -> LiteralBooleanVar: ...
@overload
@classmethod
def create(
cls,
value: int,
_var_data: VarData | None = None,
) -> LiteralNumberVar[int]: ...
@overload
@classmethod
def create(
cls,
value: float,
_var_data: VarData | None = None,
) -> LiteralNumberVar[float]: ...
@overload
@classmethod
def create(
cls,
value: Decimal,
_var_data: VarData | None = None,
) -> LiteralNumberVar[Decimal]: ...
@overload
@classmethod
def create( # pyright: ignore [reportOverlappingOverload]
cls,
value: Color,
_var_data: VarData | None = None,
) -> LiteralColorVar: ...
@overload
@classmethod
def create( # pyright: ignore [reportOverlappingOverload]
cls,
value: LITERAL_STRING_T,
_var_data: VarData | None = None,
) -> LiteralStringVar[LITERAL_STRING_T]: ...
@overload
@classmethod
def create( # pyright: ignore [reportOverlappingOverload]
cls,
value: STRING_T,
_var_data: VarData | None = None,
) -> StringVar[STRING_T]: ...
@overload
@classmethod
def create( # pyright: ignore[reportOverlappingOverload]
cls,
value: None,
_var_data: VarData | None = None,
) -> LiteralNoneVar: ...
@overload
@classmethod
def create(
cls,
value: MAPPING_TYPE,
_var_data: VarData | None = None,
) -> LiteralObjectVar[MAPPING_TYPE]: ...
@overload
@classmethod
def create(
cls,
value: SEQUENCE_TYPE,
_var_data: VarData | None = None,
) -> LiteralArrayVar[SEQUENCE_TYPE]: ...
@overload
@classmethod
def create(
cls,
value: OTHER_VAR_TYPE,
_var_data: VarData | None = None,
) -> Var[OTHER_VAR_TYPE]: ...
@classmethod
def create(
cls,
value: OTHER_VAR_TYPE,
_var_data: VarData | None = None,
) -> Var[OTHER_VAR_TYPE]:
"""Create a var from a value.
Args:
value: The value to create the var from.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
"""
# If the value is already a var, do nothing.
if isinstance(value, Var):
return value
return LiteralVar.create(value, _var_data=_var_data)
def __format__(self, format_spec: str) -> str:
"""Format the var into a Javascript equivalent to an f-string.
Args:
format_spec: The format specifier (Ignored for now).
Returns:
The formatted var.
"""
hashed_var = hash(self)
_global_vars[hashed_var] = self
# Encode the _var_data into the formatted output for tracking purposes.
return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._js_expr}"
@overload
def to(self, output: type[str]) -> StringVar: ... # pyright: ignore[reportOverlappingOverload]
@overload
def to(self, output: type[bool]) -> BooleanVar: ...
@overload
def to(self, output: type[int]) -> NumberVar[int]: ...
@overload
def to(self, output: type[float]) -> NumberVar[float]: ...
@overload
def to(self, output: type[Decimal]) -> NumberVar[Decimal]: ...
@overload
def to(
self,
output: type[SEQUENCE_TYPE],
) -> ArrayVar[SEQUENCE_TYPE]: ...
@overload
def to(
self,
output: type[MAPPING_TYPE],
) -> ObjectVar[MAPPING_TYPE]: ...
@overload
def to(
self, output: type[ObjectVar], var_type: type[VAR_INSIDE]
) -> ObjectVar[VAR_INSIDE]: ...
@overload
def to(
self, output: type[ObjectVar], var_type: None = None
) -> ObjectVar[VAR_TYPE]: ...
@overload
def to(self, output: VAR_SUBCLASS, var_type: None = None) -> VAR_SUBCLASS: ...
@overload
def to(
self,
output: type[OUTPUT] | types.GenericType,
var_type: types.GenericType | None = None,
) -> OUTPUT: ...
def to(
self,
output: type[OUTPUT] | types.GenericType,
var_type: types.GenericType | None = None,
) -> Var:
"""Convert the var to a different type.
Args:
output: The output type.
var_type: The type of the var.
Returns:
The converted var.
"""
from .object import ObjectVar
fixed_output_type = get_origin(output) or output
# If the first argument is a python type, we map it to the corresponding Var type.
for var_subclass in _var_subclasses[::-1]:
if fixed_output_type in var_subclass.python_types or safe_issubclass(
fixed_output_type, var_subclass.python_types
):
return self.to(var_subclass.var_subclass, output)
if fixed_output_type is None:
return get_to_operation(NoneVar).create(self) # pyright: ignore [reportReturnType]
# Handle fixed_output_type being Base or a dataclass.
if can_use_in_object_var(output):
return self.to(ObjectVar, output)
if isinstance(output, type):
for var_subclass in _var_subclasses[::-1]:
if safe_issubclass(output, var_subclass.var_subclass):
current_var_type = self._var_type
if current_var_type is Any:
new_var_type = var_type
else:
new_var_type = var_type or current_var_type
return var_subclass.to_var_subclass.create( # pyright: ignore [reportReturnType]
value=self, _var_type=new_var_type
)
# If we can't determine the first argument, we just replace the _var_type.
if not safe_issubclass(output, Var) or var_type is None:
return dataclasses.replace(
self,
_var_type=output,
)
# We couldn't determine the output type to be any other Var type, so we replace the _var_type.
if var_type is not None:
return dataclasses.replace(
self,
_var_type=var_type,
)
return self
@overload
def guess_type(self: Var[NoReturn]) -> Var[Any]: ... # pyright: ignore [reportOverlappingOverload]
@overload
def guess_type(self: Var[str]) -> StringVar: ...
@overload
def guess_type(self: Var[bool]) -> BooleanVar: ...
@overload
def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ...
@overload
def guess_type(self) -> Self: ...
def guess_type(self) -> Var:
"""Guesses the type of the variable based on its `_var_type` attribute.
Returns:
Var: The guessed type of the variable.
Raises:
TypeError: If the type is not supported for guessing.
"""
from .object import ObjectVar
var_type = self._var_type
if var_type is None:
return self.to(None)
if var_type is NoReturn:
return self.to(Any)
var_type = types.value_inside_optional(var_type)
if var_type is Any:
return self
fixed_type = get_origin(var_type) or var_type
if fixed_type in types.UnionTypes:
inner_types = get_args(var_type)
non_optional_inner_types = [
types.value_inside_optional(inner_type) for inner_type in inner_types
]
fixed_inner_types = [
get_origin(inner_type) or inner_type
for inner_type in non_optional_inner_types
]
for var_subclass in _var_subclasses[::-1]:
if all(
safe_issubclass(t, var_subclass.python_types)
for t in fixed_inner_types
):
return self.to(var_subclass.var_subclass, self._var_type)
if can_use_in_object_var(var_type):
return self.to(ObjectVar, self._var_type)
return self
if fixed_type is Literal:
args = get_args(var_type)
fixed_type = unionize(*(type(arg) for arg in args))
if not isinstance(fixed_type, type):
msg = f"Unsupported type {var_type} for guess_type."
raise TypeError(msg)
if fixed_type is None:
return self.to(None)
for var_subclass in _var_subclasses[::-1]:
if safe_issubclass(fixed_type, var_subclass.python_types):
return self.to(var_subclass.var_subclass, self._var_type)
if can_use_in_object_var(fixed_type):
return self.to(ObjectVar, self._var_type)
return self
@staticmethod
def _get_setter_name_for_name(
name: str,
) -> str:
"""Get the name of the var's generated setter function.
Args:
name: The name of the var.
Returns:
The name of the setter function.
"""
return constants.SETTER_PREFIX + name
def _get_setter(self, name: str) -> Callable[[BaseState, Any], None]:
"""Get the var's setter function.
Args:
name: The name of the var.
Returns:
A function that that creates a setter for the var.
"""
setter_name = Var._get_setter_name_for_name(name)
def setter(state: Any, value: Any):
"""Get the setter for the var.
Args:
state: The state within which we add the setter function.
value: The value to set.
"""
if self._var_type in [int, float]:
try:
value = self._var_type(value)
setattr(state, name, value)
except ValueError:
console.debug(
f"{type(state).__name__}.{self._js_expr}: Failed conversion of {value!s} to '{self._var_type.__name__}'. Value not set.",
)
else:
setattr(state, name, value)
setter.__annotations__["value"] = self._var_type
setter.__qualname__ = setter_name
return setter
def _var_set_state(self, state: type[BaseState] | str) -> Self:
"""Set the state of the var.
Args:
state: The state to set.
Returns:
The var with the state set.
"""
formatted_state_name = (
state
if isinstance(state, str)
else format_state_name(state.get_full_name())
)
return StateOperation.create( # pyright: ignore [reportReturnType]
formatted_state_name,
self,
_var_data=VarData.merge(
VarData.from_state(state, self._js_expr), self._var_data
),
).guess_type()
def __eq__(self, other: Var | Any) -> BooleanVar:
"""Check if the current variable is equal to the given variable.
Args:
other (Var | Any): The variable to compare with.
Returns:
BooleanVar: A BooleanVar object representing the result of the equality check.
"""
from .number import equal_operation
return equal_operation(self, other)
def __ne__(self, other: Var | Any) -> BooleanVar:
"""Check if the current object is not equal to the given object.
Parameters:
other (Var | Any): The object to compare with.
Returns:
BooleanVar: A BooleanVar object representing the result of the comparison.
"""
from .number import equal_operation
return ~equal_operation(self, other)
def bool(self) -> BooleanVar:
"""Convert the var to a boolean.
Returns:
The boolean var.
"""
from .number import boolify
return boolify(self)
def is_none(self) -> BooleanVar:
"""Check if the var is None.
Returns:
A BooleanVar object representing the result of the check.
"""
from .number import is_not_none_operation
return ~is_not_none_operation(self)
def is_not_none(self) -> BooleanVar:
"""Check if the var is not None.
Returns:
A BooleanVar object representing the result of the check.
"""
from .number import is_not_none_operation
return is_not_none_operation(self)
def __and__(
self, other: Var[OTHER_VAR_TYPE] | Any
) -> Var[VAR_TYPE | OTHER_VAR_TYPE]:
"""Perform a logical AND operation on the current instance and another variable.
Args:
other: The variable to perform the logical AND operation with.
Returns:
A `BooleanVar` object representing the result of the logical AND operation.
"""
return and_operation(self, other)
def __rand__(
self, other: Var[OTHER_VAR_TYPE] | Any
) -> Var[VAR_TYPE | OTHER_VAR_TYPE]:
"""Perform a logical AND operation on the current instance and another variable.
Args:
other: The variable to perform the logical AND operation with.
Returns:
A `BooleanVar` object representing the result of the logical AND operation.
"""
return and_operation(other, self)
def __or__(
self, other: Var[OTHER_VAR_TYPE] | Any
) -> Var[VAR_TYPE | OTHER_VAR_TYPE]:
"""Perform a logical OR operation on the current instance and another variable.
Args:
other: The variable to perform the logical OR operation with.
Returns:
A `BooleanVar` object representing the result of the logical OR operation.
"""
return or_operation(self, other)
def __ror__(
self, other: Var[OTHER_VAR_TYPE] | Any
) -> Var[VAR_TYPE | OTHER_VAR_TYPE]:
"""Perform a logical OR operation on the current instance and another variable.
Args:
other: The variable to perform the logical OR operation with.
Returns:
A `BooleanVar` object representing the result of the logical OR operation.
"""
return or_operation(other, self)
def __invert__(self) -> BooleanVar:
"""Perform a logical NOT operation on the current instance.
Returns:
A `BooleanVar` object representing the result of the logical NOT operation.
"""
return ~self.bool()
def to_string(self, use_json: bool = True) -> StringVar:
"""Convert the var to a string.
Args:
use_json: Whether to use JSON stringify. If False, uses Object.prototype.toString.
Returns:
The string var.
"""
from .function import JSON_STRINGIFY, PROTOTYPE_TO_STRING
from .sequence import StringVar
return (
JSON_STRINGIFY.call(self).to(StringVar)
if use_json
else PROTOTYPE_TO_STRING.call(self).to(StringVar)
)
def _as_ref(self) -> Var:
"""Get a reference to the var.
Returns:
The reference to the var.
"""
return Var(
_js_expr=f"refs[{Var.create(str(self))}]",
_var_data=VarData(
imports={
f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")]
}
),
).to(str)
def js_type(self) -> StringVar:
"""Returns the javascript type of the object.
This method uses the `typeof` function from the `FunctionStringVar` class
to determine the type of the object.
Returns:
StringVar: A string variable representing the type of the object.
"""
from .function import FunctionStringVar
from .sequence import StringVar
type_of = FunctionStringVar("typeof")
return type_of.call(self).to(StringVar)
def _without_data(self):
"""Create a copy of the var without the data.
Returns:
The var without the data.
"""
return dataclasses.replace(self, _var_data=None)
def _decode(self) -> Any:
"""Decode Var as a python value.
Note that Var with state set cannot be decoded python-side and will be
returned as full_name.
Returns:
The decoded value or the Var name.
"""
if isinstance(self, LiteralVar):
return self._var_value
try:
return json.loads(str(self))
except ValueError:
return str(self)
@property
def _var_state(self) -> str:
"""Compat method for getting the state.
Returns:
The state name associated with the var.
"""
var_data = self._get_all_var_data()
return var_data.state if var_data else ""
@overload
@classmethod
def range(cls, stop: int | NumberVar, /) -> ArrayVar[Sequence[int]]: ...
@overload
@classmethod
def range(
cls,
start: int | NumberVar,
end: int | NumberVar,
step: int | NumberVar = 1,
/,
) -> ArrayVar[Sequence[int]]: ...
@classmethod
def range(
cls,
first_endpoint: int | NumberVar,
second_endpoint: int | NumberVar | None = None,
step: int | NumberVar | None = None,
) -> ArrayVar[Sequence[int]]:
"""Create a range of numbers.
Args:
first_endpoint: The end of the range if second_endpoint is not provided, otherwise the start of the range.
second_endpoint: The end of the range.
step: The step of the range.
Returns:
The range of numbers.
"""
from .sequence import ArrayVar
return ArrayVar.range(first_endpoint, second_endpoint, step)
if not TYPE_CHECKING:
def __getitem__(self, key: Any) -> Var:
"""Get the item from the var.
Args:
key: The key to get.
Raises:
UntypedVarError: If the var type is Any.
TypeError: If the var type is Any.
# noqa: DAR101 self
"""
if self._var_type is Any:
raise exceptions.UntypedVarError(
self,
f"access the item '{key}'",
)
msg = f"Var of type {self._var_type} does not support item access."
raise TypeError(msg)
def __getattr__(self, name: str):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Raises:
VarAttributeError: If the attribute does not exist.
UntypedVarError: If the var type is Any.
TypeError: If the var type is Any.
# noqa: DAR101 self
"""
if name.startswith("_"):
msg = f"Attribute {name} not found."
raise VarAttributeError(msg)
if name == "contains":
msg = f"Var of type {self._var_type} does not support contains check."
raise TypeError(msg)
if name == "reverse":
msg = "Cannot reverse non-list var."
raise TypeError(msg)
if self._var_type is Any:
raise exceptions.UntypedVarError(
self,
f"access the attribute '{name}'",
)
msg = f"The State var {escape(self._js_expr)} of type {escape(str(self._var_type))} has no attribute '{name}' or may have been annotated wrongly."
raise VarAttributeError(msg)
def __bool__(self) -> bool:
"""Raise exception if using Var in a boolean context.
Raises:
VarTypeError: when attempting to bool-ify the Var.
# noqa: DAR101 self
"""
msg = (
f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. "
"Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)."
)
raise VarTypeError(msg)
def __iter__(self) -> Any:
"""Raise exception if using Var in an iterable context.
Raises:
VarTypeError: when attempting to iterate over the Var.
# noqa: DAR101 self
"""
msg = f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`."
raise VarTypeError(msg)
def __contains__(self, _: Any) -> Var:
"""Override the 'in' operator to alert the user that it is not supported.
Raises:
VarTypeError: the operation is not supported
# noqa: DAR101 self
"""
msg = (
"'in' operator not supported for Var types, use Var.contains() instead."
)
raise VarTypeError(msg)
OUTPUT = TypeVar("OUTPUT", bound=Var)
VAR_SUBCLASS = TypeVar("VAR_SUBCLASS", bound=Var)
VAR_INSIDE = TypeVar("VAR_INSIDE")
class ToOperation:
"""A var operation that converts a var to another type."""
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
from .object import ObjectVar
if isinstance(self, ObjectVar) and name != "_js_expr":
return ObjectVar.__getattr__(self, name)
return getattr(self._original, name)
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_js_expr")
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash(self._original)
def _get_all_var_data(self) -> VarData | None:
"""Get all the var data.
Returns:
The var data.
"""
return VarData.merge(
self._original._get_all_var_data(),
self._var_data,
)
@classmethod
def create(
cls,
value: Var,
_var_type: GenericType | None = None,
_var_data: VarData | None = None,
):
"""Create a ToOperation.
Args:
value: The value of the var.
_var_type: The type of the Var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The ToOperation.
"""
return cls(
_js_expr="", # pyright: ignore [reportCallIssue]
_var_data=_var_data, # pyright: ignore [reportCallIssue]
_var_type=_var_type or cls._default_var_type, # pyright: ignore [reportCallIssue, reportAttributeAccessIssue]
_original=value, # pyright: ignore [reportCallIssue]
)
class LiteralVar(Var[VAR_TYPE]):
"""Base class for immutable literal vars."""
def __init_subclass__(cls, **kwargs):
"""Initialize the subclass.
Args:
**kwargs: Additional keyword arguments.
Raises:
TypeError: If the LiteralVar subclass does not have a corresponding Var subclass.
"""
super().__init_subclass__(**kwargs)
bases = cls.__bases__
bases_normalized = [
base if isinstance(base, type) else get_origin(base) for base in bases
]
possible_bases = [
base
for base in bases_normalized
if safe_issubclass(base, Var) and base != LiteralVar
]
if not possible_bases:
msg = f"LiteralVar subclass {cls} must have a base class that is a subclass of Var and not LiteralVar."
raise TypeError(msg)
var_subclasses = [
var_subclass
for var_subclass in _var_subclasses
if var_subclass.var_subclass in possible_bases
]
if not var_subclasses:
msg = f"LiteralVar {cls} must have a base class annotated with `python_types`."
raise TypeError(msg)
if len(var_subclasses) != 1:
msg = f"LiteralVar {cls} must have exactly one base class annotated with `python_types`."
raise TypeError(msg)
var_subclass = var_subclasses[0]
# Remove the old subclass, happens because __init_subclass__ is called twice
# for each subclass. This is because of __slots__ in dataclasses.
for var_literal_subclass in list(_var_literal_subclasses):
if var_literal_subclass[1] is var_subclass:
_var_literal_subclasses.remove(var_literal_subclass)
_var_literal_subclasses.append((cls, var_subclass))
@classmethod
def _create_literal_var(
cls,
value: Any,
_var_data: VarData | None = None,
) -> Var:
"""Create a var from a value.
Args:
value: The value to create the var from.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
Raises:
TypeError: If the value is not a supported type for LiteralVar.
"""
from .object import LiteralObjectVar
from .sequence import ArrayVar, LiteralStringVar
if isinstance(value, Var):
if _var_data is None:
return value
return value._replace(merge_var_data=_var_data)
for literal_subclass, var_subclass in _var_literal_subclasses[::-1]:
if isinstance(value, var_subclass.python_types):
return literal_subclass.create(value, _var_data=_var_data)
if (
(as_var_method := getattr(value, "_as_var", None)) is not None
and callable(as_var_method)
and isinstance((resulting_var := as_var_method()), Var)
):
return resulting_var
from reflex_base.event import EventHandler
from reflex_base.utils.format import get_event_handler_parts
if isinstance(value, EventHandler):
return Var(_js_expr=".".join(filter(None, get_event_handler_parts(value))))
serialized_value = serializers.serialize(value)
if serialized_value is not None:
if isinstance(serialized_value, Mapping):
return LiteralObjectVar.create(
serialized_value,
_var_type=type(value),
_var_data=_var_data,
)
if isinstance(serialized_value, str):
return LiteralStringVar.create(
serialized_value, _var_type=type(value), _var_data=_var_data
)
return LiteralVar.create(serialized_value, _var_data=_var_data)
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return LiteralObjectVar.create(
{
k.name: (None if callable(v := getattr(value, k.name)) else v)
for k in dataclasses.fields(value)
},
_var_type=type(value),
_var_data=_var_data,
)
if isinstance(value, range):
return ArrayVar.range(value.start, value.stop, value.step)
msg = f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}."
raise TypeError(msg)
if not TYPE_CHECKING:
create = _create_literal_var
def __post_init__(self):
"""Post-initialize the var."""
@classmethod
def _get_all_var_data_without_creating_var(
cls,
value: Any,
) -> VarData | None:
return cls.create(value)._get_all_var_data()
@classmethod
def _get_all_var_data_without_creating_var_dispatch(
cls,
value: Any,
) -> VarData | None:
"""Get all the var data without creating a var.
Args:
value: The value to get the var data from.
Returns:
The var data or None.
Raises:
TypeError: If the value is not a supported type for LiteralVar.
"""
from .object import LiteralObjectVar
from .sequence import LiteralStringVar
if isinstance(value, Var):
return value._get_all_var_data()
for literal_subclass, var_subclass in _var_literal_subclasses[::-1]:
if isinstance(value, var_subclass.python_types):
return literal_subclass._get_all_var_data_without_creating_var(value)
if (
(as_var_method := getattr(value, "_as_var", None)) is not None
and callable(as_var_method)
and isinstance((resulting_var := as_var_method()), Var)
):
return resulting_var._get_all_var_data()
from reflex_base.event import EventHandler
from reflex_base.utils.format import get_event_handler_parts
if isinstance(value, EventHandler):
return Var(
_js_expr=".".join(filter(None, get_event_handler_parts(value)))
)._get_all_var_data()
serialized_value = serializers.serialize(value)
if serialized_value is not None:
if isinstance(serialized_value, Mapping):
return LiteralObjectVar._get_all_var_data_without_creating_var(
serialized_value
)
if isinstance(serialized_value, str):
return LiteralStringVar._get_all_var_data_without_creating_var(
serialized_value
)
return LiteralVar._get_all_var_data_without_creating_var_dispatch(
serialized_value
)
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return LiteralObjectVar._get_all_var_data_without_creating_var({
k.name: (None if callable(v := getattr(value, k.name)) else v)
for k in dataclasses.fields(value)
})
if isinstance(value, range):
return None
msg = f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}."
raise TypeError(msg)
@property
def _var_value(self) -> Any:
msg = "LiteralVar subclasses must implement the _var_value property."
raise NotImplementedError(msg)
def json(self) -> str:
"""Serialize the var to a JSON string.
Raises:
NotImplementedError: If the method is not implemented.
"""
msg = "LiteralVar subclasses must implement the json method."
raise NotImplementedError(msg)
@serializers.serializer
def serialize_literal(value: LiteralVar):
"""Serialize a Literal type.
Args:
value: The Literal to serialize.
Returns:
The serialized Literal.
"""
return value._var_value
def get_python_literal(value: LiteralVar | Any) -> Any | None:
"""Get the Python literal value.
Args:
value: The value to get the Python literal value of.
Returns:
The Python literal value.
"""
if isinstance(value, LiteralVar):
return value._var_value
if isinstance(value, Var):
return None
return value
P = ParamSpec("P")
T = TypeVar("T")
# NoReturn is used to match CustomVarOperationReturn with no type hint.
@overload
def var_operation( # pyright: ignore [reportOverlappingOverload]
func: Callable[P, CustomVarOperationReturn[NoReturn]],
) -> Callable[P, Var]: ...
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[None]],
) -> Callable[P, NoneVar]: ...
@overload
def var_operation( # pyright: ignore [reportOverlappingOverload]
func: Callable[P, CustomVarOperationReturn[bool]]
| Callable[P, CustomVarOperationReturn[bool | None]],
) -> Callable[P, BooleanVar]: ...
NUMBER_T = TypeVar("NUMBER_T", int, float, int | float)
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[NUMBER_T]]
| Callable[P, CustomVarOperationReturn[NUMBER_T | None]],
) -> Callable[P, NumberVar[NUMBER_T]]: ...
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[str]]
| Callable[P, CustomVarOperationReturn[str | None]],
) -> Callable[P, StringVar]: ...
LIST_T = TypeVar("LIST_T", bound=Sequence)
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[LIST_T]]
| Callable[P, CustomVarOperationReturn[LIST_T | None]],
) -> Callable[P, ArrayVar[LIST_T]]: ...
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping)
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]]
| Callable[P, CustomVarOperationReturn[OBJECT_TYPE | None]],
) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ...
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[T]]
| Callable[P, CustomVarOperationReturn[T | None]],
) -> Callable[P, Var[T]]: ...
def var_operation( # pyright: ignore [reportInconsistentOverload]
func: Callable[P, CustomVarOperationReturn[T]],
) -> Callable[P, Var[T]]:
"""Decorator for creating a var operation.
Example:
```python
@var_operation
def add(a: NumberVar, b: NumberVar):
return custom_var_operation(f"{a} + {b}")
```
Args:
func: The function to decorate.
Returns:
The decorated function.
"""
func_args = list(inspect.signature(func).parameters)
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Var[T]:
args_vars = {
func_args[i]: (LiteralVar.create(arg) if not isinstance(arg, Var) else arg)
for i, arg in enumerate(args)
}
kwargs_vars = {
key: LiteralVar.create(value) if not isinstance(value, Var) else value
for key, value in kwargs.items()
}
return CustomVarOperation.create(
name=func.__name__,
args=tuple(list(args_vars.items()) + list(kwargs_vars.items())),
return_var=func(*args_vars.values(), **kwargs_vars), # pyright: ignore [reportCallIssue, reportReturnType]
).guess_type()
return wrapper
def figure_out_type(value: Any) -> types.GenericType:
"""Figure out the type of the value.
Args:
value: The value to figure out the type of.
Returns:
The type of the value.
"""
if isinstance(value, (list, set, tuple, Mapping, Var)):
if isinstance(value, Var):
return value._var_type
if has_args(value_type := type(value)):
return value_type
if isinstance(value, list):
if not value:
return Sequence[NoReturn]
return Sequence[unionize(*{figure_out_type(v) for v in value[:100]})]
if isinstance(value, set):
return set[unionize(*{figure_out_type(v) for v in value})]
if isinstance(value, tuple):
if not value:
return tuple[NoReturn, ...]
if len(value) <= 5:
return tuple[tuple(figure_out_type(v) for v in value)]
return tuple[unionize(*{figure_out_type(v) for v in value[:100]}), ...]
if isinstance(value, Mapping):
if not value:
return Mapping[NoReturn, NoReturn]
return Mapping[
unionize(*{figure_out_type(k) for k in list(value.keys())[:100]}),
unionize(*{figure_out_type(v) for v in list(value.values())[:100]}),
]
return type(value)
GLOBAL_CACHE = {}
class cached_property: # noqa: N801
"""A cached property that caches the result of the function."""
def __init__(self, func: Callable):
"""Initialize the cached_property.
Args:
func: The function to cache.
"""
self._func = func
self._attrname = None
def __set_name__(self, owner: Any, name: str):
"""Set the name of the cached property.
Args:
owner: The owner of the cached property.
name: The name of the cached property.
Raises:
TypeError: If the cached property is assigned to two different names.
"""
if self._attrname is None:
self._attrname = name
original_del = getattr(owner, "__del__", None)
def delete_property(this: Any):
"""Delete the cached property.
Args:
this: The object to delete the cached property from.
"""
cached_field_name = "_reflex_cache_" + name
try:
unique_id = object.__getattribute__(this, cached_field_name)
except AttributeError:
if original_del is not None:
original_del(this)
return
GLOBAL_CACHE.pop(unique_id, None)
if original_del is not None:
original_del(this)
owner.__del__ = delete_property
elif name != self._attrname:
msg = (
"Cannot assign the same cached_property to two different names "
f"({self._attrname!r} and {name!r})."
)
raise TypeError(msg)
def __get__(self, instance: Any, owner: type | None = None):
"""Get the cached property.
Args:
instance: The instance to get the cached property from.
owner: The owner of the cached property.
Returns:
The cached property.
Raises:
TypeError: If the class does not have __set_name__.
"""
if self._attrname is None:
msg = "Cannot use cached_property on a class without __set_name__."
raise TypeError(msg)
cached_field_name = "_reflex_cache_" + self._attrname
try:
unique_id = object.__getattribute__(instance, cached_field_name)
except AttributeError:
unique_id = uuid.uuid4().int
object.__setattr__(instance, cached_field_name, unique_id)
if unique_id not in GLOBAL_CACHE:
GLOBAL_CACHE[unique_id] = self._func(instance)
return GLOBAL_CACHE[unique_id]
cached_property_no_lock = cached_property
class VarProtocol(Protocol):
"""A protocol for Var."""
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]
@property
def _js_expr(self) -> str: ...
@property
def _var_type(self) -> types.GenericType: ...
@property
def _var_data(self) -> VarData: ...
class CachedVarOperation:
"""Base class for cached var operations to lower boilerplate code."""
def __post_init__(self):
"""Post-initialize the CachedVarOperation."""
object.__delattr__(self, "_js_expr")
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute.
"""
if name == "_js_expr":
return self._cached_var_name
parent_classes = inspect.getmro(type(self))
next_class = parent_classes[parent_classes.index(CachedVarOperation) + 1]
return next_class.__getattr__(self, name)
def _get_all_var_data(self) -> VarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
@cached_property_no_lock
def _cached_get_all_var_data(self: VarProtocol) -> VarData | None:
"""Get the cached VarData.
Returns:
The cached VarData.
"""
return VarData.merge(
*(
value._get_all_var_data() if isinstance(value, Var) else None
for value in (
getattr(self, field.name) for field in dataclasses.fields(self)
)
),
self._var_data,
)
def __hash__(self: DataclassInstance) -> int:
"""Calculate the hash of the object.
Returns:
The hash of the object.
"""
return hash((
type(self).__name__,
*[
getattr(self, field.name)
for field in dataclasses.fields(self)
if field.name not in ["_js_expr", "_var_data", "_var_type"]
],
))
def and_operation(
a: Var[VAR_TYPE] | Any, b: Var[OTHER_VAR_TYPE] | Any
) -> Var[VAR_TYPE | OTHER_VAR_TYPE]:
"""Perform a logical AND operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical AND operation.
"""
return _and_operation(a, b)
@var_operation
def _and_operation(a: Var, b: Var):
"""Perform a logical AND operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical AND operation.
"""
return var_operation_return(
js_expression=f"({a} && {b})",
var_type=unionize(a._var_type, b._var_type),
)
def or_operation(
a: Var[VAR_TYPE] | Any, b: Var[OTHER_VAR_TYPE] | Any
) -> Var[VAR_TYPE | OTHER_VAR_TYPE]:
"""Perform a logical OR operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical OR operation.
"""
return _or_operation(a, b)
@var_operation
def _or_operation(a: Var, b: Var):
"""Perform a logical OR operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical OR operation.
"""
return var_operation_return(
js_expression=f"({a} || {b})",
var_type=unionize(a._var_type, b._var_type),
)
RETURN_TYPE = TypeVar("RETURN_TYPE")
DICT_KEY = TypeVar("DICT_KEY")
DICT_VAL = TypeVar("DICT_VAL")
LIST_INSIDE = TypeVar("LIST_INSIDE")
class FakeComputedVarBaseClass(property):
"""A fake base class for ComputedVar to avoid inheriting from property."""
__pydantic_run_validation__ = False
def is_computed_var(obj: Any) -> TypeGuard[ComputedVar]:
"""Check if the object is a ComputedVar.
Args:
obj: The object to check.
Returns:
Whether the object is a ComputedVar.
"""
return isinstance(obj, FakeComputedVarBaseClass)
@dataclasses.dataclass(
eq=False,
frozen=True,
slots=True,
)
class ComputedVar(Var[RETURN_TYPE]):
"""A field with computed getters."""
# Whether to track dependencies and cache computed values
_cache: bool = dataclasses.field(default=False)
# Whether the computed var is a backend var
_backend: bool = dataclasses.field(default=False)
# The initial value of the computed var
_initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset())
# Explicit var dependencies to track
_static_deps: dict[str | None, set[str]] = dataclasses.field(default_factory=dict)
# Whether var dependencies should be auto-determined
_auto_deps: bool = dataclasses.field(default=True)
# Interval at which the computed var should be updated
_update_interval: datetime.timedelta | None = dataclasses.field(default=None)
_fget: Callable[[BaseState], RETURN_TYPE] = dataclasses.field(
default_factory=lambda: lambda _: None
) # pyright: ignore [reportAssignmentType]
_name: str = dataclasses.field(default="")
def __init__(
self,
fget: Callable[[BASE_STATE], RETURN_TYPE],
initial_value: RETURN_TYPE | types.Unset = types.Unset(),
cache: bool = True,
deps: list[str | Var] | None = None,
auto_deps: bool = True,
interval: int | datetime.timedelta | None = None,
backend: bool | None = None,
**kwargs,
):
"""Initialize a ComputedVar.
Args:
fget: The getter function.
initial_value: The initial value of the computed var.
cache: Whether to cache the computed value.
deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated.
backend: Whether the computed var is a backend var.
**kwargs: additional attributes to set on the instance
Raises:
TypeError: If the computed var dependencies are not Var instances or var names.
UntypedComputedVarError: If the computed var is untyped.
"""
hint = kwargs.pop("return_type", None) or get_type_hints(fget).get(
"return", Any
)
if hint is Any:
raise UntypedComputedVarError(var_name=fget.__name__)
is_using_fget_name = "_js_expr" not in kwargs
js_expr = kwargs.pop("_js_expr", fget.__name__ + FIELD_MARKER)
kwargs.setdefault("_var_type", hint)
Var.__init__(
self,
_js_expr=js_expr,
_var_type=kwargs.pop("_var_type"),
_var_data=kwargs.pop(
"_var_data",
VarData(field_name=fget.__name__) if is_using_fget_name else None,
),
)
if kwargs:
msg = f"Unexpected keyword arguments: {tuple(kwargs)}"
raise TypeError(msg)
if backend is None:
backend = fget.__name__.startswith("_")
object.__setattr__(self, "_backend", backend)
object.__setattr__(self, "_initial_value", initial_value)
object.__setattr__(self, "_cache", cache)
object.__setattr__(self, "_name", fget.__name__)
if isinstance(interval, int):
interval = datetime.timedelta(seconds=interval)
object.__setattr__(self, "_update_interval", interval)
object.__setattr__(
self,
"_static_deps",
self._calculate_static_deps(deps),
)
object.__setattr__(self, "_auto_deps", auto_deps)
object.__setattr__(self, "_fget", fget)
def _calculate_static_deps(
self,
deps: list[str | Var] | dict[str | None, set[str]] | None = None,
) -> dict[str | None, set[str]]:
"""Calculate the static dependencies of the computed var from user input or existing dependencies.
Args:
deps: The user input dependencies or existing dependencies.
Returns:
The static dependencies.
"""
if isinstance(deps, dict):
# Assume a dict is coming from _replace, so no special processing.
return deps
static_deps = {}
if deps is not None:
for dep in deps:
static_deps = self._add_static_dep(dep, static_deps)
return static_deps
def _add_static_dep(
self, dep: str | Var, deps: dict[str | None, set[str]] | None = None
) -> dict[str | None, set[str]]:
"""Add a static dependency to the computed var or existing dependency set.
Args:
dep: The dependency to add.
deps: The existing dependency set.
Returns:
The updated dependency set.
Raises:
TypeError: If the computed var dependencies are not Var instances or var names.
"""
if deps is None:
deps = self._static_deps
if isinstance(dep, Var):
state_name = (
all_var_data.state
if (all_var_data := dep._get_all_var_data()) and all_var_data.state
else None
)
if all_var_data is not None:
var_name = all_var_data.field_name
else:
var_name = dep._js_expr
deps.setdefault(state_name, set()).add(var_name)
elif isinstance(dep, str) and dep != "":
deps.setdefault(None, set()).add(dep)
else:
msg = "ComputedVar dependencies must be Var instances or var names (non-empty strings)."
raise TypeError(msg)
return deps
@override
def _replace(
self,
merge_var_data: VarData | None = None,
**kwargs: Any,
) -> Self:
"""Replace the attributes of the ComputedVar.
Args:
merge_var_data: VarData to merge into the existing VarData.
**kwargs: Var fields to update.
Returns:
The new ComputedVar instance.
Raises:
TypeError: If kwargs contains keys that are not allowed.
"""
if "deps" in kwargs:
kwargs["deps"] = self._calculate_static_deps(kwargs["deps"])
field_values = {
"fget": kwargs.pop("fget", self._fget),
"initial_value": kwargs.pop("initial_value", self._initial_value),
"cache": kwargs.pop("cache", self._cache),
"deps": kwargs.pop("deps", copy.copy(self._static_deps)),
"auto_deps": kwargs.pop("auto_deps", self._auto_deps),
"interval": kwargs.pop("interval", self._update_interval),
"backend": kwargs.pop("backend", self._backend),
"_js_expr": kwargs.pop("_js_expr", self._js_expr),
"_var_type": kwargs.pop("_var_type", self._var_type),
"_var_data": kwargs.pop(
"_var_data", VarData.merge(self._var_data, merge_var_data)
),
"return_type": kwargs.pop("return_type", self._var_type),
}
if kwargs:
unexpected_kwargs = ", ".join(kwargs.keys())
msg = f"Unexpected keyword arguments: {unexpected_kwargs}"
raise TypeError(msg)
return type(self)(**field_values)
@property
def _cache_attr(self) -> str:
"""Get the attribute used to cache the value on the instance.
Returns:
An attribute name.
"""
return f"__cached_{self._js_expr}"
@property
def _last_updated_attr(self) -> str:
"""Get the attribute used to store the last updated timestamp.
Returns:
An attribute name.
"""
return f"__last_updated_{self._js_expr}"
def needs_update(self, instance: BaseState) -> bool:
"""Check if the computed var needs to be updated.
Args:
instance: The state instance that the computed var is attached to.
Returns:
True if the computed var needs to be updated, False otherwise.
"""
if self._update_interval is None:
return False
last_updated = getattr(instance, self._last_updated_attr, None)
if last_updated is None:
return True
return datetime.datetime.now() - last_updated > self._update_interval
@overload
def __get__(
self: ComputedVar[bool],
instance: None,
owner: type,
) -> BooleanVar: ...
@overload
def __get__(
self: ComputedVar[int] | ComputedVar[float],
instance: None,
owner: type,
) -> NumberVar: ...
@overload
def __get__(
self: ComputedVar[str],
instance: None,
owner: type,
) -> StringVar: ...
@overload
def __get__(
self: ComputedVar[MAPPING_TYPE],
instance: None,
owner: type,
) -> ObjectVar[MAPPING_TYPE]: ...
@overload
def __get__(
self: ComputedVar[list[LIST_INSIDE]],
instance: None,
owner: type,
) -> ArrayVar[list[LIST_INSIDE]]: ...
@overload
def __get__(
self: ComputedVar[tuple[LIST_INSIDE, ...]],
instance: None,
owner: type,
) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ...
@overload
def __get__(
self: ComputedVar[SQLA_TYPE],
instance: None,
owner: type,
) -> ObjectVar[SQLA_TYPE]: ...
if TYPE_CHECKING:
@overload
def __get__(
self: ComputedVar[DATACLASS_TYPE], instance: None, owner: Any
) -> ObjectVar[DATACLASS_TYPE]: ...
@overload
def __get__(self, instance: None, owner: type) -> ComputedVar[RETURN_TYPE]: ...
@overload
def __get__(self, instance: BaseState, owner: type) -> RETURN_TYPE: ...
def __get__(self, instance: BaseState | None, owner: type):
"""Get the ComputedVar value.
If the value is already cached on the instance, return the cached value.
Args:
instance: the instance of the class accessing this computed var.
owner: the class that this descriptor is attached to.
Returns:
The value of the var for the given instance.
"""
if instance is None:
state_where_defined = owner
while self._name in state_where_defined.inherited_vars:
state_where_defined = state_where_defined.get_parent_state()
field_name = (
format_state_name(state_where_defined.get_full_name())
+ "."
+ self._js_expr
)
return dispatch(
field_name,
var_data=VarData.from_state(state_where_defined, self._name),
result_var_type=self._var_type,
existing_var=self,
)
if not self._cache:
value = self.fget(instance)
else:
# handle caching
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
# Set cache attr on state instance.
setattr(instance, self._cache_attr, self.fget(instance))
# Ensure the computed var gets serialized to redis.
instance._was_touched = True
# Set the last updated timestamp on the state instance.
setattr(instance, self._last_updated_attr, datetime.datetime.now())
value = getattr(instance, self._cache_attr)
self._check_deprecated_return_type(instance, value)
return value
def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None:
if not _isinstance(value, self._var_type, nested=1, treat_var_as_type=False):
console.error(
f"Computed var '{type(instance).__name__}.{self._name}' must return"
f" a value of type '{escape(str(self._var_type))}', got '{value!s}' of type {type(value)}."
)
def _deps(
self,
objclass: type[BaseState],
obj: FunctionType | CodeType | None = None,
) -> dict[str, set[str]]:
"""Determine var dependencies of this ComputedVar.
Save references to attributes accessed on "self" or other fetched states.
Recursively called when the function makes a method call on "self" or
define comprehensions or nested functions that may reference "self".
Args:
objclass: the class obj this ComputedVar is attached to.
obj: the object to disassemble (defaults to the fget function).
Returns:
A dictionary mapping state names to the set of variable names
accessed by the given obj.
"""
from .dep_tracking import DependencyTracker
d = {}
if self._static_deps:
d.update(self._static_deps)
# None is a placeholder for the current state class.
if None in d:
d[objclass.get_full_name()] = d.pop(None)
if not self._auto_deps:
return d
if obj is None:
fget = self._fget
if fget is not None:
obj = cast(FunctionType, fget)
else:
return d
try:
return DependencyTracker(
func=obj, state_cls=objclass, dependencies=d
).dependencies
except Exception as e:
console.warn(
"Failed to automatically determine dependencies for computed var "
f"{objclass.__name__}.{self._name}: {e}. "
"Set auto_deps=False and provide accurate deps=['var1', 'var2'] to suppress this warning."
)
return d
def mark_dirty(self, instance: BaseState) -> None:
"""Mark this ComputedVar as dirty.
Args:
instance: the state instance that needs to recompute the value.
"""
with contextlib.suppress(AttributeError):
delattr(instance, self._cache_attr)
def add_dependency(self, objclass: type[BaseState], dep: Var):
"""Explicitly add a dependency to the ComputedVar.
After adding the dependency, when the `dep` changes, this computed var
will be marked dirty.
Args:
objclass: The class obj this ComputedVar is attached to.
dep: The dependency to add.
Raises:
VarDependencyError: If the dependency is not a Var instance with a
state and field name
"""
if all_var_data := dep._get_all_var_data():
state_name = all_var_data.state
if state_name:
var_name = all_var_data.field_name
if var_name:
self._static_deps.setdefault(state_name, set()).add(var_name)
target_state_class = objclass.get_root_state().get_class_substate(
state_name
)
target_state_class._var_dependencies.setdefault(
var_name, set()
).add((
objclass.get_full_name(),
self._name,
))
target_state_class._potentially_dirty_states.add(
objclass.get_full_name()
)
return
msg = (
"ComputedVar dependencies must be Var instances with a state and "
f"field name, got {dep!r}."
)
raise VarDependencyError(msg)
def _determine_var_type(self) -> type:
"""Get the type of the var.
Returns:
The type of the var.
"""
hints = get_type_hints(self._fget)
if "return" in hints:
return hints["return"]
return Any # pyright: ignore [reportReturnType]
@property
def __class__(self) -> type:
"""Get the class of the var.
Returns:
The class of the var.
"""
return FakeComputedVarBaseClass
@property
def fget(self) -> Callable[[BaseState], RETURN_TYPE]:
"""Get the getter function.
Returns:
The getter function.
"""
return self._fget
class DynamicRouteVar(ComputedVar[str | list[str]]):
"""A ComputedVar that represents a dynamic route."""
async def _default_async_computed_var(_self: BaseState) -> Any: # noqa: RUF029
return None
@dataclasses.dataclass(
eq=False,
frozen=True,
init=False,
slots=True,
)
class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
"""A computed var that wraps a coroutinefunction."""
_fget: Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]] = (
dataclasses.field(default=_default_async_computed_var)
)
@overload
def __get__(
self: AsyncComputedVar[bool],
instance: None,
owner: type,
) -> BooleanVar: ...
@overload
def __get__(
self: AsyncComputedVar[int] | ComputedVar[float],
instance: None,
owner: type,
) -> NumberVar: ...
@overload
def __get__(
self: AsyncComputedVar[str],
instance: None,
owner: type,
) -> StringVar: ...
@overload
def __get__(
self: AsyncComputedVar[MAPPING_TYPE],
instance: None,
owner: type,
) -> ObjectVar[MAPPING_TYPE]: ...
@overload
def __get__(
self: AsyncComputedVar[list[LIST_INSIDE]],
instance: None,
owner: type,
) -> ArrayVar[list[LIST_INSIDE]]: ...
@overload
def __get__(
self: AsyncComputedVar[tuple[LIST_INSIDE, ...]],
instance: None,
owner: type,
) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ...
@overload
def __get__(
self: AsyncComputedVar[SQLA_TYPE],
instance: None,
owner: type,
) -> ObjectVar[SQLA_TYPE]: ...
if TYPE_CHECKING:
@overload
def __get__(
self: AsyncComputedVar[DATACLASS_TYPE], instance: None, owner: Any
) -> ObjectVar[DATACLASS_TYPE]: ...
@overload
def __get__(self, instance: None, owner: type) -> AsyncComputedVar[RETURN_TYPE]: ...
@overload
def __get__(
self, instance: BaseState, owner: type
) -> Coroutine[None, None, RETURN_TYPE]: ...
def __get__(
self, instance: BaseState | None, owner
) -> Var | Coroutine[None, None, RETURN_TYPE]:
"""Get the ComputedVar value.
If the value is already cached on the instance, return the cached value.
Args:
instance: the instance of the class accessing this computed var.
owner: the class that this descriptor is attached to.
Returns:
The value of the var for the given instance.
"""
if instance is None:
return super(AsyncComputedVar, self).__get__(instance, owner)
if not self._cache:
async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE:
value = await self.fget(instance)
self._check_deprecated_return_type(instance, value)
return value
return _awaitable_result()
# handle caching
async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE:
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
# Set cache attr on state instance.
setattr(instance, self._cache_attr, await self.fget(instance))
# Ensure the computed var gets serialized to redis.
instance._was_touched = True
# Set the last updated timestamp on the state instance.
setattr(instance, self._last_updated_attr, datetime.datetime.now())
value = getattr(instance, self._cache_attr)
self._check_deprecated_return_type(instance, value)
return value
return _awaitable_result()
@property
def fget(self) -> Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]]:
"""Get the getter function.
Returns:
The getter function.
"""
return self._fget
if TYPE_CHECKING:
BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
class _ComputedVarDecorator(Protocol):
"""A protocol for the ComputedVar decorator."""
@overload
def __call__(
self,
fget: Callable[[BASE_STATE], Coroutine[Any, Any, RETURN_TYPE]],
) -> AsyncComputedVar[RETURN_TYPE]: ...
@overload
def __call__(
self,
fget: Callable[[BASE_STATE], RETURN_TYPE],
) -> ComputedVar[RETURN_TYPE]: ...
def __call__(
self,
fget: Callable[[BASE_STATE], Any],
) -> ComputedVar[Any]: ...
@overload
def computed_var(
fget: None = None,
initial_value: Any | types.Unset = types.Unset(),
cache: bool = True,
deps: list[str | Var] | None = None,
auto_deps: bool = True,
interval: datetime.timedelta | int | None = None,
backend: bool | None = None,
**kwargs,
) -> _ComputedVarDecorator: ...
@overload
def computed_var(
fget: Callable[[BASE_STATE], Coroutine[Any, Any, RETURN_TYPE]],
initial_value: RETURN_TYPE | types.Unset = types.Unset(),
cache: bool = True,
deps: list[str | Var] | None = None,
auto_deps: bool = True,
interval: datetime.timedelta | int | None = None,
backend: bool | None = None,
**kwargs,
) -> AsyncComputedVar[RETURN_TYPE]: ...
@overload
def computed_var(
fget: Callable[[BASE_STATE], RETURN_TYPE],
initial_value: RETURN_TYPE | types.Unset = types.Unset(),
cache: bool = True,
deps: list[str | Var] | None = None,
auto_deps: bool = True,
interval: datetime.timedelta | int | None = None,
backend: bool | None = None,
**kwargs,
) -> ComputedVar[RETURN_TYPE]: ...
def computed_var(
fget: Callable[[BASE_STATE], Any] | None = None,
initial_value: Any | types.Unset = types.Unset(),
cache: bool = True,
deps: list[str | Var] | None = None,
auto_deps: bool = True,
interval: datetime.timedelta | int | None = None,
backend: bool | None = None,
**kwargs,
) -> ComputedVar | Callable[[Callable[[BASE_STATE], Any]], ComputedVar]:
"""A ComputedVar decorator with or without kwargs.
Args:
fget: The getter function.
initial_value: The initial value of the computed var.
cache: Whether to cache the computed value.
deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated.
backend: Whether the computed var is a backend var.
**kwargs: additional attributes to set on the instance
Returns:
A ComputedVar instance.
Raises:
ValueError: If caching is disabled and an update interval is set.
VarDependencyError: If user supplies dependencies without caching.
ComputedVarSignatureError: If the getter function has more than one argument.
"""
if cache is False and interval is not None:
msg = "Cannot set update interval without caching."
raise ValueError(msg)
if cache is False and (deps is not None or auto_deps is False):
msg = "Cannot track dependencies without caching."
raise VarDependencyError(msg)
if fget is not None:
sign = inspect.signature(fget)
if len(sign.parameters) != 1:
raise ComputedVarSignatureError(fget.__name__, signature=str(sign))
if inspect.iscoroutinefunction(fget):
computed_var_cls = AsyncComputedVar
else:
computed_var_cls = ComputedVar
return computed_var_cls(
fget,
initial_value=initial_value,
cache=cache,
deps=deps,
auto_deps=auto_deps,
interval=interval,
backend=backend,
**kwargs,
)
def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar:
if inspect.iscoroutinefunction(fget):
computed_var_cls = AsyncComputedVar
else:
computed_var_cls = ComputedVar
return computed_var_cls(
fget,
initial_value=initial_value,
cache=cache,
deps=deps,
auto_deps=auto_deps,
interval=interval,
backend=backend,
**kwargs,
)
return wrapper
RETURN = TypeVar("RETURN")
class CustomVarOperationReturn(Var[RETURN]):
"""Base class for custom var operations."""
@classmethod
def create(
cls,
js_expression: str,
_var_type: type[RETURN] | None = None,
_var_data: VarData | None = None,
) -> CustomVarOperationReturn[RETURN]:
"""Create a CustomVarOperation.
Args:
js_expression: The JavaScript expression to evaluate.
_var_type: The type of the var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The CustomVarOperation.
"""
return CustomVarOperationReturn(
_js_expr=js_expression,
_var_type=_var_type or Any,
_var_data=_var_data,
)
def var_operation_return(
js_expression: str,
var_type: type[RETURN] | GenericType | None = None,
var_data: VarData | None = None,
) -> CustomVarOperationReturn[RETURN]:
"""Shortcut for creating a CustomVarOperationReturn.
Args:
js_expression: The JavaScript expression to evaluate.
var_type: The type of the var.
var_data: Additional hooks and imports associated with the Var.
Returns:
The CustomVarOperationReturn.
"""
return CustomVarOperationReturn.create(
js_expression,
var_type,
var_data,
)
@dataclasses.dataclass(
eq=False,
frozen=True,
slots=True,
)
class CustomVarOperation(CachedVarOperation, Var[T]):
"""Base class for custom var operations."""
_name: str = dataclasses.field(default="")
_args: tuple[tuple[str, Var], ...] = dataclasses.field(default_factory=tuple)
_return: CustomVarOperationReturn[T] = dataclasses.field(
default_factory=lambda: CustomVarOperationReturn.create("")
)
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""Get the cached var name.
Returns:
The cached var name.
"""
return str(self._return)
@cached_property_no_lock
def _cached_get_all_var_data(self) -> VarData | None:
"""Get the cached VarData.
Returns:
The cached VarData.
"""
return VarData.merge(
*(arg[1]._get_all_var_data() for arg in self._args),
self._return._get_all_var_data(),
self._var_data,
)
@classmethod
def create(
cls,
name: str,
args: tuple[tuple[str, Var], ...],
return_var: CustomVarOperationReturn[T],
_var_data: VarData | None = None,
) -> CustomVarOperation[T]:
"""Create a CustomVarOperation.
Args:
name: The name of the operation.
args: The arguments to the operation.
return_var: The return var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The CustomVarOperation.
"""
return CustomVarOperation(
_js_expr="",
_var_type=return_var._var_type,
_var_data=_var_data,
_name=name,
_args=args,
_return=return_var,
)
class NoneVar(Var[None], python_types=type(None)):
"""A var representing None."""
@dataclasses.dataclass(
eq=False,
frozen=True,
slots=True,
)
class LiteralNoneVar(LiteralVar[None], NoneVar):
"""A var representing None."""
_var_value: None = None
def json(self) -> str:
"""Serialize the var to a JSON string.
Returns:
The JSON string.
"""
return "null"
@classmethod
def _get_all_var_data_without_creating_var(cls, value: None) -> VarData | None:
return None
@classmethod
def create(
cls,
value: None = None,
_var_data: VarData | None = None,
) -> LiteralNoneVar:
"""Create a var from a value.
Args:
value: The value of the var. Must be None. Existed for compatibility with LiteralVar.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
"""
return LiteralNoneVar(
_js_expr="null",
_var_type=None,
_var_data=_var_data,
)
def get_to_operation(var_subclass: type[Var]) -> type[ToOperation]:
"""Get the ToOperation class for a given Var subclass.
Args:
var_subclass: The Var subclass.
Returns:
The ToOperation class.
Raises:
ValueError: If the ToOperation class cannot be found.
"""
possible_classes = [
saved_var_subclass.to_var_subclass
for saved_var_subclass in _var_subclasses
if saved_var_subclass.var_subclass is var_subclass
]
if not possible_classes:
msg = f"Could not find ToOperation for {var_subclass}."
raise ValueError(msg)
return possible_classes[0]
@dataclasses.dataclass(
eq=False,
frozen=True,
slots=True,
)
class StateOperation(CachedVarOperation, Var):
"""A var operation that accesses a field on an object."""
_state_name: str = dataclasses.field(default="")
_field: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""Get the cached var name.
Returns:
The cached var name.
"""
return f"{self._state_name!s}.{self._field!s}"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute.
"""
if name == "_js_expr":
return self._cached_var_name
return getattr(self._field, name)
@classmethod
def create(
cls,
state_name: str,
field: Var,
_var_data: VarData | None = None,
) -> StateOperation:
"""Create a DotOperation.
Args:
state_name: The name of the state.
field: The field of the state.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The DotOperation.
"""
return StateOperation(
_js_expr="",
_var_type=field._var_type,
_var_data=_var_data,
_state_name=state_name,
_field=field,
)
def get_uuid_string_var() -> Var:
"""Return a Var that generates a single memoized UUID via .web/utils/state.js.
useMemo with an empty dependency array ensures that the generated UUID is
consistent across re-renders of the component.
Returns:
A Var that generates a UUID at runtime.
"""
from reflex_base.utils.imports import ImportVar
from reflex_base.vars import Var
unique_uuid_var = get_unique_variable_name()
unique_uuid_var_data = VarData(
imports={
f"$/{constants.Dirs.STATE_PATH}": ImportVar(tag="generateUUID"),
"react": "useMemo",
},
hooks={f"const {unique_uuid_var} = useMemo(generateUUID, [])": None},
)
return Var(
_js_expr=unique_uuid_var,
_var_type=str,
_var_data=unique_uuid_var_data,
)
# Set of unique variable names.
USED_VARIABLES = set()
@once
def _rng():
import random
return random.Random(42)
def get_unique_variable_name() -> str:
"""Get a unique variable name.
Returns:
The unique variable name.
"""
name = "".join([_rng().choice(string.ascii_lowercase) for _ in range(8)])
if name not in USED_VARIABLES:
USED_VARIABLES.add(name)
return name
return get_unique_variable_name()
# Compile regex for finding reflex var tags.
_decode_var_pattern_re = (
rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}"
)
_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
# Defined global immutable vars.
_global_vars: dict[int, Var] = {}
dispatchers: dict[GenericType, Callable[[Var], Var]] = {}
def transform(fn: Callable[[Var], Var]) -> Callable[[Var], Var]:
"""Register a function to transform a Var.
Args:
fn: The function to register.
Returns:
The decorator.
Raises:
TypeError: If the return type of the function is not a Var.
TypeError: If the Var return type does not have a generic type.
ValueError: If a function for the generic type is already registered.
"""
types = get_type_hints(fn)
return_type = types["return"]
origin = get_origin(return_type)
if origin is not Var:
msg = f"Expected return type of {fn.__name__} to be a Var, got {origin}."
raise TypeError(msg)
generic_args = get_args(return_type)
if not generic_args:
msg = f"Expected Var return type of {fn.__name__} to have a generic type."
raise TypeError(msg)
generic_type = get_origin(generic_args[0]) or generic_args[0]
if generic_type in dispatchers:
msg = f"Function for {generic_type} already registered."
raise ValueError(msg)
dispatchers[generic_type] = fn
return fn
def dispatch(
field_name: str,
var_data: VarData,
result_var_type: GenericType,
existing_var: Var | None = None,
) -> Var:
"""Dispatch a Var to the appropriate transformation function.
Args:
field_name: The name of the field.
var_data: The VarData associated with the Var.
result_var_type: The type of the Var.
existing_var: The existing Var to transform. Optional.
Returns:
The transformed Var.
Raises:
TypeError: If the return type of the function is not a Var.
TypeError: If the Var return type does not have a generic type.
TypeError: If the first argument of the function is not a Var.
TypeError: If the first argument of the function does not have a generic type
"""
result_origin_var_type = get_origin(result_var_type) or result_var_type
if result_origin_var_type in dispatchers:
fn = dispatchers[result_origin_var_type]
fn_types = get_type_hints(fn)
fn_first_arg_type = fn_types.get(
next(iter(inspect.signature(fn).parameters.values())).name, Any
)
fn_return = fn_types.get("return", Any)
fn_return_origin = get_origin(fn_return) or fn_return
if fn_return_origin is not Var:
msg = f"Expected return type of {fn.__name__} to be a Var, got {fn_return}."
raise TypeError(msg)
fn_return_generic_args = get_args(fn_return)
if not fn_return_generic_args:
msg = f"Expected generic type of {fn_return} to be a type."
raise TypeError(msg)
arg_origin = get_origin(fn_first_arg_type) or fn_first_arg_type
if arg_origin is not Var:
msg = f"Expected first argument of {fn.__name__} to be a Var, got {fn_first_arg_type}."
raise TypeError(msg)
arg_generic_args = get_args(fn_first_arg_type)
if not arg_generic_args:
msg = f"Expected generic type of {fn_first_arg_type} to be a type."
raise TypeError(msg)
fn_return_type = fn_return_generic_args[0]
var = (
Var(
field_name,
_var_data=var_data,
_var_type=fn_return_type,
).guess_type()
if existing_var is None
else existing_var._replace(
_var_type=fn_return_type,
_var_data=var_data,
_js_expr=field_name,
).guess_type()
)
return fn(var)
if existing_var is not None:
return existing_var._replace(
_js_expr=field_name,
_var_data=var_data,
_var_type=result_var_type,
).guess_type()
return Var(
field_name,
_var_data=var_data,
_var_type=result_var_type,
).guess_type()
if TYPE_CHECKING:
from _typeshed import DataclassInstance
from sqlalchemy.orm import DeclarativeBase
SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase | None)
DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance | None)
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping | None)
V = TypeVar("V")
FIELD_TYPE = TypeVar("FIELD_TYPE")
class Field(Generic[FIELD_TYPE]):
"""A field for a state."""
if TYPE_CHECKING:
type_: GenericType
default: FIELD_TYPE | _MISSING_TYPE
default_factory: Callable[[], FIELD_TYPE] | None
def __init__(
self,
default: FIELD_TYPE | _MISSING_TYPE = MISSING,
default_factory: Callable[[], FIELD_TYPE] | None = None,
is_var: bool = True,
annotated_type: GenericType # pyright: ignore [reportRedeclaration]
| _MISSING_TYPE = MISSING,
) -> None:
"""Initialize the field.
Args:
default: The default value for the field.
default_factory: The default factory for the field.
is_var: Whether the field is a Var.
annotated_type: The annotated type for the field.
"""
self.default = default
self.default_factory = default_factory
self.is_var = is_var
if annotated_type is not MISSING:
type_origin = get_origin(annotated_type) or annotated_type
if type_origin is Field and (
args := getattr(annotated_type, "__args__", None)
):
annotated_type: GenericType = args[0]
type_origin = get_origin(annotated_type) or annotated_type
if self.default is MISSING and self.default_factory is None:
default_value = types.get_default_value_for_type(annotated_type)
if default_value is None and not types.is_optional(annotated_type):
annotated_type = annotated_type | None
if types.is_immutable(default_value):
self.default = default_value
else:
self.default_factory = functools.partial(
copy.deepcopy, default_value
)
self.outer_type_ = self.annotated_type = annotated_type
if type_origin is Annotated:
type_origin = annotated_type.__origin__ # pyright: ignore [reportAttributeAccessIssue]
self.type_ = self.type_origin = type_origin
else:
self.outer_type_ = self.annotated_type = self.type_ = self.type_origin = Any
def default_value(self) -> FIELD_TYPE:
"""Get the default value for the field.
Returns:
The default value for the field.
Raises:
ValueError: If no default value or factory is provided.
"""
if self.default is not MISSING:
return self.default
if self.default_factory is not None:
return self.default_factory()
msg = "No default value or factory provided."
raise ValueError(msg)
def __repr__(self) -> str:
"""Represent the field in a readable format.
Returns:
The string representation of the field.
"""
annotated_type_str = (
f", annotated_type={self.annotated_type!r}"
if self.annotated_type is not MISSING
else ""
)
if self.default is not MISSING:
return f"Field(default={self.default!r}, is_var={self.is_var}{annotated_type_str})"
return f"Field(default_factory={self.default_factory!r}, is_var={self.is_var}{annotated_type_str})"
if TYPE_CHECKING:
def __set__(self, instance: Any, value: FIELD_TYPE):
"""Set the Var.
Args:
instance: The instance of the class setting the Var.
value: The value to set the Var to.
# noqa: DAR101 self
"""
@overload
def __get__(self: Field[None], instance: None, owner: Any) -> NoneVar: ...
@overload
def __get__(
self: Field[bool] | Field[bool | None], instance: None, owner: Any
) -> BooleanVar: ...
@overload
def __get__(
self: Field[int] | Field[int | None],
instance: None,
owner: Any,
) -> NumberVar[int]: ...
@overload
def __get__(
self: Field[float]
| Field[int | float]
| Field[float | None]
| Field[int | float | None],
instance: None,
owner: Any,
) -> NumberVar: ...
@overload
def __get__(
self: Field[str] | Field[str | None], instance: None, owner: Any
) -> StringVar: ...
@overload
def __get__(
self: Field[list[V]]
| Field[set[V]]
| Field[list[V] | None]
| Field[set[V] | None],
instance: None,
owner: Any,
) -> ArrayVar[Sequence[V]]: ...
@overload
def __get__(
self: Field[SEQUENCE_TYPE] | Field[SEQUENCE_TYPE | None],
instance: None,
owner: Any,
) -> ArrayVar[SEQUENCE_TYPE]: ...
@overload
def __get__(
self: Field[MAPPING_TYPE] | Field[MAPPING_TYPE | None],
instance: None,
owner: Any,
) -> ObjectVar[MAPPING_TYPE]: ...
@overload
def __get__(
self: Field[SQLA_TYPE] | Field[SQLA_TYPE | None], instance: None, owner: Any
) -> ObjectVar[SQLA_TYPE]: ...
if TYPE_CHECKING:
@overload
def __get__(
self: Field[DATACLASS_TYPE] | Field[DATACLASS_TYPE | None],
instance: None,
owner: Any,
) -> ObjectVar[DATACLASS_TYPE]: ...
@overload
def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ...
@overload
def __get__(self, instance: Any, owner: Any) -> FIELD_TYPE: ...
def __get__(self, instance: Any, owner: Any): # pyright: ignore [reportInconsistentOverload]
"""Get the Var.
Args:
instance: The instance of the class accessing the Var.
owner: The class that the Var is attached to.
"""
@overload
def field(
default: FIELD_TYPE | _MISSING_TYPE = MISSING,
*,
is_var: Literal[False],
default_factory: Callable[[], FIELD_TYPE] | None = None,
) -> FIELD_TYPE: ...
@overload
def field(
default: FIELD_TYPE | _MISSING_TYPE = MISSING,
*,
default_factory: Callable[[], FIELD_TYPE] | None = None,
is_var: Literal[True] = True,
) -> Field[FIELD_TYPE]: ...
def field(
default: FIELD_TYPE | _MISSING_TYPE = MISSING,
*,
default_factory: Callable[[], FIELD_TYPE] | None = None,
is_var: bool = True,
) -> Field[FIELD_TYPE] | FIELD_TYPE:
"""Create a field for a state.
Args:
default: The default value for the field.
default_factory: The default factory for the field.
is_var: Whether the field is a Var.
Returns:
The field for the state.
Raises:
ValueError: If both default and default_factory are specified.
"""
if default is not MISSING and default_factory is not None:
msg = "cannot specify both default and default_factory"
raise ValueError(msg)
if default is not MISSING and not types.is_immutable(default):
console.warn(
"Mutable default values are not recommended. "
"Use default_factory instead to avoid unexpected behavior."
)
return Field(
default_factory=functools.partial(copy.deepcopy, default),
is_var=is_var,
)
return Field(
default=default,
default_factory=default_factory,
is_var=is_var,
)
@dataclass_transform(kw_only_default=True, field_specifiers=(field,))
class BaseStateMeta(ABCMeta):
"""Meta class for BaseState."""
if TYPE_CHECKING:
__inherited_fields__: Mapping[str, Field]
__own_fields__: dict[str, Field]
__fields__: dict[str, Field]
# Whether this state class is a mixin and should not be instantiated.
_mixin: bool = False
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
mixin: bool = False,
) -> type:
"""Create a new class.
Args:
name: The name of the class.
bases: The bases of the class.
namespace: The namespace of the class.
mixin: Whether the class is a mixin and should not be instantiated.
Returns:
The new class.
"""
state_bases = [
base for base in bases if issubclass(base, EvenMoreBasicBaseState)
]
mixin = mixin or (
bool(state_bases) and all(base._mixin for base in state_bases)
)
# Add the field to the class
inherited_fields: dict[str, Field] = {}
own_fields: dict[str, Field] = {}
resolved_annotations = types.resolve_annotations(
annotations_from_namespace(namespace), namespace["__module__"]
)
for base in bases[::-1]:
if hasattr(base, "__inherited_fields__"):
inherited_fields.update(base.__inherited_fields__)
for base in bases[::-1]:
if hasattr(base, "__own_fields__"):
inherited_fields.update(base.__own_fields__)
for key, value in [
(key, value)
for key, value in namespace.items()
if key not in resolved_annotations
]:
if isinstance(value, Field):
if value.annotated_type is not Any:
new_value = value
elif value.default is not MISSING:
new_value = Field(
default=value.default,
is_var=value.is_var,
annotated_type=figure_out_type(value.default),
)
else:
new_value = Field(
default_factory=value.default_factory,
is_var=value.is_var,
annotated_type=Any,
)
elif (
not key.startswith("__")
and not callable(value)
and not isinstance(value, (staticmethod, classmethod, property, Var))
):
if types.is_immutable(value):
new_value = Field(
default=value,
annotated_type=figure_out_type(value),
)
else:
new_value = Field(
default_factory=functools.partial(copy.deepcopy, value),
annotated_type=figure_out_type(value),
)
else:
continue
own_fields[key] = new_value
for key, annotation in resolved_annotations.items():
value = namespace.get(key, MISSING)
if types.is_classvar(annotation):
# If the annotation is a classvar, skip it.
continue
if value is MISSING:
value = Field(
annotated_type=annotation,
)
elif not isinstance(value, Field):
if types.is_immutable(value):
value = Field(
default=value,
annotated_type=annotation,
)
else:
value = Field(
default_factory=functools.partial(copy.deepcopy, value),
annotated_type=annotation,
)
else:
value = Field(
default=value.default,
default_factory=value.default_factory,
is_var=value.is_var,
annotated_type=annotation,
)
own_fields[key] = value
namespace["__own_fields__"] = own_fields
namespace["__inherited_fields__"] = inherited_fields
namespace["__fields__"] = inherited_fields | own_fields
namespace["_mixin"] = mixin
return super().__new__(cls, name, bases, namespace)
class EvenMoreBasicBaseState(metaclass=BaseStateMeta):
"""A simplified base state class that provides basic functionality."""
def __init__(
self,
**kwargs,
):
"""Initialize the state with the given kwargs.
Args:
**kwargs: The kwargs to pass to the state.
"""
super().__init__()
for key, value in kwargs.items():
object.__setattr__(self, key, value)
for name, value in type(self).get_fields().items():
if name not in kwargs:
default_value = value.default_value()
object.__setattr__(self, name, default_value)
def set(self, **kwargs):
"""Mutate the state by setting the given kwargs. Returns the state.
Args:
**kwargs: The kwargs to set.
Returns:
The state with the fields set to the given kwargs.
"""
for key, value in kwargs.items():
setattr(self, key, value)
return self
@classmethod
def get_fields(cls) -> Mapping[str, Field]:
"""Get the fields of the component.
Returns:
The fields of the component.
"""
return cls.__fields__
@classmethod
def add_field(cls, name: str, var: Var, default_value: Any):
"""Add a field to the class after class definition.
Used by State.add_var() to correctly handle the new variable.
Args:
name: The name of the field to add.
var: The variable to add a field for.
default_value: The default value of the field.
"""
if types.is_immutable(default_value):
new_field = Field(
default=default_value,
annotated_type=var._var_type,
)
else:
new_field = Field(
default_factory=functools.partial(copy.deepcopy, default_value),
annotated_type=var._var_type,
)
cls.__fields__[name] = new_field