"""Immutable function vars.""" from __future__ import annotations import dataclasses from collections.abc import Callable, Sequence from typing import Any, Concatenate, Generic, ParamSpec, Protocol, TypeVar, overload from reflex_base.utils import format from reflex_base.utils.types import GenericType from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock P = ParamSpec("P") V1 = TypeVar("V1") V2 = TypeVar("V2") V3 = TypeVar("V3") V4 = TypeVar("V4") V5 = TypeVar("V5") V6 = TypeVar("V6") R = TypeVar("R") class ReflexCallable(Protocol[P, R]): """Protocol for a callable.""" __call__: Callable[P, R] CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, covariant=True) OTHER_CALLABLE_TYPE = TypeVar( "OTHER_CALLABLE_TYPE", bound=ReflexCallable, covariant=True ) def _is_js_identifier_start(char: str) -> bool: """Check whether a character can start a JavaScript identifier. Returns: True if the character is valid as the first character of a JS identifier. """ return char == "$" or char == "_" or char.isalpha() def _is_js_identifier_char(char: str) -> bool: """Check whether a character can continue a JavaScript identifier. Returns: True if the character is valid within a JS identifier. """ return _is_js_identifier_start(char) or char.isdigit() def _starts_with_arrow_function(expr: str) -> bool: """Check whether an expression starts with an inline arrow function. Returns: True if the expression begins with an arrow function. """ if "=>" not in expr: return False expr = expr.lstrip() if not expr: return False if expr.startswith("async"): async_remainder = expr[len("async") :] if async_remainder[:1].isspace(): expr = async_remainder.lstrip() if not expr: return False if _is_js_identifier_start(expr[0]): end_index = 1 while end_index < len(expr) and _is_js_identifier_char(expr[end_index]): end_index += 1 return expr[end_index:].lstrip().startswith("=>") if not expr.startswith("("): return False depth = 0 string_delimiter: str | None = None escaped = False for index, char in enumerate(expr): if string_delimiter is not None: if escaped: escaped = False elif char == "\\": escaped = True elif char == string_delimiter: string_delimiter = None continue if char in {"'", '"', "`"}: string_delimiter = char continue if char == "(": depth += 1 continue if char == ")": depth -= 1 if depth == 0: return expr[index + 1 :].lstrip().startswith("=>") return False class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): """Base class for immutable function vars.""" @overload def partial(self) -> FunctionVar[CALLABLE_TYPE]: ... @overload def partial( self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]], arg1: V1 | Var[V1], ) -> FunctionVar[ReflexCallable[P, R]]: ... @overload def partial( self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]], arg1: V1 | Var[V1], arg2: V2 | Var[V2], ) -> FunctionVar[ReflexCallable[P, R]]: ... @overload def partial( self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]], arg1: V1 | Var[V1], arg2: V2 | Var[V2], arg3: V3 | Var[V3], ) -> FunctionVar[ReflexCallable[P, R]]: ... @overload def partial( self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]], arg1: V1 | Var[V1], arg2: V2 | Var[V2], arg3: V3 | Var[V3], arg4: V4 | Var[V4], ) -> FunctionVar[ReflexCallable[P, R]]: ... @overload def partial( self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]], arg1: V1 | Var[V1], arg2: V2 | Var[V2], arg3: V3 | Var[V3], arg4: V4 | Var[V4], arg5: V5 | Var[V5], ) -> FunctionVar[ReflexCallable[P, R]]: ... @overload def partial( self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]], arg1: V1 | Var[V1], arg2: V2 | Var[V2], arg3: V3 | Var[V3], arg4: V4 | Var[V4], arg5: V5 | Var[V5], arg6: V6 | Var[V6], ) -> FunctionVar[ReflexCallable[P, R]]: ... @overload def partial( self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any ) -> FunctionVar[ReflexCallable[P, R]]: ... @overload def partial(self, *args: Var | Any) -> FunctionVar: ... def partial(self, *args: Var | Any) -> FunctionVar: # pyright: ignore [reportInconsistentOverload] """Partially apply the function with the given arguments. Args: *args: The arguments to partially apply the function with. Returns: The partially applied function. """ if not args: return self return ArgsFunctionOperation.create( ("...args",), VarOperationCall.create(self, *args, Var(_js_expr="...args")), ) @overload def call( self: FunctionVar[ReflexCallable[[V1], R]], arg1: V1 | Var[V1] ) -> VarOperationCall[[V1], R]: ... @overload def call( self: FunctionVar[ReflexCallable[[V1, V2], R]], arg1: V1 | Var[V1], arg2: V2 | Var[V2], ) -> VarOperationCall[[V1, V2], R]: ... @overload def call( self: FunctionVar[ReflexCallable[[V1, V2, V3], R]], arg1: V1 | Var[V1], arg2: V2 | Var[V2], arg3: V3 | Var[V3], ) -> VarOperationCall[[V1, V2, V3], R]: ... @overload def call( self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]], arg1: V1 | Var[V1], arg2: V2 | Var[V2], arg3: V3 | Var[V3], arg4: V4 | Var[V4], ) -> VarOperationCall[[V1, V2, V3, V4], R]: ... @overload def call( self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]], arg1: V1 | Var[V1], arg2: V2 | Var[V2], arg3: V3 | Var[V3], arg4: V4 | Var[V4], arg5: V5 | Var[V5], ) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ... @overload def call( self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]], arg1: V1 | Var[V1], arg2: V2 | Var[V2], arg3: V3 | Var[V3], arg4: V4 | Var[V4], arg5: V5 | Var[V5], arg6: V6 | Var[V6], ) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ... @overload def call( self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any ) -> VarOperationCall[P, R]: ... @overload def call(self, *args: Var | Any) -> Var: ... def call(self, *args: Var | Any) -> Var: # pyright: ignore [reportInconsistentOverload] """Call the function with the given arguments. Args: *args: The arguments to call the function with. Returns: The function call operation. """ return VarOperationCall.create(self, *args).guess_type() __call__ = call class BuilderFunctionVar( FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any] ): """Base class for immutable function vars with the builder pattern.""" __call__ = FunctionVar.partial class FunctionStringVar(FunctionVar[CALLABLE_TYPE]): """Base class for immutable function vars from a string.""" @classmethod def create( cls, func: str, _var_type: type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any], _var_data: VarData | None = None, ) -> FunctionStringVar[OTHER_CALLABLE_TYPE]: """Create a new function var from a string. Args: func: The function to call. _var_type: The type of the Var. _var_data: Additional hooks and imports associated with the Var. Returns: The function var. """ return FunctionStringVar( _js_expr=func, _var_type=_var_type, _var_data=_var_data, ) @dataclasses.dataclass( eq=False, frozen=True, slots=True, ) class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]): """Base class for immutable vars that are the result of a function call.""" _func: FunctionVar[ReflexCallable[P, R]] | None = dataclasses.field(default=None) _args: tuple[Var | Any, ...] = dataclasses.field(default_factory=tuple) @cached_property_no_lock def _cached_var_name(self) -> str: """The name of the var. Returns: The name of the var. """ func_expr = str(self._func) if _starts_with_arrow_function(func_expr) and not format.is_wrapped( func_expr, "(" ): func_expr = format.wrap(func_expr, "(") return f"({func_expr}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))" @cached_property_no_lock def _cached_get_all_var_data(self) -> VarData | None: """Get all the var data associated with the var. Returns: All the var data associated with the var. """ return VarData.merge( self._func._get_all_var_data() if self._func is not None else None, *[LiteralVar.create(arg)._get_all_var_data() for arg in self._args], self._var_data, ) @classmethod def create( cls, func: FunctionVar[ReflexCallable[P, R]], *args: Var | Any, _var_type: GenericType = Any, _var_data: VarData | None = None, ) -> VarOperationCall: """Create a new function call var. Args: func: The function to call. *args: The arguments to call the function with. _var_type: The type of the Var. _var_data: Additional hooks and imports associated with the Var. Returns: The function call var. """ function_return_type = ( func._var_type.__args__[1] if getattr(func._var_type, "__args__", None) else Any ) var_type = _var_type if _var_type is not Any else function_return_type return cls( _js_expr="", _var_type=var_type, _var_data=_var_data, _func=func, _args=args, ) @dataclasses.dataclass(frozen=True) class DestructuredArg: """Class for destructured arguments.""" fields: tuple[str, ...] = () rest: str | None = None def to_javascript(self) -> str: """Convert the destructured argument to JavaScript. Returns: The destructured argument in JavaScript. """ inner = ", ".join(self.fields) if self.rest: inner = f"{inner}, ...{self.rest}" if inner else f"...{self.rest}" return format.wrap(inner, "{", "}") @dataclasses.dataclass( frozen=True, ) class FunctionArgs: """Class for function arguments.""" args: tuple[str | DestructuredArg, ...] = () rest: str | None = None def format_args_function_operation( args: FunctionArgs, return_expr: Var | Any, explicit_return: bool ) -> str: """Format an args function operation. Args: args: The function arguments. return_expr: The return expression. explicit_return: Whether to use explicit return syntax. Returns: The formatted args function operation. """ arg_names_str = ", ".join([ arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args ]) + (f", ...{args.rest}" if args.rest else "") return_expr_str = str(LiteralVar.create(return_expr)) # Wrap return expression in curly braces if explicit return syntax is used. return_expr_str_wrapped = ( format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str ) return f"(({arg_names_str}) => {return_expr_str_wrapped})" @dataclasses.dataclass( eq=False, frozen=True, slots=True, ) class ArgsFunctionOperation(CachedVarOperation, FunctionVar): """Base class for immutable function defined via arguments and return expression.""" _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) _return_expr: Var | Any = dataclasses.field(default=None) _explicit_return: bool = dataclasses.field(default=False) @cached_property_no_lock def _cached_var_name(self) -> str: """The name of the var. Returns: The name of the var. """ return format_args_function_operation( self._args, self._return_expr, self._explicit_return ) @classmethod def create( cls, args_names: Sequence[str | DestructuredArg], return_expr: Var | Any, rest: str | None = None, explicit_return: bool = False, _var_type: GenericType = Callable, _var_data: VarData | None = None, ): """Create a new function var. Args: args_names: The names of the arguments. return_expr: The return expression of the function. rest: The name of the rest argument. explicit_return: Whether to use explicit return syntax. _var_type: The type of the Var. _var_data: Additional hooks and imports associated with the Var. Returns: The function var. """ return_expr = Var.create(return_expr) return cls( _js_expr="", _var_type=_var_type, _var_data=_var_data, _args=FunctionArgs(args=tuple(args_names), rest=rest), _return_expr=return_expr, _explicit_return=explicit_return, ) @dataclasses.dataclass( eq=False, frozen=True, slots=True, ) class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): """Base class for immutable function defined via arguments and return expression with the builder pattern.""" _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) _return_expr: Var | Any = dataclasses.field(default=None) _explicit_return: bool = dataclasses.field(default=False) @cached_property_no_lock def _cached_var_name(self) -> str: """The name of the var. Returns: The name of the var. """ return format_args_function_operation( self._args, self._return_expr, self._explicit_return ) @classmethod def create( cls, args_names: Sequence[str | DestructuredArg], return_expr: Var | Any, rest: str | None = None, explicit_return: bool = False, _var_type: GenericType = Callable, _var_data: VarData | None = None, ): """Create a new function var. Args: args_names: The names of the arguments. return_expr: The return expression of the function. rest: The name of the rest argument. explicit_return: Whether to use explicit return syntax. _var_type: The type of the Var. _var_data: Additional hooks and imports associated with the Var. Returns: The function var. """ return_expr = Var.create(return_expr) return cls( _js_expr="", _var_type=_var_type, _var_data=_var_data, _args=FunctionArgs(args=tuple(args_names), rest=rest), _return_expr=return_expr, _explicit_return=explicit_return, ) JSON_STRINGIFY = FunctionStringVar.create( "JSON.stringify", _var_type=ReflexCallable[[Any], str] ) ARRAY_ISARRAY = FunctionStringVar.create( "Array.isArray", _var_type=ReflexCallable[[Any], bool] ) PROTOTYPE_TO_STRING = FunctionStringVar.create( "((__to_string) => __to_string.toString())", _var_type=ReflexCallable[[Any], str], )