eptm_dashboard/.venv/lib/python3.12/site-packages/reflex/model.py

720 lines
24 KiB
Python

"""Database built into Reflex."""
from __future__ import annotations
import re
from collections import defaultdict
from contextlib import suppress
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, ClassVar
from reflex_base.config import get_config
from reflex_base.environment import environment
from reflex_base.utils import console
from reflex_base.utils.serializers import serializer
if TYPE_CHECKING:
from typing import TypeVar
import sqlalchemy
import sqlmodel
SQLModelOrSqlAlchemy = (
type[sqlmodel.SQLModel] | type[sqlalchemy.orm.DeclarativeBase]
)
SQLModelOrSqlAlchemyT = TypeVar("SQLModelOrSqlAlchemyT", bound=SQLModelOrSqlAlchemy)
def _safe_db_url_for_logging(url: str) -> str:
"""Remove username and password from the database URL for logging.
Args:
url: The database URL.
Returns:
The database URL with the username and password removed.
"""
return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
def _print_db_not_available(*args, **kwargs):
msg = (
"Database is not available. Please install the required packages: "
"`pip install reflex[db]`."
)
raise ImportError(msg)
class _ClassThatErrorsOnInit:
def __init__(self, *args, **kwargs):
_print_db_not_available(*args, **kwargs)
if find_spec("sqlalchemy"):
import sqlalchemy
import sqlalchemy.exc
import sqlalchemy.ext.asyncio
import sqlalchemy.orm
_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
def get_engine_args(url: str | None = None) -> dict[str, Any]:
"""Get the database engine arguments.
Args:
url: The database url.
Returns:
The database engine arguments as a dict.
"""
kwargs: dict[str, Any] = {
# Print the SQL queries if the log level is INFO or lower.
"echo": environment.SQLALCHEMY_ECHO.get(),
# Check connections before returning them.
"pool_pre_ping": environment.SQLALCHEMY_POOL_PRE_PING.get(),
"pool_size": environment.SQLALCHEMY_POOL_SIZE.get(),
"max_overflow": environment.SQLALCHEMY_MAX_OVERFLOW.get(),
"pool_recycle": environment.SQLALCHEMY_POOL_RECYCLE.get(),
"pool_timeout": environment.SQLALCHEMY_POOL_TIMEOUT.get(),
}
conf = get_config()
url = url or conf.db_url
if url is not None and url.startswith("sqlite"):
# Needed for the admin dash on sqlite.
kwargs["connect_args"] = {"check_same_thread": False}
return kwargs
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
"""Get the database engine.
Args:
url: the DB url to use.
Returns:
The database engine.
Raises:
ValueError: If the database url is None.
"""
conf = get_config()
url = url or conf.db_url
if url is None:
msg = "No database url configured"
raise ValueError(msg)
global _ENGINE
if url in _ENGINE:
return _ENGINE[url]
if not environment.ALEMBIC_CONFIG.get().exists():
console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first.",
dedupe=True,
)
_ENGINE[url] = sqlalchemy.engine.create_engine(
url,
**get_engine_args(url),
)
return _ENGINE[url]
def create_all():
"""Create all the tables."""
engine = get_engine()
ModelRegistry.get_metadata().create_all(engine)
def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
"""Get the async database engine.
Args:
url: The database url.
Returns:
The async database engine.
Raises:
ValueError: If the async database url is None.
"""
if url is None:
conf = get_config()
url = conf.async_db_url
if url is not None and conf.db_url is not None:
async_db_url_tail = url.partition("://")[2]
db_url_tail = conf.db_url.partition("://")[2]
if async_db_url_tail != db_url_tail:
console.warn(
f"async_db_url `{_safe_db_url_for_logging(url)}` "
"should reference the same database as "
f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
)
if url is None:
msg = "No async database url configured"
raise ValueError(msg)
global _ASYNC_ENGINE
if url in _ASYNC_ENGINE:
return _ASYNC_ENGINE[url]
if not environment.ALEMBIC_CONFIG.get().exists():
console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first.",
dedupe=True,
)
_ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
url,
**get_engine_args(url),
)
return _ASYNC_ENGINE[url]
def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
"""Get a bare sqlalchemy session to interact with the database.
Args:
url: The database url.
Returns:
A database session.
"""
return sqlalchemy.orm.Session(get_engine(url))
class ModelRegistry:
"""Registry for all models.
Attributes:
_metadata: Cache the metadata to avoid re-creating it.
"""
models: ClassVar[set[SQLModelOrSqlAlchemy]] = set()
_metadata: ClassVar[sqlalchemy.MetaData | None] = None
@classmethod
def register(cls, model: SQLModelOrSqlAlchemyT) -> SQLModelOrSqlAlchemyT:
"""Register a model. Can be used directly or as a decorator.
Args:
model: The model to register.
Returns:
The model passed in as an argument (Allows decorator usage)
"""
cls.models.add(model)
return model
@classmethod
def get_models(cls, include_empty: bool = False) -> set[SQLModelOrSqlAlchemy]:
"""Get registered models.
Args:
include_empty: If True, include models with empty metadata.
Returns:
The registered models.
"""
if include_empty:
return cls.models
return {
model for model in cls.models if not cls._model_metadata_is_empty(model)
}
@staticmethod
def _model_metadata_is_empty(model: SQLModelOrSqlAlchemy) -> bool:
"""Check if the model metadata is empty.
Args:
model: The model to check.
Returns:
True if the model metadata is empty, False otherwise.
"""
return len(model.metadata.tables) == 0
@classmethod
def get_metadata(cls) -> sqlalchemy.MetaData:
"""Get the database metadata.
Returns:
The database metadata.
"""
if cls._metadata is not None:
return cls._metadata
models = cls.get_models(include_empty=False)
if len(models) == 1:
metadata = next(iter(models)).metadata
else:
# Merge the metadata from all the models.
# This allows mixing bare sqlalchemy models with sqlmodel models in one database.
metadata = sqlalchemy.MetaData()
for model in cls.get_models():
for table in model.metadata.tables.values():
table.to_metadata(metadata)
# Cache the metadata
cls._metadata = metadata
return metadata
else:
get_engine_args = _print_db_not_available
get_engine = _print_db_not_available
get_async_engine = _print_db_not_available
sqla_session = _print_db_not_available
ModelRegistry = _ClassThatErrorsOnInit # pyright: ignore [reportAssignmentType]
if find_spec("sqlalchemy") and find_spec("alembic"):
import alembic.autogenerate
import alembic.command
import alembic.config
import alembic.operations.ops
import alembic.runtime.environment
import alembic.script
from alembic.runtime.migration import MigrationContext
from alembic.script.base import Script
def format_revision(
rev: Script,
current_rev: str | None,
current_reached_ref: list[bool],
) -> str:
"""Format a single revision for display.
Args:
rev: The alembic script object
current_rev: The currently applied revision ID
current_reached_ref: Mutable reference to track if we've reached current revision
Returns:
Formatted string for display
"""
current = rev.revision
message = rev.doc
# Determine if this migration is applied
if current_rev is None:
is_applied = False
elif current == current_rev:
is_applied = True
current_reached_ref[0] = True
else:
is_applied = not current_reached_ref[0]
# Show checkmark or X with colors
status_icon = "[green]✓[/green]" if is_applied else "[red]✗[/red]"
head_marker = " (head)" if rev.is_head else ""
# Format output with message
return f" [{status_icon}] {current}{head_marker}, {message}"
def _alembic_config():
"""Get the alembic configuration and script_directory.
Returns:
tuple of (config, script_directory)
"""
config = alembic.config.Config(environment.ALEMBIC_CONFIG.get())
if not config.get_main_option("script_location"):
config.set_main_option("script_location", "version")
return config, alembic.script.ScriptDirectory.from_config(config)
def _alembic_render_item(
type_: str,
obj: Any,
autogen_context: alembic.autogenerate.api.AutogenContext,
):
"""Alembic render_item hook call.
This method is called to provide python code for the given obj,
but currently it is only used to add `sqlmodel` to the import list
when generating migration scripts.
See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
Args:
type_: One of "schema", "table", "column", "index",
"unique_constraint", or "foreign_key_constraint".
obj: The object being rendered.
autogen_context: Shared AutogenContext passed to each render_item call.
Returns:
False - Indicating that the default rendering should be used.
"""
if find_spec("sqlmodel"):
autogen_context.imports.add("import sqlmodel")
return False
def alembic_init():
"""Initialize alembic for the project."""
alembic.command.init(
config=alembic.config.Config(environment.ALEMBIC_CONFIG.get()),
directory=str(environment.ALEMBIC_CONFIG.get().parent / "alembic"),
)
def get_migration_history():
"""Get migration history with current database state.
Returns:
tuple: (current_revision, revisions_list) where revisions_list is in chronological order
"""
# Get current revision from database
with get_engine().connect() as connection:
context = MigrationContext.configure(connection)
current_rev = context.get_current_revision()
# Get all revisions from base to head
_, script_dir = _alembic_config()
revisions = list(script_dir.walk_revisions())
revisions.reverse() # Reverse to get chronological order (base first)
return current_rev, revisions
def alembic_autogenerate(
connection: sqlalchemy.engine.Connection,
message: str | None = None,
write_migration_scripts: bool = True,
) -> bool:
"""Generate migration scripts for alembic-detectable changes.
Args:
connection: SQLAlchemy connection to use when detecting changes.
message: Human readable identifier describing the generated revision.
write_migration_scripts: If True, write autogenerated revisions to script directory.
Returns:
True when changes have been detected.
"""
if not environment.ALEMBIC_CONFIG.get().exists():
return False
config, script_directory = _alembic_config()
revision_context = alembic.autogenerate.api.RevisionContext(
config=config,
script_directory=script_directory,
command_args=defaultdict(
lambda: None,
autogenerate=True,
head="head",
message=message,
),
)
writer = alembic.autogenerate.rewriter.Rewriter()
@writer.rewrites(alembic.operations.ops.AddColumnOp)
def render_add_column_with_server_default(
context: MigrationContext,
revision: str | None,
op: Any,
):
# Carry the sqlmodel default as server_default so that newly added
# columns get the desired default value in existing rows.
if op.column.default is not None and op.column.server_default is None:
op.column.server_default = sqlalchemy.DefaultClause(
sqlalchemy.sql.expression.literal(op.column.default.arg),
)
return op
def run_autogenerate(rev: str, context: MigrationContext):
revision_context.run_autogenerate(rev, context)
return []
with alembic.runtime.environment.EnvironmentContext(
config=config,
script=script_directory,
fn=run_autogenerate,
) as env:
env.configure(
connection=connection,
target_metadata=ModelRegistry.get_metadata(),
render_item=_alembic_render_item,
process_revision_directives=writer,
compare_type=False,
include_schemas=environment.ALEMBIC_INCLUDE_SCHEMAS.get(),
render_as_batch=True, # for sqlite compatibility
)
env.run_migrations()
changes_detected = False
if revision_context.generated_revisions:
upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
if upgrade_ops is not None:
changes_detected = bool(upgrade_ops.ops)
if changes_detected and write_migration_scripts:
# Must iterate the generator to actually write the scripts.
_ = tuple(revision_context.generate_scripts())
return changes_detected
def _alembic_upgrade(
connection: sqlalchemy.engine.Connection,
to_rev: str = "head",
) -> None:
"""Apply alembic migrations up to the given revision.
Args:
connection: SQLAlchemy connection to use when performing upgrade.
to_rev: Revision to migrate towards.
"""
config, script_directory = _alembic_config()
def run_upgrade(rev: str, context: MigrationContext):
return script_directory._upgrade_revs(to_rev, rev)
with alembic.runtime.environment.EnvironmentContext(
config=config,
script=script_directory,
fn=run_upgrade,
) as env:
env.configure(connection=connection)
env.run_migrations()
def migrate(autogenerate: bool = False) -> bool | None:
"""Execute alembic migrations for all model classes.
If alembic is not installed or has not been initialized for the project,
then no action is performed.
If there are no revisions currently tracked by alembic, then
an initial revision will be created based on sqlmodel metadata.
If models in the app have changed in incompatible ways that alembic
cannot automatically generate revisions for, the app may not be able to
start up until migration scripts have been corrected by hand.
Args:
autogenerate: If True, generate migration script and use it to upgrade schema
(otherwise, just bring the schema to current "head" revision).
Returns:
True - indicating the process was successful.
None - indicating the process was skipped.
"""
if not environment.ALEMBIC_CONFIG.get().exists():
return None
with get_engine().connect() as connection:
_alembic_upgrade(connection=connection)
if autogenerate:
changes_detected = alembic_autogenerate(connection=connection)
if changes_detected:
_alembic_upgrade(connection=connection)
connection.commit()
return True
else:
alembic_init = _print_db_not_available
get_migration_history = _print_db_not_available
alembic_autogenerate = _print_db_not_available
migrate = _print_db_not_available
if find_spec("sqlmodel") and find_spec("sqlalchemy") and find_spec("pydantic"):
import sqlmodel
from sqlmodel.ext.asyncio.session import AsyncSession
_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
def get_db_status() -> dict[str, bool]:
"""Checks the status of the database connection.
Attempts to connect to the database and execute a simple query to verify connectivity.
Returns:
The status of the database connection.
"""
status = True
try:
engine = get_engine()
with engine.connect() as connection:
connection.execute(sqlalchemy.text("SELECT 1"))
except Exception as exc:
status = False
console.error(
f"Database health check failed: {exc} (subsequent errors will not be logged)",
dedupe=True,
)
return {"db": status}
@serializer
def serialize_sqlmodel(m: sqlmodel.SQLModel) -> dict[str, Any]:
"""Serialize a SQLModel object to a dictionary.
Args:
m: The SQLModel object to serialize.
Returns:
The serialized object as a dictionary.
"""
base_fields = m.model_dump()
relationships = {}
# SQLModel relationships do not appear in __fields__, but should be included if present.
for name in m.__sqlmodel_relationships__:
with suppress(
sqlalchemy.orm.exc.DetachedInstanceError # This happens when the relationship was never loaded and the session is closed.
):
relationships[name] = getattr(m, name)
return {
**base_fields,
**relationships,
}
def _warn_about_model_deprecation():
console.deprecate(
feature_name="reflex.Model",
reason="Directly use database ORM layer, like sqlalchemy or SQLModel",
deprecation_version="0.9.2",
removal_version="1.0.0",
)
class Model(sqlmodel.SQLModel):
"""Base class to define a table in the database.
Attributes:
id: The primary key for the table.
"""
id: int | None = sqlmodel.Field(default=None, primary_key=True)
model_config = { # pyright: ignore [reportAssignmentType]
"arbitrary_types_allowed": True,
"use_enum_values": True,
"extra": "allow",
}
def __init_subclass__(cls, **kwargs):
"""Automatically register subclasses as models."""
super().__init_subclass__(**kwargs)
_warn_about_model_deprecation()
@staticmethod
def create_all():
"""Create all the tables."""
_warn_about_model_deprecation()
create_all()
@staticmethod
def get_db_engine():
"""Get the database engine.
Returns:
The database engine.
"""
_warn_about_model_deprecation()
return get_engine()
@classmethod
def alembic_init(cls):
"""Initialize alembic for the project."""
_warn_about_model_deprecation()
alembic_init()
@classmethod
def get_migration_history(cls):
"""Get migration history with current database state.
Returns:
tuple: (current_revision, revisions_list) where revisions_list is in chronological order
"""
_warn_about_model_deprecation()
return get_migration_history()
@classmethod
def alembic_autogenerate(
cls,
connection: sqlalchemy.engine.Connection,
message: str | None = None,
write_migration_scripts: bool = True,
) -> bool:
"""Generate migration scripts for alembic-detectable changes.
Args:
connection: SQLAlchemy connection to use when detecting changes.
message: Human readable identifier describing the generated revision.
write_migration_scripts: If True, write autogenerated revisions to script directory.
Returns:
True when changes have been detected.
"""
_warn_about_model_deprecation()
return alembic_autogenerate(
connection=connection,
message=message,
write_migration_scripts=write_migration_scripts,
)
@classmethod
def migrate(cls, autogenerate: bool = False) -> bool | None:
"""Execute alembic migrations for all sqlmodel Model classes.
If alembic is not installed or has not been initialized for the project,
then no action is performed.
If there are no revisions currently tracked by alembic, then
an initial revision will be created based on sqlmodel metadata.
If models in the app have changed in incompatible ways that alembic
cannot automatically generate revisions for, the app may not be able to
start up until migration scripts have been corrected by hand.
Args:
autogenerate: If True, generate migration script and use it to upgrade schema
(otherwise, just bring the schema to current "head" revision).
Returns:
True - indicating the process was successful.
None - indicating the process was skipped.
"""
_warn_about_model_deprecation()
return migrate(autogenerate=autogenerate)
@classmethod
def select(cls):
"""Select rows from the table.
Returns:
The select statement.
"""
return sqlmodel.select(cls)
ModelRegistry.register(Model)
def session(url: str | None = None) -> sqlmodel.Session:
"""Get a sqlmodel session to interact with the database.
Args:
url: The database url.
Returns:
A database session.
"""
return sqlmodel.Session(get_engine(url))
def asession(url: str | None = None) -> AsyncSession:
"""Get an async sqlmodel session to interact with the database.
async with rx.asession() as asession:
...
Most operations against the `asession` must be awaited.
Args:
url: The database url.
Returns:
An async database session.
"""
global _AsyncSessionLocal
if url not in _AsyncSessionLocal:
_AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker(
bind=get_async_engine(url),
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
return _AsyncSessionLocal[url]()
else:
get_db_status = _print_db_not_available
session = _print_db_not_available
asession = _print_db_not_available
Model = _ClassThatErrorsOnInit # pyright: ignore [reportAssignmentType]