990 lines
36 KiB
Python
990 lines
36 KiB
Python
"""reflex.testing - tools for testing reflex apps."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import contextvars
|
|
import dataclasses
|
|
import functools
|
|
import inspect
|
|
import os
|
|
import platform
|
|
import re
|
|
import signal
|
|
import socket
|
|
import socketserver
|
|
import subprocess
|
|
import sys
|
|
import textwrap
|
|
import threading
|
|
import time
|
|
import types
|
|
from collections.abc import Callable, Coroutine, Sequence
|
|
from copy import deepcopy
|
|
from http.server import SimpleHTTPRequestHandler
|
|
from importlib.util import find_spec
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar
|
|
|
|
import uvicorn
|
|
from reflex_base.components.component import CUSTOM_COMPONENTS, CustomComponent
|
|
from reflex_base.config import get_config
|
|
from reflex_base.environment import environment
|
|
from reflex_base.registry import RegistrationContext
|
|
from reflex_base.utils.types import ASGIApp
|
|
from typing_extensions import Self
|
|
|
|
import reflex
|
|
import reflex.reflex
|
|
import reflex.utils.build
|
|
import reflex.utils.format
|
|
import reflex.utils.prerequisites
|
|
import reflex.utils.processes
|
|
from reflex.experimental.memo import EXPERIMENTAL_MEMOS
|
|
from reflex.istate.shared import SharedState as SharedState # To register it.
|
|
from reflex.state import reload_state_module
|
|
from reflex.utils import console, js_runtimes
|
|
from reflex.utils.export import export
|
|
from reflex.utils.token_manager import TokenManager
|
|
|
|
try:
|
|
from selenium import webdriver
|
|
from selenium.webdriver.remote.webdriver import WebDriver
|
|
|
|
if TYPE_CHECKING:
|
|
from selenium.webdriver.common.options import ArgOptions
|
|
from selenium.webdriver.remote.webelement import WebElement
|
|
|
|
has_selenium = True
|
|
except ImportError:
|
|
has_selenium = False
|
|
|
|
# The timeout (minutes) to check for the port.
|
|
DEFAULT_TIMEOUT = 15
|
|
POLL_INTERVAL = 0.25
|
|
FRONTEND_POPEN_ARGS = {}
|
|
T = TypeVar("T")
|
|
TimeoutType = int | float | None
|
|
if platform.system() == "Windows":
|
|
FRONTEND_POPEN_ARGS["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP # pyright: ignore [reportAttributeAccessIssue]
|
|
FRONTEND_POPEN_ARGS["shell"] = True
|
|
else:
|
|
FRONTEND_POPEN_ARGS["start_new_session"] = True
|
|
|
|
|
|
# borrowed from py3.11
|
|
class chdir(contextlib.AbstractContextManager): # noqa: N801
|
|
"""Non thread-safe context manager to change the current working directory."""
|
|
|
|
def __init__(self, path: str | Path):
|
|
"""Prepare contextmanager.
|
|
|
|
Args:
|
|
path: the path to change to
|
|
"""
|
|
self.path = path
|
|
self._old_cwd = []
|
|
|
|
def __enter__(self):
|
|
"""Save current directory and perform chdir."""
|
|
self._old_cwd.append(Path.cwd())
|
|
os.chdir(self.path)
|
|
|
|
def __exit__(self, *excinfo):
|
|
"""Change back to previous directory on stack.
|
|
|
|
Args:
|
|
excinfo: sys.exc_info captured in the context block
|
|
"""
|
|
os.chdir(self._old_cwd.pop())
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class AppHarness:
|
|
"""AppHarness executes a reflex app in-process for testing."""
|
|
|
|
app_name: str
|
|
app_source: (
|
|
Callable[[], None] | types.ModuleType | str | functools.partial[Any] | None
|
|
)
|
|
app_path: Path
|
|
app_module_path: Path
|
|
app_module: types.ModuleType | None = None
|
|
app_instance: reflex.App | None = None
|
|
app_asgi: ASGIApp | None = None
|
|
frontend_process: subprocess.Popen | None = None
|
|
frontend_url: str | None = None
|
|
frontend_output_thread: threading.Thread | None = None
|
|
backend_thread: threading.Thread | None = None
|
|
backend: uvicorn.Server | None = None
|
|
_frontends: list[WebDriver] = dataclasses.field(default_factory=list)
|
|
_registry_token: contextvars.Token[RegistrationContext] | None = None
|
|
_base_registration_context: ClassVar[RegistrationContext] | None = None
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
root: Path,
|
|
app_source: (
|
|
Callable[[], None] | types.ModuleType | str | functools.partial[Any] | None
|
|
) = None,
|
|
app_name: str | None = None,
|
|
) -> Self:
|
|
"""Create an AppHarness instance at root.
|
|
|
|
Args:
|
|
root: the directory that will contain the app under test.
|
|
app_source: if specified, the source code from this function or module is used
|
|
as the main module for the app. It may also be the raw source code text, as a str.
|
|
If unspecified, then root must already contain a working reflex app and will be used directly.
|
|
app_name: provide the name of the app, otherwise will be derived from app_source or root.
|
|
|
|
Returns:
|
|
AppHarness instance
|
|
|
|
Raises:
|
|
ValueError: when app_source is a string and app_name is not provided.
|
|
"""
|
|
if app_name is None:
|
|
if app_source is None:
|
|
app_name = root.name
|
|
elif isinstance(app_source, functools.partial):
|
|
keywords = app_source.keywords
|
|
slug_suffix = "_".join([str(v) for v in keywords.values()])
|
|
func_name = app_source.func.__name__
|
|
app_name = f"{func_name}_{slug_suffix}"
|
|
app_name = re.sub(r"[^a-zA-Z0-9_]", "_", app_name)
|
|
elif isinstance(app_source, str):
|
|
msg = "app_name must be provided when app_source is a string."
|
|
raise ValueError(msg)
|
|
else:
|
|
app_name = app_source.__name__
|
|
|
|
app_name = app_name.lower()
|
|
while "__" in app_name:
|
|
app_name = app_name.replace("__", "_")
|
|
return cls(
|
|
app_name=app_name,
|
|
app_source=app_source,
|
|
app_path=root,
|
|
app_module_path=root / app_name / f"{app_name}.py",
|
|
)
|
|
|
|
def get_state_name(self, state_cls_name: str) -> str:
|
|
"""Get the state name for the given state class name.
|
|
|
|
Args:
|
|
state_cls_name: The state class name
|
|
|
|
Returns:
|
|
The state name
|
|
"""
|
|
return reflex.utils.format.to_snake_case(
|
|
f"{self.app_name}___{self.app_name}___" + state_cls_name
|
|
)
|
|
|
|
def get_full_state_name(self, path: list[str]) -> str:
|
|
"""Get the full state name for the given state class name.
|
|
|
|
Args:
|
|
path: A list of state class names
|
|
|
|
Returns:
|
|
The full state name
|
|
"""
|
|
# NOTE: using State.get_name() somehow causes trouble here
|
|
# path = [State.get_name()] + [self.get_state_name(p) for p in path] # noqa: ERA001
|
|
path = ["reflex___state____state"] + [self.get_state_name(p) for p in path]
|
|
return ".".join(path)
|
|
|
|
def _get_globals_from_signature(self, func: Any) -> dict[str, Any]:
|
|
"""Get the globals from a function or module object.
|
|
|
|
Args:
|
|
func: function or module object
|
|
|
|
Returns:
|
|
dict of globals
|
|
"""
|
|
overrides = {}
|
|
glbs = {}
|
|
if not callable(func):
|
|
return glbs
|
|
if isinstance(func, functools.partial):
|
|
overrides = func.keywords
|
|
func = func.func
|
|
for param in inspect.signature(func).parameters.values():
|
|
if param.default is not inspect.Parameter.empty:
|
|
glbs[param.name] = param.default
|
|
glbs.update(overrides)
|
|
return glbs
|
|
|
|
def _get_source_from_app_source(self, app_source: Any) -> str:
|
|
"""Get the source from app_source.
|
|
|
|
Args:
|
|
app_source: function or module or str
|
|
|
|
Returns:
|
|
source code
|
|
"""
|
|
if isinstance(app_source, str):
|
|
return app_source
|
|
source = inspect.getsource(app_source)
|
|
source = re.sub(
|
|
r"^\s*def\s+\w+\s*\(.*?\)(\s+->\s+\w+)?:", "", source, flags=re.DOTALL
|
|
)
|
|
return textwrap.dedent(source)
|
|
|
|
def _initialize_app(self):
|
|
# disable telemetry reporting for tests
|
|
os.environ["REFLEX_TELEMETRY_ENABLED"] = "false"
|
|
# Reset global memo registries so previous AppHarness apps do not
|
|
# leak compiled component definitions into the next test app.
|
|
CUSTOM_COMPONENTS.clear()
|
|
EXPERIMENTAL_MEMOS.clear()
|
|
CustomComponent.create().get_component.cache_clear()
|
|
self.app_path.mkdir(parents=True, exist_ok=True)
|
|
if self.app_source is not None:
|
|
app_globals = self._get_globals_from_signature(self.app_source)
|
|
if isinstance(self.app_source, functools.partial):
|
|
self.app_source = self.app_source.func
|
|
# get the source from a function or module object
|
|
source_code = "\n".join([
|
|
"\n".join([
|
|
self.get_app_global_source(k, v) for k, v in app_globals.items()
|
|
]),
|
|
self._get_source_from_app_source(self.app_source),
|
|
])
|
|
get_config().loglevel = reflex.constants.LogLevel.INFO
|
|
with chdir(self.app_path):
|
|
reflex.reflex._init(
|
|
name=self.app_name,
|
|
template=reflex.constants.Templates.DEFAULT,
|
|
)
|
|
self.app_module_path.write_text(source_code)
|
|
else:
|
|
# Just initialize the web folder.
|
|
with chdir(self.app_path):
|
|
reflex.utils.prerequisites.initialize_frontend_dependencies()
|
|
with chdir(self.app_path):
|
|
# Use a new registration context for a new app.
|
|
if AppHarness._base_registration_context is None:
|
|
# Save the initial RegistrationContext for the app if we haven't already
|
|
AppHarness._base_registration_context = (
|
|
RegistrationContext.ensure_context()
|
|
)
|
|
new_registration_context = deepcopy(AppHarness._base_registration_context)
|
|
self._registry_token = RegistrationContext.set(new_registration_context)
|
|
# ensure config and app are reloaded when testing different app
|
|
config = get_config(reload=True)
|
|
# Ensure the AppHarness test does not skip State assignment due to running via pytest
|
|
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
|
|
os.environ[reflex.constants.APP_HARNESS_FLAG] = "true"
|
|
# Ensure we compile generated apps, and reload pre-existing app modules
|
|
# that were already imported so they can re-register memo definitions.
|
|
should_reload_app = (
|
|
self.app_source is not None or config.module in sys.modules
|
|
)
|
|
self.app_instance, self.app_module = (
|
|
reflex.utils.prerequisites.get_and_validate_app(
|
|
reload=should_reload_app
|
|
)
|
|
)
|
|
self.app_asgi = self.app_instance()
|
|
|
|
def _reload_state_module(self):
|
|
"""Reload the rx.State module to avoid conflict when reloading."""
|
|
reload_state_module(module=f"{self.app_name}.{self.app_name}")
|
|
|
|
def _get_backend_shutdown_handler(self):
|
|
if self.backend is None:
|
|
msg = "Backend was not initialized."
|
|
raise RuntimeError(msg)
|
|
|
|
original_shutdown = self.backend.shutdown
|
|
|
|
async def _shutdown(*args, **kwargs) -> None:
|
|
# ensure redis is closed before event loop
|
|
if (
|
|
self.app_instance is not None
|
|
and self.app_instance._state_manager is not None
|
|
):
|
|
with contextlib.suppress(ValueError):
|
|
await self.app_instance._state_manager.close()
|
|
|
|
# socketio shutdown handler
|
|
if self.app_instance is not None and self.app_instance.sio is not None:
|
|
with contextlib.suppress(TypeError):
|
|
await self.app_instance.sio.shutdown()
|
|
|
|
# sqlalchemy async engine shutdown handler
|
|
if find_spec("sqlmodel"):
|
|
try:
|
|
async_engine = reflex.model.get_async_engine(None)
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
await async_engine.dispose()
|
|
|
|
await original_shutdown(*args, **kwargs)
|
|
|
|
return _shutdown
|
|
|
|
def _start_backend(self, port: int = 0):
|
|
if self.app_asgi is None:
|
|
msg = "App was not initialized."
|
|
raise RuntimeError(msg)
|
|
self.backend = uvicorn.Server(
|
|
uvicorn.Config(
|
|
app=self.app_asgi,
|
|
host="127.0.0.1",
|
|
port=port,
|
|
)
|
|
)
|
|
self.backend.shutdown = self._get_backend_shutdown_handler()
|
|
|
|
def _run_backend(context: contextvars.Context) -> None:
|
|
if self.backend is not None:
|
|
context.run(self.backend.run)
|
|
|
|
with chdir(self.app_path):
|
|
print( # noqa: T201
|
|
"Creating backend in a new thread..."
|
|
) # for pytest diagnosis
|
|
self.backend_thread = threading.Thread(
|
|
target=_run_backend, args=(contextvars.copy_context(),)
|
|
)
|
|
self.backend_thread.start()
|
|
print("Backend started.") # for pytest diagnosis #noqa: T201
|
|
|
|
def _start_frontend(self):
|
|
# Set up the frontend.
|
|
with chdir(self.app_path):
|
|
config = get_config()
|
|
print("Polling for servers...") # for pytest diagnosis #noqa: T201
|
|
config.api_url = "http://{}:{}".format(
|
|
*self._poll_for_servers(timeout=30).getsockname(),
|
|
)
|
|
print("Building frontend...") # for pytest diagnosis #noqa: T201
|
|
reflex.utils.build.setup_frontend(self.app_path)
|
|
|
|
print("Frontend starting...") # for pytest diagnosis #noqa: T201
|
|
|
|
# Start the frontend.
|
|
self.frontend_process = reflex.utils.processes.new_process(
|
|
[
|
|
*js_runtimes.get_js_package_executor(raise_on_none=True)[0],
|
|
"run",
|
|
"dev",
|
|
],
|
|
cwd=self.app_path / reflex.utils.prerequisites.get_web_dir(),
|
|
env={"PORT": "0", "NO_COLOR": "1"},
|
|
**FRONTEND_POPEN_ARGS,
|
|
)
|
|
|
|
def _wait_frontend(self):
|
|
if self.frontend_process is None or self.frontend_process.stdout is None:
|
|
msg = "Frontend process has no stdout."
|
|
raise RuntimeError(msg)
|
|
while self.frontend_url is None:
|
|
line = self.frontend_process.stdout.readline()
|
|
if not line:
|
|
break
|
|
print(line) # for pytest diagnosis #noqa: T201
|
|
m = re.search(reflex.constants.ReactRouter.FRONTEND_LISTENING_REGEX, line)
|
|
if m is not None:
|
|
self.frontend_url = m.group(1)
|
|
config = get_config()
|
|
config.deploy_url = self.frontend_url
|
|
break
|
|
if self.frontend_url is None:
|
|
msg = "Frontend did not start"
|
|
raise RuntimeError(msg)
|
|
|
|
def consume_frontend_output():
|
|
while True:
|
|
try:
|
|
line = (
|
|
self.frontend_process.stdout.readline() # pyright: ignore [reportOptionalMemberAccess]
|
|
)
|
|
# catch I/O operation on closed file.
|
|
except ValueError as e:
|
|
console.error(str(e))
|
|
break
|
|
if not line:
|
|
break
|
|
|
|
self.frontend_output_thread = threading.Thread(target=consume_frontend_output)
|
|
self.frontend_output_thread.start()
|
|
|
|
def start(self) -> Self:
|
|
"""Start the backend in a new thread and dev frontend as a separate process.
|
|
|
|
Returns:
|
|
self
|
|
"""
|
|
self._initialize_app()
|
|
self._start_backend()
|
|
self._start_frontend()
|
|
self._wait_frontend()
|
|
return self
|
|
|
|
@staticmethod
|
|
def get_app_global_source(key: str, value: Any):
|
|
"""Get the source code of a global object.
|
|
If value is a function or class we render the actual
|
|
source of value otherwise we assign value to key.
|
|
|
|
Args:
|
|
key: variable name to assign value to.
|
|
value: value of the global variable.
|
|
|
|
Returns:
|
|
The rendered app global code.
|
|
"""
|
|
if not isinstance(value, type) and not inspect.isfunction(value):
|
|
return f"{key} = {value!r}"
|
|
return inspect.getsource(value)
|
|
|
|
def __enter__(self) -> Self:
|
|
"""Contextmanager protocol for `start()`.
|
|
|
|
Returns:
|
|
Instance of AppHarness after calling start()
|
|
"""
|
|
return self.start()
|
|
|
|
def stop(self) -> None:
|
|
"""Stop the frontend and backend servers."""
|
|
import psutil
|
|
|
|
# Quit browsers first to avoid any lingering events being sent during shutdown.
|
|
for driver in self._frontends:
|
|
driver.quit()
|
|
|
|
self._reload_state_module()
|
|
if self._registry_token is not None:
|
|
RegistrationContext.reset(self._registry_token)
|
|
|
|
if self.backend is not None:
|
|
self.backend.should_exit = True
|
|
if self.frontend_process is not None:
|
|
# https://stackoverflow.com/a/70565806
|
|
frontend_children = psutil.Process(self.frontend_process.pid).children(
|
|
recursive=True,
|
|
)
|
|
if sys.platform == "win32":
|
|
self.frontend_process.terminate()
|
|
else:
|
|
pgrp = os.getpgid(self.frontend_process.pid)
|
|
os.killpg(pgrp, signal.SIGTERM)
|
|
# kill any remaining child processes
|
|
for child in frontend_children:
|
|
# It's okay if the process is already gone.
|
|
with contextlib.suppress(psutil.NoSuchProcess):
|
|
child.terminate()
|
|
_, still_alive = psutil.wait_procs(frontend_children, timeout=3)
|
|
for child in still_alive:
|
|
# It's okay if the process is already gone.
|
|
with contextlib.suppress(psutil.NoSuchProcess):
|
|
child.kill()
|
|
# wait for main process to exit
|
|
self.frontend_process.communicate()
|
|
if self.backend_thread is not None:
|
|
self.backend_thread.join()
|
|
if self.frontend_output_thread is not None:
|
|
self.frontend_output_thread.join()
|
|
|
|
def __exit__(self, *excinfo) -> None:
|
|
"""Contextmanager protocol for `stop()`.
|
|
|
|
Args:
|
|
excinfo: sys.exc_info captured in the context block
|
|
"""
|
|
self.stop()
|
|
|
|
@staticmethod
|
|
def _poll_for(
|
|
target: Callable[[], T],
|
|
timeout: TimeoutType = None,
|
|
step: TimeoutType = None,
|
|
) -> T | Literal[False]:
|
|
"""Generic polling logic.
|
|
|
|
Args:
|
|
target: callable that returns truthy if polling condition is met.
|
|
timeout: max polling time
|
|
step: interval between checking target()
|
|
|
|
Returns:
|
|
return value of target() if truthy within timeout
|
|
False if timeout elapses
|
|
"""
|
|
if timeout is None:
|
|
timeout = DEFAULT_TIMEOUT
|
|
if step is None:
|
|
step = POLL_INTERVAL
|
|
deadline = time.time() + timeout
|
|
while time.time() < deadline:
|
|
with contextlib.suppress(Exception):
|
|
success = target()
|
|
if success:
|
|
return success
|
|
time.sleep(step)
|
|
return False
|
|
|
|
@staticmethod
|
|
async def _poll_for_async(
|
|
target: Callable[[], Coroutine[None, None, T]],
|
|
timeout: TimeoutType = None,
|
|
step: TimeoutType = None,
|
|
) -> T | bool:
|
|
"""Generic polling logic for async functions.
|
|
|
|
Args:
|
|
target: callable that returns truthy if polling condition is met.
|
|
timeout: max polling time
|
|
step: interval between checking target()
|
|
|
|
Returns:
|
|
return value of target() if truthy within timeout
|
|
False if timeout elapses
|
|
"""
|
|
if timeout is None:
|
|
timeout = DEFAULT_TIMEOUT
|
|
if step is None:
|
|
step = POLL_INTERVAL
|
|
deadline = time.time() + timeout
|
|
while time.time() < deadline:
|
|
success = await target()
|
|
if success:
|
|
return success
|
|
await asyncio.sleep(step)
|
|
return False
|
|
|
|
def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
|
|
"""Poll backend server for listening sockets.
|
|
|
|
Args:
|
|
timeout: how long to wait for listening socket.
|
|
|
|
Returns:
|
|
first active listening socket on the backend
|
|
|
|
Raises:
|
|
RuntimeError: when the backend hasn't started running
|
|
TimeoutError: when server or sockets are not ready
|
|
"""
|
|
if self.backend is None:
|
|
msg = "Backend is not running."
|
|
raise RuntimeError(msg)
|
|
backend = self.backend
|
|
# check for servers to be initialized
|
|
if not self._poll_for(
|
|
target=lambda: getattr(backend, "servers", False),
|
|
timeout=timeout,
|
|
):
|
|
msg = "Backend servers are not initialized."
|
|
raise TimeoutError(msg)
|
|
# check for sockets to be listening
|
|
if not self._poll_for(
|
|
target=lambda: getattr(backend.servers[0], "sockets", False),
|
|
timeout=timeout,
|
|
):
|
|
msg = "Backend is not listening."
|
|
raise TimeoutError(msg)
|
|
return backend.servers[0].sockets[0]
|
|
|
|
def frontend(
|
|
self,
|
|
driver_clz: type[WebDriver] | None = None,
|
|
driver_kwargs: dict[str, Any] | None = None,
|
|
driver_options: ArgOptions | None = None,
|
|
driver_option_args: list[str] | None = None,
|
|
driver_option_capabilities: dict[str, Any] | None = None,
|
|
) -> WebDriver:
|
|
"""Get a selenium webdriver instance pointed at the app.
|
|
|
|
Args:
|
|
driver_clz: webdriver.Chrome (default), webdriver.Firefox, webdriver.Safari,
|
|
webdriver.Edge, etc
|
|
driver_kwargs: additional keyword arguments to pass to the webdriver constructor
|
|
driver_options: selenium ArgOptions instance to pass to the webdriver constructor
|
|
driver_option_args: additional arguments for the webdriver options
|
|
driver_option_capabilities: additional capabilities for the webdriver options
|
|
|
|
Returns:
|
|
Instance of the given webdriver navigated to the frontend url of the app.
|
|
|
|
Raises:
|
|
RuntimeError: when selenium is not importable or frontend is not running
|
|
"""
|
|
if not has_selenium:
|
|
msg = (
|
|
"Frontend functionality requires `selenium` to be installed, "
|
|
"and it could not be imported."
|
|
)
|
|
raise RuntimeError(msg)
|
|
if self.frontend_url is None:
|
|
msg = "Frontend is not running."
|
|
raise RuntimeError(msg)
|
|
want_headless = False
|
|
if environment.APP_HARNESS_HEADLESS.get():
|
|
want_headless = True
|
|
if driver_clz is None:
|
|
requested_driver = environment.APP_HARNESS_DRIVER.get()
|
|
driver_clz = getattr(webdriver, requested_driver) # pyright: ignore [reportPossiblyUnboundVariable]
|
|
if driver_options is None:
|
|
driver_options = getattr(webdriver, f"{requested_driver}Options")() # pyright: ignore [reportPossiblyUnboundVariable]
|
|
if driver_clz is webdriver.Chrome: # pyright: ignore [reportPossiblyUnboundVariable]
|
|
if driver_options is None:
|
|
from selenium.webdriver.chrome.options import Options
|
|
|
|
driver_options = Options() # pyright: ignore [reportPossiblyUnboundVariable]
|
|
driver_options.add_argument("--class=AppHarness")
|
|
if want_headless:
|
|
driver_options.add_argument("--headless=new")
|
|
elif driver_clz is webdriver.Firefox: # pyright: ignore [reportPossiblyUnboundVariable]
|
|
if driver_options is None:
|
|
from selenium.webdriver.firefox.options import Options
|
|
|
|
driver_options = Options() # pyright: ignore [reportPossiblyUnboundVariable]
|
|
if want_headless:
|
|
driver_options.add_argument("-headless")
|
|
elif driver_clz is webdriver.Edge: # pyright: ignore [reportPossiblyUnboundVariable]
|
|
if driver_options is None:
|
|
from selenium.webdriver.edge.options import Options
|
|
|
|
driver_options = Options() # pyright: ignore [reportPossiblyUnboundVariable]
|
|
if want_headless:
|
|
driver_options.add_argument("headless")
|
|
if driver_options is None:
|
|
msg = f"Could not determine options for {driver_clz}"
|
|
raise RuntimeError(msg)
|
|
if args := environment.APP_HARNESS_DRIVER_ARGS.get():
|
|
for arg in args.split(","):
|
|
driver_options.add_argument(arg)
|
|
if driver_option_args is not None:
|
|
for arg in driver_option_args:
|
|
driver_options.add_argument(arg)
|
|
if driver_option_capabilities is not None:
|
|
for key, value in driver_option_capabilities.items():
|
|
driver_options.set_capability(key, value)
|
|
if driver_kwargs is None:
|
|
driver_kwargs = {}
|
|
driver = driver_clz(options=driver_options, **driver_kwargs) # pyright: ignore [reportOptionalCall, reportArgumentType]
|
|
driver.get(self.frontend_url)
|
|
self._frontends.append(driver)
|
|
return driver
|
|
|
|
def token_manager(self) -> TokenManager:
|
|
"""Get the token manager for the app instance.
|
|
|
|
Returns:
|
|
The current token_manager attached to the app's EventNamespace.
|
|
"""
|
|
assert self.app_instance is not None
|
|
app_event_namespace = self.app_instance.event_namespace
|
|
assert app_event_namespace is not None
|
|
app_token_manager = app_event_namespace._token_manager
|
|
assert app_token_manager is not None
|
|
return app_token_manager
|
|
|
|
def poll_for_content(
|
|
self,
|
|
element: WebElement,
|
|
timeout: TimeoutType = None,
|
|
exp_not_equal: str = "",
|
|
) -> str:
|
|
"""Poll element.text for change.
|
|
|
|
Args:
|
|
element: selenium webdriver element to check
|
|
timeout: how long to poll element.text
|
|
exp_not_equal: exit the polling loop when the element text does not match
|
|
|
|
Returns:
|
|
The element text when the polling loop exited
|
|
|
|
Raises:
|
|
TimeoutError: when the timeout expires before text changes
|
|
"""
|
|
if not self._poll_for(
|
|
target=lambda: element.text != exp_not_equal,
|
|
timeout=timeout,
|
|
):
|
|
msg = f"{element} content remains {exp_not_equal!r} while polling."
|
|
raise TimeoutError(msg)
|
|
return element.text
|
|
|
|
def poll_for_value(
|
|
self,
|
|
element: WebElement,
|
|
timeout: TimeoutType = None,
|
|
exp_not_equal: str | Sequence[str] = "",
|
|
) -> str | None:
|
|
"""Poll element.get_attribute("value") for change.
|
|
|
|
Args:
|
|
element: selenium webdriver element to check
|
|
timeout: how long to poll element value attribute
|
|
exp_not_equal: exit the polling loop when the value does not match
|
|
|
|
Returns:
|
|
The element value when the polling loop exited
|
|
|
|
Raises:
|
|
TimeoutError: when the timeout expires before value changes
|
|
"""
|
|
exp_not_equal = (
|
|
(exp_not_equal,) if isinstance(exp_not_equal, str) else exp_not_equal
|
|
)
|
|
if not self._poll_for(
|
|
target=lambda: element.get_attribute("value") not in exp_not_equal,
|
|
timeout=timeout,
|
|
):
|
|
msg = f"{element} content remains {exp_not_equal!r} while polling."
|
|
raise TimeoutError(msg)
|
|
return element.get_attribute("value")
|
|
|
|
@staticmethod
|
|
def poll_for_or_raise_timeout(
|
|
target: Callable[[], T],
|
|
timeout: TimeoutType = None,
|
|
step: TimeoutType = None,
|
|
) -> T:
|
|
"""Poll target callable for a truthy return value.
|
|
|
|
Like `_poll_for`, but raises a `TimeoutError` if the target does not
|
|
return a truthy value within the timeout.
|
|
|
|
Args:
|
|
target: callable that returns truthy if polling condition is met.
|
|
timeout: max polling time
|
|
step: interval between checking target()
|
|
|
|
Returns:
|
|
return value of target() if truthy within timeout
|
|
|
|
Raises:
|
|
TimeoutError: when target does not return a truthy value within timeout
|
|
"""
|
|
result = AppHarness._poll_for(
|
|
target=target,
|
|
timeout=timeout,
|
|
step=step,
|
|
)
|
|
if result is False:
|
|
msg = "Target did not return a truthy value while polling."
|
|
raise TimeoutError(msg)
|
|
return result
|
|
|
|
@staticmethod
|
|
def expect(
|
|
target: Callable[[], T],
|
|
timeout: TimeoutType = None,
|
|
step: TimeoutType = None,
|
|
):
|
|
"""Expect a target callable to return a truthy value within the timeout.
|
|
|
|
Args:
|
|
target: callable that returns truthy if polling condition is met.
|
|
timeout: max polling time
|
|
step: interval between checking target()
|
|
"""
|
|
AppHarness.poll_for_or_raise_timeout(
|
|
target=target,
|
|
timeout=timeout,
|
|
step=step,
|
|
)
|
|
|
|
|
|
class SimpleHTTPRequestHandlerCustomErrors(SimpleHTTPRequestHandler):
|
|
"""SimpleHTTPRequestHandler with custom error page handling."""
|
|
|
|
def __init__(self, *args, error_page_map: dict[int, Path], **kwargs):
|
|
"""Initialize the handler.
|
|
|
|
Args:
|
|
error_page_map: map of error code to error page path
|
|
*args: passed through to superclass
|
|
**kwargs: passed through to superclass
|
|
"""
|
|
self.error_page_map = error_page_map
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def send_error(
|
|
self, code: int, message: str | None = None, explain: str | None = None
|
|
) -> None:
|
|
"""Send the error page for the given error code.
|
|
|
|
If the code matches a custom error page, then message and explain are
|
|
ignored.
|
|
|
|
Args:
|
|
code: the error code
|
|
message: the error message
|
|
explain: the error explanation
|
|
"""
|
|
error_page = self.error_page_map.get(code)
|
|
if error_page:
|
|
self.send_response(code, message)
|
|
self.send_header("Connection", "close")
|
|
body = error_page.read_bytes()
|
|
self.send_header("Content-Type", self.error_content_type)
|
|
self.send_header("Content-Length", str(len(body)))
|
|
self.end_headers()
|
|
self.wfile.write(body)
|
|
else:
|
|
super().send_error(code, message, explain)
|
|
|
|
|
|
class Subdir404TCPServer(socketserver.TCPServer):
|
|
"""TCPServer for SimpleHTTPRequestHandlerCustomErrors that serves from a subdir."""
|
|
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
root: Path,
|
|
error_page_map: dict[int, Path] | None,
|
|
**kwargs,
|
|
):
|
|
"""Initialize the server.
|
|
|
|
Args:
|
|
root: the root directory to serve from
|
|
error_page_map: map of error code to error page path
|
|
*args: passed through to superclass
|
|
**kwargs: passed through to superclass
|
|
"""
|
|
self.root = root
|
|
self.error_page_map = error_page_map or {}
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def finish_request(self, request: socket.socket, client_address: tuple[str, int]):
|
|
"""Finish one request by instantiating RequestHandlerClass.
|
|
|
|
Args:
|
|
request: the requesting socket
|
|
client_address: (host, port) referring to the client's address.
|
|
"""
|
|
self.RequestHandlerClass(
|
|
request,
|
|
client_address,
|
|
self,
|
|
directory=str(self.root), # pyright: ignore [reportCallIssue]
|
|
error_page_map=self.error_page_map, # pyright: ignore [reportCallIssue]
|
|
)
|
|
|
|
|
|
class AppHarnessProd(AppHarness):
|
|
"""AppHarnessProd executes a reflex app in-process for testing.
|
|
|
|
In prod mode, instead of running `react-router dev` the app is exported as static
|
|
files and served via the builtin python http.server with custom 404 redirect
|
|
handling. Additionally, the backend runs in multi-worker mode.
|
|
"""
|
|
|
|
frontend_thread: threading.Thread | None = None
|
|
frontend_server: Subdir404TCPServer | None = None
|
|
|
|
def _run_frontend(self):
|
|
web_root = (
|
|
self.app_path
|
|
/ reflex.utils.prerequisites.get_web_dir()
|
|
/ reflex.constants.Dirs.STATIC
|
|
)
|
|
config = reflex.config.get_config()
|
|
with Subdir404TCPServer(
|
|
("", 0),
|
|
SimpleHTTPRequestHandlerCustomErrors,
|
|
root=web_root,
|
|
error_page_map={
|
|
404: web_root / config.prepend_frontend_path("/404.html").lstrip("/"),
|
|
},
|
|
) as self.frontend_server:
|
|
frontend_path = config.frontend_path.strip("/")
|
|
self.frontend_url = "http://localhost:{1}".format(
|
|
*self.frontend_server.socket.getsockname()
|
|
) + (f"/{frontend_path}/" if frontend_path else "/")
|
|
self.frontend_server.serve_forever()
|
|
|
|
def _start_frontend(self):
|
|
# Set up the frontend.
|
|
with chdir(self.app_path):
|
|
config = get_config()
|
|
print("Polling for servers...") # for pytest diagnosis #noqa: T201
|
|
config.api_url = "http://{}:{}".format(
|
|
*self._poll_for_servers(timeout=30).getsockname(),
|
|
)
|
|
print("Building frontend...") # for pytest diagnosis #noqa: T201
|
|
|
|
get_config().loglevel = reflex.constants.LogLevel.INFO
|
|
|
|
reflex.utils.prerequisites.assert_in_reflex_dir()
|
|
|
|
if reflex.utils.prerequisites.needs_reinit():
|
|
reflex.reflex._init(name=get_config().app_name)
|
|
|
|
export(
|
|
zipping=False,
|
|
frontend=True,
|
|
backend=False,
|
|
loglevel=reflex.constants.LogLevel.INFO,
|
|
env=reflex.constants.Env.PROD,
|
|
)
|
|
|
|
print("Frontend starting...") # for pytest diagnosis #noqa: T201
|
|
|
|
self.frontend_thread = threading.Thread(target=self._run_frontend)
|
|
self.frontend_thread.start()
|
|
|
|
def _wait_frontend(self):
|
|
self._poll_for(lambda: self.frontend_server is not None)
|
|
if self.frontend_server is None or not self.frontend_server.socket.fileno():
|
|
msg = "Frontend did not start"
|
|
raise RuntimeError(msg)
|
|
|
|
def _start_backend(self):
|
|
if self.app_asgi is None:
|
|
msg = "App was not initialized."
|
|
raise RuntimeError(msg)
|
|
environment.REFLEX_SKIP_COMPILE.set(True)
|
|
self.backend = uvicorn.Server(
|
|
uvicorn.Config(
|
|
app=self.app_asgi,
|
|
host="127.0.0.1",
|
|
port=0,
|
|
workers=reflex.utils.processes.get_num_workers(),
|
|
),
|
|
)
|
|
self.backend.shutdown = self._get_backend_shutdown_handler()
|
|
|
|
def _run_backend(context: contextvars.Context) -> None:
|
|
if self.backend is not None:
|
|
context.run(self.backend.run)
|
|
|
|
print( # noqa: T201
|
|
"Creating backend in a new thread..."
|
|
)
|
|
self.backend_thread = threading.Thread(
|
|
target=_run_backend, args=(contextvars.copy_context(),)
|
|
)
|
|
self.backend_thread.start()
|
|
print("Backend started.") # for pytest diagnosis #noqa: T201
|
|
|
|
def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
|
|
try:
|
|
return super()._poll_for_servers(timeout)
|
|
finally:
|
|
environment.REFLEX_SKIP_COMPILE.set(None)
|
|
|
|
def stop(self):
|
|
"""Stop the frontend python webserver."""
|
|
super().stop()
|
|
if self.frontend_server is not None:
|
|
self.frontend_server.shutdown()
|
|
if self.frontend_thread is not None:
|
|
self.frontend_thread.join()
|