81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
"""Shared contextvars wrapper for contextual globals."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
from contextvars import ContextVar, Token
|
|
from typing import ClassVar
|
|
|
|
from typing_extensions import Self
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
|
|
class BaseContext:
|
|
"""Base context class that acts as an async context manager to set the context var."""
|
|
|
|
_context_var: ClassVar[ContextVar[Self]]
|
|
_attached_context_token: ClassVar[dict[Self, Token[Self]]]
|
|
|
|
@classmethod
|
|
def __init_subclass__(cls, **kwargs):
|
|
"""Initialize the context variable for the subclass."""
|
|
super(BaseContext, cls).__init_subclass__(**kwargs)
|
|
cls._context_var = ContextVar(cls.__name__)
|
|
cls._attached_context_token = {}
|
|
|
|
@classmethod
|
|
def get(cls) -> Self:
|
|
"""Get the context from the context variable.
|
|
|
|
Returns:
|
|
The context instance.
|
|
"""
|
|
return cls._context_var.get()
|
|
|
|
@classmethod
|
|
def set(cls, context: Self) -> Token[Self]:
|
|
"""Set the context in the context variable.
|
|
|
|
Args:
|
|
context: The context instance to set.
|
|
|
|
Returns:
|
|
The token for resetting the context variable.
|
|
"""
|
|
return cls._context_var.set(context)
|
|
|
|
@classmethod
|
|
def reset(cls, token: Token[Self]) -> None:
|
|
"""Reset the context variable to a previous state.
|
|
|
|
Args:
|
|
token: The token to reset the context variable to.
|
|
"""
|
|
cls._context_var.reset(token)
|
|
|
|
def __enter__(self) -> Self:
|
|
"""Enter the context.
|
|
|
|
Returns:
|
|
This context instance.
|
|
"""
|
|
if self._attached_context_token.get(self) is not None:
|
|
msg = "Context is already attached, cannot enter context manager."
|
|
raise RuntimeError(msg)
|
|
self._attached_context_token[self] = self._context_var.set(self)
|
|
return self
|
|
|
|
def __exit__(self, *exc_info):
|
|
"""Exit the context."""
|
|
if (token := self._attached_context_token.pop(self)) is not None:
|
|
self._context_var.reset(token)
|
|
|
|
def ensure_context_attached(self):
|
|
"""Ensure that the context is attached to the current context variable.
|
|
|
|
Raises:
|
|
RuntimeError: If the context is not attached.
|
|
"""
|
|
if self._attached_context_token.get(self) is None:
|
|
msg = f"{type(self).__name__} must be entered before calling this method."
|
|
raise RuntimeError(msg)
|