554 lines
16 KiB
Python
554 lines
16 KiB
Python
"""Immutable function vars."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
from collections.abc import Callable, Sequence
|
|
from typing import Any, Concatenate, Generic, ParamSpec, Protocol, TypeVar, overload
|
|
|
|
from reflex_base.utils import format
|
|
from reflex_base.utils.types import GenericType
|
|
|
|
from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
|
|
|
|
P = ParamSpec("P")
|
|
V1 = TypeVar("V1")
|
|
V2 = TypeVar("V2")
|
|
V3 = TypeVar("V3")
|
|
V4 = TypeVar("V4")
|
|
V5 = TypeVar("V5")
|
|
V6 = TypeVar("V6")
|
|
R = TypeVar("R")
|
|
|
|
|
|
class ReflexCallable(Protocol[P, R]):
|
|
"""Protocol for a callable."""
|
|
|
|
__call__: Callable[P, R]
|
|
|
|
|
|
CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, covariant=True)
|
|
OTHER_CALLABLE_TYPE = TypeVar(
|
|
"OTHER_CALLABLE_TYPE", bound=ReflexCallable, covariant=True
|
|
)
|
|
|
|
|
|
def _is_js_identifier_start(char: str) -> bool:
|
|
"""Check whether a character can start a JavaScript identifier.
|
|
|
|
Returns:
|
|
True if the character is valid as the first character of a JS identifier.
|
|
"""
|
|
return char == "$" or char == "_" or char.isalpha()
|
|
|
|
|
|
def _is_js_identifier_char(char: str) -> bool:
|
|
"""Check whether a character can continue a JavaScript identifier.
|
|
|
|
Returns:
|
|
True if the character is valid within a JS identifier.
|
|
"""
|
|
return _is_js_identifier_start(char) or char.isdigit()
|
|
|
|
|
|
def _starts_with_arrow_function(expr: str) -> bool:
|
|
"""Check whether an expression starts with an inline arrow function.
|
|
|
|
Returns:
|
|
True if the expression begins with an arrow function.
|
|
"""
|
|
if "=>" not in expr:
|
|
return False
|
|
|
|
expr = expr.lstrip()
|
|
if not expr:
|
|
return False
|
|
|
|
if expr.startswith("async"):
|
|
async_remainder = expr[len("async") :]
|
|
if async_remainder[:1].isspace():
|
|
expr = async_remainder.lstrip()
|
|
|
|
if not expr:
|
|
return False
|
|
|
|
if _is_js_identifier_start(expr[0]):
|
|
end_index = 1
|
|
while end_index < len(expr) and _is_js_identifier_char(expr[end_index]):
|
|
end_index += 1
|
|
return expr[end_index:].lstrip().startswith("=>")
|
|
|
|
if not expr.startswith("("):
|
|
return False
|
|
|
|
depth = 0
|
|
string_delimiter: str | None = None
|
|
escaped = False
|
|
|
|
for index, char in enumerate(expr):
|
|
if string_delimiter is not None:
|
|
if escaped:
|
|
escaped = False
|
|
elif char == "\\":
|
|
escaped = True
|
|
elif char == string_delimiter:
|
|
string_delimiter = None
|
|
continue
|
|
|
|
if char in {"'", '"', "`"}:
|
|
string_delimiter = char
|
|
continue
|
|
|
|
if char == "(":
|
|
depth += 1
|
|
continue
|
|
|
|
if char == ")":
|
|
depth -= 1
|
|
if depth == 0:
|
|
return expr[index + 1 :].lstrip().startswith("=>")
|
|
|
|
return False
|
|
|
|
|
|
class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
|
|
"""Base class for immutable function vars."""
|
|
|
|
@overload
|
|
def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...
|
|
|
|
@overload
|
|
def partial(
|
|
self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
|
|
arg1: V1 | Var[V1],
|
|
) -> FunctionVar[ReflexCallable[P, R]]: ...
|
|
|
|
@overload
|
|
def partial(
|
|
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
|
|
arg1: V1 | Var[V1],
|
|
arg2: V2 | Var[V2],
|
|
) -> FunctionVar[ReflexCallable[P, R]]: ...
|
|
|
|
@overload
|
|
def partial(
|
|
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
|
|
arg1: V1 | Var[V1],
|
|
arg2: V2 | Var[V2],
|
|
arg3: V3 | Var[V3],
|
|
) -> FunctionVar[ReflexCallable[P, R]]: ...
|
|
|
|
@overload
|
|
def partial(
|
|
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
|
|
arg1: V1 | Var[V1],
|
|
arg2: V2 | Var[V2],
|
|
arg3: V3 | Var[V3],
|
|
arg4: V4 | Var[V4],
|
|
) -> FunctionVar[ReflexCallable[P, R]]: ...
|
|
|
|
@overload
|
|
def partial(
|
|
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
|
|
arg1: V1 | Var[V1],
|
|
arg2: V2 | Var[V2],
|
|
arg3: V3 | Var[V3],
|
|
arg4: V4 | Var[V4],
|
|
arg5: V5 | Var[V5],
|
|
) -> FunctionVar[ReflexCallable[P, R]]: ...
|
|
|
|
@overload
|
|
def partial(
|
|
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
|
|
arg1: V1 | Var[V1],
|
|
arg2: V2 | Var[V2],
|
|
arg3: V3 | Var[V3],
|
|
arg4: V4 | Var[V4],
|
|
arg5: V5 | Var[V5],
|
|
arg6: V6 | Var[V6],
|
|
) -> FunctionVar[ReflexCallable[P, R]]: ...
|
|
|
|
@overload
|
|
def partial(
|
|
self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
|
|
) -> FunctionVar[ReflexCallable[P, R]]: ...
|
|
|
|
@overload
|
|
def partial(self, *args: Var | Any) -> FunctionVar: ...
|
|
|
|
def partial(self, *args: Var | Any) -> FunctionVar: # pyright: ignore [reportInconsistentOverload]
|
|
"""Partially apply the function with the given arguments.
|
|
|
|
Args:
|
|
*args: The arguments to partially apply the function with.
|
|
|
|
Returns:
|
|
The partially applied function.
|
|
"""
|
|
if not args:
|
|
return self
|
|
return ArgsFunctionOperation.create(
|
|
("...args",),
|
|
VarOperationCall.create(self, *args, Var(_js_expr="...args")),
|
|
)
|
|
|
|
@overload
|
|
def call(
|
|
self: FunctionVar[ReflexCallable[[V1], R]], arg1: V1 | Var[V1]
|
|
) -> VarOperationCall[[V1], R]: ...
|
|
|
|
@overload
|
|
def call(
|
|
self: FunctionVar[ReflexCallable[[V1, V2], R]],
|
|
arg1: V1 | Var[V1],
|
|
arg2: V2 | Var[V2],
|
|
) -> VarOperationCall[[V1, V2], R]: ...
|
|
|
|
@overload
|
|
def call(
|
|
self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
|
|
arg1: V1 | Var[V1],
|
|
arg2: V2 | Var[V2],
|
|
arg3: V3 | Var[V3],
|
|
) -> VarOperationCall[[V1, V2, V3], R]: ...
|
|
|
|
@overload
|
|
def call(
|
|
self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
|
|
arg1: V1 | Var[V1],
|
|
arg2: V2 | Var[V2],
|
|
arg3: V3 | Var[V3],
|
|
arg4: V4 | Var[V4],
|
|
) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
|
|
|
|
@overload
|
|
def call(
|
|
self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
|
|
arg1: V1 | Var[V1],
|
|
arg2: V2 | Var[V2],
|
|
arg3: V3 | Var[V3],
|
|
arg4: V4 | Var[V4],
|
|
arg5: V5 | Var[V5],
|
|
) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
|
|
|
|
@overload
|
|
def call(
|
|
self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
|
|
arg1: V1 | Var[V1],
|
|
arg2: V2 | Var[V2],
|
|
arg3: V3 | Var[V3],
|
|
arg4: V4 | Var[V4],
|
|
arg5: V5 | Var[V5],
|
|
arg6: V6 | Var[V6],
|
|
) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
|
|
|
|
@overload
|
|
def call(
|
|
self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
|
|
) -> VarOperationCall[P, R]: ...
|
|
|
|
@overload
|
|
def call(self, *args: Var | Any) -> Var: ...
|
|
|
|
def call(self, *args: Var | Any) -> Var: # pyright: ignore [reportInconsistentOverload]
|
|
"""Call the function with the given arguments.
|
|
|
|
Args:
|
|
*args: The arguments to call the function with.
|
|
|
|
Returns:
|
|
The function call operation.
|
|
"""
|
|
return VarOperationCall.create(self, *args).guess_type()
|
|
|
|
__call__ = call
|
|
|
|
|
|
class BuilderFunctionVar(
|
|
FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
|
|
):
|
|
"""Base class for immutable function vars with the builder pattern."""
|
|
|
|
__call__ = FunctionVar.partial
|
|
|
|
|
|
class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
|
|
"""Base class for immutable function vars from a string."""
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
func: str,
|
|
_var_type: type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
|
|
_var_data: VarData | None = None,
|
|
) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
|
|
"""Create a new function var from a string.
|
|
|
|
Args:
|
|
func: The function to call.
|
|
_var_type: The type of the Var.
|
|
_var_data: Additional hooks and imports associated with the Var.
|
|
|
|
Returns:
|
|
The function var.
|
|
"""
|
|
return FunctionStringVar(
|
|
_js_expr=func,
|
|
_var_type=_var_type,
|
|
_var_data=_var_data,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass(
|
|
eq=False,
|
|
frozen=True,
|
|
slots=True,
|
|
)
|
|
class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
|
|
"""Base class for immutable vars that are the result of a function call."""
|
|
|
|
_func: FunctionVar[ReflexCallable[P, R]] | None = dataclasses.field(default=None)
|
|
_args: tuple[Var | Any, ...] = dataclasses.field(default_factory=tuple)
|
|
|
|
@cached_property_no_lock
|
|
def _cached_var_name(self) -> str:
|
|
"""The name of the var.
|
|
|
|
Returns:
|
|
The name of the var.
|
|
"""
|
|
func_expr = str(self._func)
|
|
if _starts_with_arrow_function(func_expr) and not format.is_wrapped(
|
|
func_expr, "("
|
|
):
|
|
func_expr = format.wrap(func_expr, "(")
|
|
|
|
return f"({func_expr}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
|
|
|
|
@cached_property_no_lock
|
|
def _cached_get_all_var_data(self) -> VarData | None:
|
|
"""Get all the var data associated with the var.
|
|
|
|
Returns:
|
|
All the var data associated with the var.
|
|
"""
|
|
return VarData.merge(
|
|
self._func._get_all_var_data() if self._func is not None else None,
|
|
*[LiteralVar.create(arg)._get_all_var_data() for arg in self._args],
|
|
self._var_data,
|
|
)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
func: FunctionVar[ReflexCallable[P, R]],
|
|
*args: Var | Any,
|
|
_var_type: GenericType = Any,
|
|
_var_data: VarData | None = None,
|
|
) -> VarOperationCall:
|
|
"""Create a new function call var.
|
|
|
|
Args:
|
|
func: The function to call.
|
|
*args: The arguments to call the function with.
|
|
_var_type: The type of the Var.
|
|
_var_data: Additional hooks and imports associated with the Var.
|
|
|
|
Returns:
|
|
The function call var.
|
|
"""
|
|
function_return_type = (
|
|
func._var_type.__args__[1]
|
|
if getattr(func._var_type, "__args__", None)
|
|
else Any
|
|
)
|
|
var_type = _var_type if _var_type is not Any else function_return_type
|
|
return cls(
|
|
_js_expr="",
|
|
_var_type=var_type,
|
|
_var_data=_var_data,
|
|
_func=func,
|
|
_args=args,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class DestructuredArg:
|
|
"""Class for destructured arguments."""
|
|
|
|
fields: tuple[str, ...] = ()
|
|
rest: str | None = None
|
|
|
|
def to_javascript(self) -> str:
|
|
"""Convert the destructured argument to JavaScript.
|
|
|
|
Returns:
|
|
The destructured argument in JavaScript.
|
|
"""
|
|
inner = ", ".join(self.fields)
|
|
if self.rest:
|
|
inner = f"{inner}, ...{self.rest}" if inner else f"...{self.rest}"
|
|
return format.wrap(inner, "{", "}")
|
|
|
|
|
|
@dataclasses.dataclass(
|
|
frozen=True,
|
|
)
|
|
class FunctionArgs:
|
|
"""Class for function arguments."""
|
|
|
|
args: tuple[str | DestructuredArg, ...] = ()
|
|
rest: str | None = None
|
|
|
|
|
|
def format_args_function_operation(
|
|
args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
|
|
) -> str:
|
|
"""Format an args function operation.
|
|
|
|
Args:
|
|
args: The function arguments.
|
|
return_expr: The return expression.
|
|
explicit_return: Whether to use explicit return syntax.
|
|
|
|
Returns:
|
|
The formatted args function operation.
|
|
"""
|
|
arg_names_str = ", ".join([
|
|
arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args
|
|
]) + (f", ...{args.rest}" if args.rest else "")
|
|
|
|
return_expr_str = str(LiteralVar.create(return_expr))
|
|
|
|
# Wrap return expression in curly braces if explicit return syntax is used.
|
|
return_expr_str_wrapped = (
|
|
format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
|
|
)
|
|
|
|
return f"(({arg_names_str}) => {return_expr_str_wrapped})"
|
|
|
|
|
|
@dataclasses.dataclass(
|
|
eq=False,
|
|
frozen=True,
|
|
slots=True,
|
|
)
|
|
class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
|
|
"""Base class for immutable function defined via arguments and return expression."""
|
|
|
|
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
|
|
_return_expr: Var | Any = dataclasses.field(default=None)
|
|
_explicit_return: bool = dataclasses.field(default=False)
|
|
|
|
@cached_property_no_lock
|
|
def _cached_var_name(self) -> str:
|
|
"""The name of the var.
|
|
|
|
Returns:
|
|
The name of the var.
|
|
"""
|
|
return format_args_function_operation(
|
|
self._args, self._return_expr, self._explicit_return
|
|
)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
args_names: Sequence[str | DestructuredArg],
|
|
return_expr: Var | Any,
|
|
rest: str | None = None,
|
|
explicit_return: bool = False,
|
|
_var_type: GenericType = Callable,
|
|
_var_data: VarData | None = None,
|
|
):
|
|
"""Create a new function var.
|
|
|
|
Args:
|
|
args_names: The names of the arguments.
|
|
return_expr: The return expression of the function.
|
|
rest: The name of the rest argument.
|
|
explicit_return: Whether to use explicit return syntax.
|
|
_var_type: The type of the Var.
|
|
_var_data: Additional hooks and imports associated with the Var.
|
|
|
|
Returns:
|
|
The function var.
|
|
"""
|
|
return_expr = Var.create(return_expr)
|
|
return cls(
|
|
_js_expr="",
|
|
_var_type=_var_type,
|
|
_var_data=_var_data,
|
|
_args=FunctionArgs(args=tuple(args_names), rest=rest),
|
|
_return_expr=return_expr,
|
|
_explicit_return=explicit_return,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass(
|
|
eq=False,
|
|
frozen=True,
|
|
slots=True,
|
|
)
|
|
class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
|
|
"""Base class for immutable function defined via arguments and return expression with the builder pattern."""
|
|
|
|
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
|
|
_return_expr: Var | Any = dataclasses.field(default=None)
|
|
_explicit_return: bool = dataclasses.field(default=False)
|
|
|
|
@cached_property_no_lock
|
|
def _cached_var_name(self) -> str:
|
|
"""The name of the var.
|
|
|
|
Returns:
|
|
The name of the var.
|
|
"""
|
|
return format_args_function_operation(
|
|
self._args, self._return_expr, self._explicit_return
|
|
)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
args_names: Sequence[str | DestructuredArg],
|
|
return_expr: Var | Any,
|
|
rest: str | None = None,
|
|
explicit_return: bool = False,
|
|
_var_type: GenericType = Callable,
|
|
_var_data: VarData | None = None,
|
|
):
|
|
"""Create a new function var.
|
|
|
|
Args:
|
|
args_names: The names of the arguments.
|
|
return_expr: The return expression of the function.
|
|
rest: The name of the rest argument.
|
|
explicit_return: Whether to use explicit return syntax.
|
|
_var_type: The type of the Var.
|
|
_var_data: Additional hooks and imports associated with the Var.
|
|
|
|
Returns:
|
|
The function var.
|
|
"""
|
|
return_expr = Var.create(return_expr)
|
|
return cls(
|
|
_js_expr="",
|
|
_var_type=_var_type,
|
|
_var_data=_var_data,
|
|
_args=FunctionArgs(args=tuple(args_names), rest=rest),
|
|
_return_expr=return_expr,
|
|
_explicit_return=explicit_return,
|
|
)
|
|
|
|
|
|
JSON_STRINGIFY = FunctionStringVar.create(
|
|
"JSON.stringify", _var_type=ReflexCallable[[Any], str]
|
|
)
|
|
ARRAY_ISARRAY = FunctionStringVar.create(
|
|
"Array.isArray", _var_type=ReflexCallable[[Any], bool]
|
|
)
|
|
PROTOTYPE_TO_STRING = FunctionStringVar.create(
|
|
"((__to_string) => __to_string.toString())",
|
|
_var_type=ReflexCallable[[Any], str],
|
|
)
|