720 lines
24 KiB
Python
720 lines
24 KiB
Python
"""Database built into Reflex."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from collections import defaultdict
|
|
from contextlib import suppress
|
|
from importlib.util import find_spec
|
|
from typing import TYPE_CHECKING, Any, ClassVar
|
|
|
|
from reflex_base.config import get_config
|
|
from reflex_base.environment import environment
|
|
from reflex_base.utils import console
|
|
from reflex_base.utils.serializers import serializer
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import TypeVar
|
|
|
|
import sqlalchemy
|
|
import sqlmodel
|
|
|
|
SQLModelOrSqlAlchemy = (
|
|
type[sqlmodel.SQLModel] | type[sqlalchemy.orm.DeclarativeBase]
|
|
)
|
|
|
|
SQLModelOrSqlAlchemyT = TypeVar("SQLModelOrSqlAlchemyT", bound=SQLModelOrSqlAlchemy)
|
|
|
|
|
|
def _safe_db_url_for_logging(url: str) -> str:
|
|
"""Remove username and password from the database URL for logging.
|
|
|
|
Args:
|
|
url: The database URL.
|
|
|
|
Returns:
|
|
The database URL with the username and password removed.
|
|
"""
|
|
return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
|
|
|
|
|
|
def _print_db_not_available(*args, **kwargs):
|
|
msg = (
|
|
"Database is not available. Please install the required packages: "
|
|
"`pip install reflex[db]`."
|
|
)
|
|
raise ImportError(msg)
|
|
|
|
|
|
class _ClassThatErrorsOnInit:
|
|
def __init__(self, *args, **kwargs):
|
|
_print_db_not_available(*args, **kwargs)
|
|
|
|
|
|
if find_spec("sqlalchemy"):
|
|
import sqlalchemy
|
|
import sqlalchemy.exc
|
|
import sqlalchemy.ext.asyncio
|
|
import sqlalchemy.orm
|
|
|
|
_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
|
|
_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
|
|
|
|
def get_engine_args(url: str | None = None) -> dict[str, Any]:
|
|
"""Get the database engine arguments.
|
|
|
|
Args:
|
|
url: The database url.
|
|
|
|
Returns:
|
|
The database engine arguments as a dict.
|
|
"""
|
|
kwargs: dict[str, Any] = {
|
|
# Print the SQL queries if the log level is INFO or lower.
|
|
"echo": environment.SQLALCHEMY_ECHO.get(),
|
|
# Check connections before returning them.
|
|
"pool_pre_ping": environment.SQLALCHEMY_POOL_PRE_PING.get(),
|
|
"pool_size": environment.SQLALCHEMY_POOL_SIZE.get(),
|
|
"max_overflow": environment.SQLALCHEMY_MAX_OVERFLOW.get(),
|
|
"pool_recycle": environment.SQLALCHEMY_POOL_RECYCLE.get(),
|
|
"pool_timeout": environment.SQLALCHEMY_POOL_TIMEOUT.get(),
|
|
}
|
|
conf = get_config()
|
|
url = url or conf.db_url
|
|
if url is not None and url.startswith("sqlite"):
|
|
# Needed for the admin dash on sqlite.
|
|
kwargs["connect_args"] = {"check_same_thread": False}
|
|
return kwargs
|
|
|
|
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
|
|
"""Get the database engine.
|
|
|
|
Args:
|
|
url: the DB url to use.
|
|
|
|
Returns:
|
|
The database engine.
|
|
|
|
Raises:
|
|
ValueError: If the database url is None.
|
|
"""
|
|
conf = get_config()
|
|
url = url or conf.db_url
|
|
if url is None:
|
|
msg = "No database url configured"
|
|
raise ValueError(msg)
|
|
|
|
global _ENGINE
|
|
if url in _ENGINE:
|
|
return _ENGINE[url]
|
|
|
|
if not environment.ALEMBIC_CONFIG.get().exists():
|
|
console.warn(
|
|
"Database is not initialized, run [bold]reflex db init[/bold] first.",
|
|
dedupe=True,
|
|
)
|
|
_ENGINE[url] = sqlalchemy.engine.create_engine(
|
|
url,
|
|
**get_engine_args(url),
|
|
)
|
|
return _ENGINE[url]
|
|
|
|
def create_all():
|
|
"""Create all the tables."""
|
|
engine = get_engine()
|
|
ModelRegistry.get_metadata().create_all(engine)
|
|
|
|
def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
|
|
"""Get the async database engine.
|
|
|
|
Args:
|
|
url: The database url.
|
|
|
|
Returns:
|
|
The async database engine.
|
|
|
|
Raises:
|
|
ValueError: If the async database url is None.
|
|
"""
|
|
if url is None:
|
|
conf = get_config()
|
|
url = conf.async_db_url
|
|
if url is not None and conf.db_url is not None:
|
|
async_db_url_tail = url.partition("://")[2]
|
|
db_url_tail = conf.db_url.partition("://")[2]
|
|
if async_db_url_tail != db_url_tail:
|
|
console.warn(
|
|
f"async_db_url `{_safe_db_url_for_logging(url)}` "
|
|
"should reference the same database as "
|
|
f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
|
|
)
|
|
if url is None:
|
|
msg = "No async database url configured"
|
|
raise ValueError(msg)
|
|
|
|
global _ASYNC_ENGINE
|
|
if url in _ASYNC_ENGINE:
|
|
return _ASYNC_ENGINE[url]
|
|
|
|
if not environment.ALEMBIC_CONFIG.get().exists():
|
|
console.warn(
|
|
"Database is not initialized, run [bold]reflex db init[/bold] first.",
|
|
dedupe=True,
|
|
)
|
|
_ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
|
|
url,
|
|
**get_engine_args(url),
|
|
)
|
|
return _ASYNC_ENGINE[url]
|
|
|
|
def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
|
|
"""Get a bare sqlalchemy session to interact with the database.
|
|
|
|
Args:
|
|
url: The database url.
|
|
|
|
Returns:
|
|
A database session.
|
|
"""
|
|
return sqlalchemy.orm.Session(get_engine(url))
|
|
|
|
class ModelRegistry:
|
|
"""Registry for all models.
|
|
|
|
Attributes:
|
|
_metadata: Cache the metadata to avoid re-creating it.
|
|
"""
|
|
|
|
models: ClassVar[set[SQLModelOrSqlAlchemy]] = set()
|
|
|
|
_metadata: ClassVar[sqlalchemy.MetaData | None] = None
|
|
|
|
@classmethod
|
|
def register(cls, model: SQLModelOrSqlAlchemyT) -> SQLModelOrSqlAlchemyT:
|
|
"""Register a model. Can be used directly or as a decorator.
|
|
|
|
Args:
|
|
model: The model to register.
|
|
|
|
Returns:
|
|
The model passed in as an argument (Allows decorator usage)
|
|
"""
|
|
cls.models.add(model)
|
|
return model
|
|
|
|
@classmethod
|
|
def get_models(cls, include_empty: bool = False) -> set[SQLModelOrSqlAlchemy]:
|
|
"""Get registered models.
|
|
|
|
Args:
|
|
include_empty: If True, include models with empty metadata.
|
|
|
|
Returns:
|
|
The registered models.
|
|
"""
|
|
if include_empty:
|
|
return cls.models
|
|
return {
|
|
model for model in cls.models if not cls._model_metadata_is_empty(model)
|
|
}
|
|
|
|
@staticmethod
|
|
def _model_metadata_is_empty(model: SQLModelOrSqlAlchemy) -> bool:
|
|
"""Check if the model metadata is empty.
|
|
|
|
Args:
|
|
model: The model to check.
|
|
|
|
Returns:
|
|
True if the model metadata is empty, False otherwise.
|
|
"""
|
|
return len(model.metadata.tables) == 0
|
|
|
|
@classmethod
|
|
def get_metadata(cls) -> sqlalchemy.MetaData:
|
|
"""Get the database metadata.
|
|
|
|
Returns:
|
|
The database metadata.
|
|
"""
|
|
if cls._metadata is not None:
|
|
return cls._metadata
|
|
|
|
models = cls.get_models(include_empty=False)
|
|
|
|
if len(models) == 1:
|
|
metadata = next(iter(models)).metadata
|
|
else:
|
|
# Merge the metadata from all the models.
|
|
# This allows mixing bare sqlalchemy models with sqlmodel models in one database.
|
|
metadata = sqlalchemy.MetaData()
|
|
for model in cls.get_models():
|
|
for table in model.metadata.tables.values():
|
|
table.to_metadata(metadata)
|
|
|
|
# Cache the metadata
|
|
cls._metadata = metadata
|
|
|
|
return metadata
|
|
|
|
else:
|
|
get_engine_args = _print_db_not_available
|
|
get_engine = _print_db_not_available
|
|
get_async_engine = _print_db_not_available
|
|
sqla_session = _print_db_not_available
|
|
ModelRegistry = _ClassThatErrorsOnInit # pyright: ignore [reportAssignmentType]
|
|
|
|
if find_spec("sqlalchemy") and find_spec("alembic"):
|
|
import alembic.autogenerate
|
|
import alembic.command
|
|
import alembic.config
|
|
import alembic.operations.ops
|
|
import alembic.runtime.environment
|
|
import alembic.script
|
|
from alembic.runtime.migration import MigrationContext
|
|
from alembic.script.base import Script
|
|
|
|
def format_revision(
|
|
rev: Script,
|
|
current_rev: str | None,
|
|
current_reached_ref: list[bool],
|
|
) -> str:
|
|
"""Format a single revision for display.
|
|
|
|
Args:
|
|
rev: The alembic script object
|
|
current_rev: The currently applied revision ID
|
|
current_reached_ref: Mutable reference to track if we've reached current revision
|
|
|
|
Returns:
|
|
Formatted string for display
|
|
"""
|
|
current = rev.revision
|
|
message = rev.doc
|
|
|
|
# Determine if this migration is applied
|
|
if current_rev is None:
|
|
is_applied = False
|
|
elif current == current_rev:
|
|
is_applied = True
|
|
current_reached_ref[0] = True
|
|
else:
|
|
is_applied = not current_reached_ref[0]
|
|
|
|
# Show checkmark or X with colors
|
|
status_icon = "[green]✓[/green]" if is_applied else "[red]✗[/red]"
|
|
head_marker = " (head)" if rev.is_head else ""
|
|
|
|
# Format output with message
|
|
return f" [{status_icon}] {current}{head_marker}, {message}"
|
|
|
|
def _alembic_config():
|
|
"""Get the alembic configuration and script_directory.
|
|
|
|
Returns:
|
|
tuple of (config, script_directory)
|
|
"""
|
|
config = alembic.config.Config(environment.ALEMBIC_CONFIG.get())
|
|
if not config.get_main_option("script_location"):
|
|
config.set_main_option("script_location", "version")
|
|
return config, alembic.script.ScriptDirectory.from_config(config)
|
|
|
|
def _alembic_render_item(
|
|
type_: str,
|
|
obj: Any,
|
|
autogen_context: alembic.autogenerate.api.AutogenContext,
|
|
):
|
|
"""Alembic render_item hook call.
|
|
|
|
This method is called to provide python code for the given obj,
|
|
but currently it is only used to add `sqlmodel` to the import list
|
|
when generating migration scripts.
|
|
|
|
See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
|
|
|
|
Args:
|
|
type_: One of "schema", "table", "column", "index",
|
|
"unique_constraint", or "foreign_key_constraint".
|
|
obj: The object being rendered.
|
|
autogen_context: Shared AutogenContext passed to each render_item call.
|
|
|
|
Returns:
|
|
False - Indicating that the default rendering should be used.
|
|
"""
|
|
if find_spec("sqlmodel"):
|
|
autogen_context.imports.add("import sqlmodel")
|
|
return False
|
|
|
|
def alembic_init():
|
|
"""Initialize alembic for the project."""
|
|
alembic.command.init(
|
|
config=alembic.config.Config(environment.ALEMBIC_CONFIG.get()),
|
|
directory=str(environment.ALEMBIC_CONFIG.get().parent / "alembic"),
|
|
)
|
|
|
|
def get_migration_history():
|
|
"""Get migration history with current database state.
|
|
|
|
Returns:
|
|
tuple: (current_revision, revisions_list) where revisions_list is in chronological order
|
|
"""
|
|
# Get current revision from database
|
|
with get_engine().connect() as connection:
|
|
context = MigrationContext.configure(connection)
|
|
current_rev = context.get_current_revision()
|
|
|
|
# Get all revisions from base to head
|
|
_, script_dir = _alembic_config()
|
|
revisions = list(script_dir.walk_revisions())
|
|
revisions.reverse() # Reverse to get chronological order (base first)
|
|
|
|
return current_rev, revisions
|
|
|
|
def alembic_autogenerate(
|
|
connection: sqlalchemy.engine.Connection,
|
|
message: str | None = None,
|
|
write_migration_scripts: bool = True,
|
|
) -> bool:
|
|
"""Generate migration scripts for alembic-detectable changes.
|
|
|
|
Args:
|
|
connection: SQLAlchemy connection to use when detecting changes.
|
|
message: Human readable identifier describing the generated revision.
|
|
write_migration_scripts: If True, write autogenerated revisions to script directory.
|
|
|
|
Returns:
|
|
True when changes have been detected.
|
|
"""
|
|
if not environment.ALEMBIC_CONFIG.get().exists():
|
|
return False
|
|
|
|
config, script_directory = _alembic_config()
|
|
revision_context = alembic.autogenerate.api.RevisionContext(
|
|
config=config,
|
|
script_directory=script_directory,
|
|
command_args=defaultdict(
|
|
lambda: None,
|
|
autogenerate=True,
|
|
head="head",
|
|
message=message,
|
|
),
|
|
)
|
|
writer = alembic.autogenerate.rewriter.Rewriter()
|
|
|
|
@writer.rewrites(alembic.operations.ops.AddColumnOp)
|
|
def render_add_column_with_server_default(
|
|
context: MigrationContext,
|
|
revision: str | None,
|
|
op: Any,
|
|
):
|
|
# Carry the sqlmodel default as server_default so that newly added
|
|
# columns get the desired default value in existing rows.
|
|
if op.column.default is not None and op.column.server_default is None:
|
|
op.column.server_default = sqlalchemy.DefaultClause(
|
|
sqlalchemy.sql.expression.literal(op.column.default.arg),
|
|
)
|
|
return op
|
|
|
|
def run_autogenerate(rev: str, context: MigrationContext):
|
|
revision_context.run_autogenerate(rev, context)
|
|
return []
|
|
|
|
with alembic.runtime.environment.EnvironmentContext(
|
|
config=config,
|
|
script=script_directory,
|
|
fn=run_autogenerate,
|
|
) as env:
|
|
env.configure(
|
|
connection=connection,
|
|
target_metadata=ModelRegistry.get_metadata(),
|
|
render_item=_alembic_render_item,
|
|
process_revision_directives=writer,
|
|
compare_type=False,
|
|
include_schemas=environment.ALEMBIC_INCLUDE_SCHEMAS.get(),
|
|
render_as_batch=True, # for sqlite compatibility
|
|
)
|
|
env.run_migrations()
|
|
changes_detected = False
|
|
if revision_context.generated_revisions:
|
|
upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
|
|
if upgrade_ops is not None:
|
|
changes_detected = bool(upgrade_ops.ops)
|
|
if changes_detected and write_migration_scripts:
|
|
# Must iterate the generator to actually write the scripts.
|
|
_ = tuple(revision_context.generate_scripts())
|
|
return changes_detected
|
|
|
|
def _alembic_upgrade(
|
|
connection: sqlalchemy.engine.Connection,
|
|
to_rev: str = "head",
|
|
) -> None:
|
|
"""Apply alembic migrations up to the given revision.
|
|
|
|
Args:
|
|
connection: SQLAlchemy connection to use when performing upgrade.
|
|
to_rev: Revision to migrate towards.
|
|
"""
|
|
config, script_directory = _alembic_config()
|
|
|
|
def run_upgrade(rev: str, context: MigrationContext):
|
|
return script_directory._upgrade_revs(to_rev, rev)
|
|
|
|
with alembic.runtime.environment.EnvironmentContext(
|
|
config=config,
|
|
script=script_directory,
|
|
fn=run_upgrade,
|
|
) as env:
|
|
env.configure(connection=connection)
|
|
env.run_migrations()
|
|
|
|
def migrate(autogenerate: bool = False) -> bool | None:
|
|
"""Execute alembic migrations for all model classes.
|
|
|
|
If alembic is not installed or has not been initialized for the project,
|
|
then no action is performed.
|
|
|
|
If there are no revisions currently tracked by alembic, then
|
|
an initial revision will be created based on sqlmodel metadata.
|
|
|
|
If models in the app have changed in incompatible ways that alembic
|
|
cannot automatically generate revisions for, the app may not be able to
|
|
start up until migration scripts have been corrected by hand.
|
|
|
|
Args:
|
|
autogenerate: If True, generate migration script and use it to upgrade schema
|
|
(otherwise, just bring the schema to current "head" revision).
|
|
|
|
Returns:
|
|
True - indicating the process was successful.
|
|
None - indicating the process was skipped.
|
|
"""
|
|
if not environment.ALEMBIC_CONFIG.get().exists():
|
|
return None
|
|
|
|
with get_engine().connect() as connection:
|
|
_alembic_upgrade(connection=connection)
|
|
if autogenerate:
|
|
changes_detected = alembic_autogenerate(connection=connection)
|
|
if changes_detected:
|
|
_alembic_upgrade(connection=connection)
|
|
connection.commit()
|
|
return True
|
|
|
|
else:
|
|
alembic_init = _print_db_not_available
|
|
get_migration_history = _print_db_not_available
|
|
alembic_autogenerate = _print_db_not_available
|
|
migrate = _print_db_not_available
|
|
|
|
if find_spec("sqlmodel") and find_spec("sqlalchemy") and find_spec("pydantic"):
|
|
import sqlmodel
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
|
|
|
|
def get_db_status() -> dict[str, bool]:
|
|
"""Checks the status of the database connection.
|
|
|
|
Attempts to connect to the database and execute a simple query to verify connectivity.
|
|
|
|
Returns:
|
|
The status of the database connection.
|
|
"""
|
|
status = True
|
|
try:
|
|
engine = get_engine()
|
|
with engine.connect() as connection:
|
|
connection.execute(sqlalchemy.text("SELECT 1"))
|
|
except Exception as exc:
|
|
status = False
|
|
console.error(
|
|
f"Database health check failed: {exc} (subsequent errors will not be logged)",
|
|
dedupe=True,
|
|
)
|
|
|
|
return {"db": status}
|
|
|
|
@serializer
|
|
def serialize_sqlmodel(m: sqlmodel.SQLModel) -> dict[str, Any]:
|
|
"""Serialize a SQLModel object to a dictionary.
|
|
|
|
Args:
|
|
m: The SQLModel object to serialize.
|
|
|
|
Returns:
|
|
The serialized object as a dictionary.
|
|
"""
|
|
base_fields = m.model_dump()
|
|
relationships = {}
|
|
# SQLModel relationships do not appear in __fields__, but should be included if present.
|
|
for name in m.__sqlmodel_relationships__:
|
|
with suppress(
|
|
sqlalchemy.orm.exc.DetachedInstanceError # This happens when the relationship was never loaded and the session is closed.
|
|
):
|
|
relationships[name] = getattr(m, name)
|
|
return {
|
|
**base_fields,
|
|
**relationships,
|
|
}
|
|
|
|
def _warn_about_model_deprecation():
|
|
console.deprecate(
|
|
feature_name="reflex.Model",
|
|
reason="Directly use database ORM layer, like sqlalchemy or SQLModel",
|
|
deprecation_version="0.9.2",
|
|
removal_version="1.0.0",
|
|
)
|
|
|
|
class Model(sqlmodel.SQLModel):
|
|
"""Base class to define a table in the database.
|
|
|
|
Attributes:
|
|
id: The primary key for the table.
|
|
"""
|
|
|
|
id: int | None = sqlmodel.Field(default=None, primary_key=True)
|
|
|
|
model_config = { # pyright: ignore [reportAssignmentType]
|
|
"arbitrary_types_allowed": True,
|
|
"use_enum_values": True,
|
|
"extra": "allow",
|
|
}
|
|
|
|
def __init_subclass__(cls, **kwargs):
|
|
"""Automatically register subclasses as models."""
|
|
super().__init_subclass__(**kwargs)
|
|
_warn_about_model_deprecation()
|
|
|
|
@staticmethod
|
|
def create_all():
|
|
"""Create all the tables."""
|
|
_warn_about_model_deprecation()
|
|
create_all()
|
|
|
|
@staticmethod
|
|
def get_db_engine():
|
|
"""Get the database engine.
|
|
|
|
Returns:
|
|
The database engine.
|
|
"""
|
|
_warn_about_model_deprecation()
|
|
return get_engine()
|
|
|
|
@classmethod
|
|
def alembic_init(cls):
|
|
"""Initialize alembic for the project."""
|
|
_warn_about_model_deprecation()
|
|
alembic_init()
|
|
|
|
@classmethod
|
|
def get_migration_history(cls):
|
|
"""Get migration history with current database state.
|
|
|
|
Returns:
|
|
tuple: (current_revision, revisions_list) where revisions_list is in chronological order
|
|
"""
|
|
_warn_about_model_deprecation()
|
|
return get_migration_history()
|
|
|
|
@classmethod
|
|
def alembic_autogenerate(
|
|
cls,
|
|
connection: sqlalchemy.engine.Connection,
|
|
message: str | None = None,
|
|
write_migration_scripts: bool = True,
|
|
) -> bool:
|
|
"""Generate migration scripts for alembic-detectable changes.
|
|
|
|
Args:
|
|
connection: SQLAlchemy connection to use when detecting changes.
|
|
message: Human readable identifier describing the generated revision.
|
|
write_migration_scripts: If True, write autogenerated revisions to script directory.
|
|
|
|
Returns:
|
|
True when changes have been detected.
|
|
"""
|
|
_warn_about_model_deprecation()
|
|
return alembic_autogenerate(
|
|
connection=connection,
|
|
message=message,
|
|
write_migration_scripts=write_migration_scripts,
|
|
)
|
|
|
|
@classmethod
|
|
def migrate(cls, autogenerate: bool = False) -> bool | None:
|
|
"""Execute alembic migrations for all sqlmodel Model classes.
|
|
|
|
If alembic is not installed or has not been initialized for the project,
|
|
then no action is performed.
|
|
|
|
If there are no revisions currently tracked by alembic, then
|
|
an initial revision will be created based on sqlmodel metadata.
|
|
|
|
If models in the app have changed in incompatible ways that alembic
|
|
cannot automatically generate revisions for, the app may not be able to
|
|
start up until migration scripts have been corrected by hand.
|
|
|
|
Args:
|
|
autogenerate: If True, generate migration script and use it to upgrade schema
|
|
(otherwise, just bring the schema to current "head" revision).
|
|
|
|
Returns:
|
|
True - indicating the process was successful.
|
|
None - indicating the process was skipped.
|
|
"""
|
|
_warn_about_model_deprecation()
|
|
return migrate(autogenerate=autogenerate)
|
|
|
|
@classmethod
|
|
def select(cls):
|
|
"""Select rows from the table.
|
|
|
|
Returns:
|
|
The select statement.
|
|
"""
|
|
return sqlmodel.select(cls)
|
|
|
|
ModelRegistry.register(Model)
|
|
|
|
def session(url: str | None = None) -> sqlmodel.Session:
|
|
"""Get a sqlmodel session to interact with the database.
|
|
|
|
Args:
|
|
url: The database url.
|
|
|
|
Returns:
|
|
A database session.
|
|
"""
|
|
return sqlmodel.Session(get_engine(url))
|
|
|
|
def asession(url: str | None = None) -> AsyncSession:
|
|
"""Get an async sqlmodel session to interact with the database.
|
|
|
|
async with rx.asession() as asession:
|
|
...
|
|
|
|
Most operations against the `asession` must be awaited.
|
|
|
|
Args:
|
|
url: The database url.
|
|
|
|
Returns:
|
|
An async database session.
|
|
"""
|
|
global _AsyncSessionLocal
|
|
if url not in _AsyncSessionLocal:
|
|
_AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker(
|
|
bind=get_async_engine(url),
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autocommit=False,
|
|
autoflush=False,
|
|
)
|
|
return _AsyncSessionLocal[url]()
|
|
|
|
else:
|
|
get_db_status = _print_db_not_available
|
|
session = _print_db_not_available
|
|
asession = _print_db_not_available
|
|
Model = _ClassThatErrorsOnInit # pyright: ignore [reportAssignmentType]
|