498 lines
17 KiB
Python
498 lines
17 KiB
Python
import asyncio
|
|
import inspect
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
from typing import List, Optional, Tuple, Type, Union
|
|
|
|
from redis.asyncio import Redis as AsyncRedis
|
|
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
|
|
from redis.asyncio.http.http_client import DEFAULT_TIMEOUT, AsyncHTTPClientWrapper
|
|
from redis.backoff import NoBackoff
|
|
from redis.client import Redis as SyncRedis
|
|
from redis.cluster import RedisCluster as SyncRedisCluster
|
|
from redis.http.http_client import HttpClient
|
|
from redis.multidb.exception import UnhealthyDatabaseException
|
|
from redis.retry import Retry
|
|
|
|
# Type alias for async Redis clients (standalone or cluster)
|
|
AsyncRedisClientT = Union[AsyncRedis, AsyncRedisCluster]
|
|
|
|
|
|
def _get_init_params(cls: Type) -> frozenset:
|
|
"""Extract parameter names from a class's __init__ method."""
|
|
sig = inspect.signature(cls.__init__)
|
|
return frozenset(
|
|
name
|
|
for name, param in sig.parameters.items()
|
|
if name != "self"
|
|
and param.kind
|
|
in (
|
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
inspect.Parameter.KEYWORD_ONLY,
|
|
)
|
|
)
|
|
|
|
|
|
def _filter_kwargs(kwargs: dict, cls: Type) -> dict:
|
|
"""Filter kwargs to only include parameters accepted by the class's __init__."""
|
|
allowed = _get_init_params(cls)
|
|
return {k: v for k, v in kwargs.items() if k in allowed}
|
|
|
|
|
|
DEFAULT_HEALTH_CHECK_PROBES = 3
|
|
DEFAULT_HEALTH_CHECK_INTERVAL = 5
|
|
DEFAULT_HEALTH_CHECK_TIMEOUT = 3
|
|
DEFAULT_HEALTH_CHECK_DELAY = 0.5
|
|
DEFAULT_LAG_AWARE_TOLERANCE = 5000
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HealthCheck(ABC):
|
|
"""
|
|
Health check interface.
|
|
"""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def health_check_probes(self) -> int:
|
|
"""Number of probes to execute health checks."""
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def health_check_delay(self) -> float:
|
|
"""Delay between health check probes."""
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def health_check_timeout(self) -> float:
|
|
"""Timeout for the full health check operation (including all probes)."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def check_health(self, database, hc_client: AsyncRedisClientT) -> bool:
|
|
"""
|
|
Function to determine the health status.
|
|
|
|
Args:
|
|
database: The database being checked
|
|
hc_client: A Redis client (AsyncRedis or AsyncRedisCluster) to use for
|
|
health checks. This client follows topology changes automatically.
|
|
|
|
Returns:
|
|
True if the database is healthy, False otherwise.
|
|
"""
|
|
pass
|
|
|
|
|
|
class HealthCheckPolicy(ABC):
|
|
"""
|
|
Health checks execution policy.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
|
"""Execute health checks and return database health status."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def _execute(self, health_check: HealthCheck, database) -> bool:
|
|
"""
|
|
Executes health check against given database.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_client(self, database) -> AsyncRedisClientT:
|
|
"""
|
|
Get a health check client for the database.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def close(self) -> None:
|
|
"""Close all health check clients."""
|
|
pass
|
|
|
|
|
|
class AbstractHealthCheckPolicy(HealthCheckPolicy):
|
|
"""
|
|
Abstract health check policy.
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Single client per database, keyed by database id
|
|
self._clients: dict[int, AsyncRedisClientT] = {}
|
|
|
|
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
|
"""
|
|
Execute all health checks concurrently with individual timeouts.
|
|
Each health check runs with its own timeout, and all run in parallel.
|
|
|
|
All exception handling is centralized here - _execute() methods just
|
|
propagate exceptions naturally.
|
|
"""
|
|
|
|
# Create wrapper tasks that apply individual timeouts
|
|
async def execute_with_timeout(health_check: HealthCheck):
|
|
return await asyncio.wait_for(
|
|
self._execute(health_check, database),
|
|
timeout=health_check.health_check_timeout,
|
|
)
|
|
|
|
# Run all health checks concurrently and collect results/exceptions
|
|
results = await asyncio.gather(
|
|
*[execute_with_timeout(hc) for hc in health_checks],
|
|
return_exceptions=True,
|
|
)
|
|
|
|
# Check results - handle exceptions and failures
|
|
for result in results:
|
|
if isinstance(result, Exception):
|
|
# Any exception (including TimeoutError) makes the database unhealthy
|
|
raise UnhealthyDatabaseException("Unhealthy database", database, result)
|
|
elif not result:
|
|
# Health check returned False
|
|
return False
|
|
|
|
return True
|
|
|
|
async def get_client(self, database) -> AsyncRedisClientT:
|
|
"""
|
|
Get or create a health check client for the database.
|
|
|
|
Creates a single client instance per database that follows topology
|
|
changes automatically. For cluster databases, the client handles
|
|
node discovery and slot mapping internally.
|
|
"""
|
|
db_id = id(database)
|
|
client = self._clients.get(db_id)
|
|
|
|
if client is None:
|
|
# Check for both sync and async standalone Redis clients
|
|
if isinstance(database.client, (AsyncRedis, SyncRedis)):
|
|
conn_kwargs = database.client.get_connection_kwargs()
|
|
filtered_kwargs = _filter_kwargs(conn_kwargs, AsyncRedis)
|
|
client = AsyncRedis(**filtered_kwargs)
|
|
elif isinstance(database.client, (AsyncRedisCluster, SyncRedisCluster)):
|
|
# Cluster client - create a single cluster client that handles
|
|
# topology changes internally
|
|
conn_kwargs = database.client.get_connection_kwargs().copy()
|
|
filtered_kwargs = _filter_kwargs(conn_kwargs, AsyncRedisCluster)
|
|
startup_nodes = database.client.startup_nodes
|
|
# Use the first node as the startup node
|
|
if startup_nodes:
|
|
first_node = startup_nodes[0]
|
|
client = AsyncRedisCluster(
|
|
host=first_node.host,
|
|
port=first_node.port,
|
|
dynamic_startup_nodes=database.client.nodes_manager._dynamic_startup_nodes,
|
|
address_remap=database.client.nodes_manager.address_remap,
|
|
require_full_coverage=database.client.nodes_manager._require_full_coverage,
|
|
retry=database.client.retry,
|
|
**filtered_kwargs,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"Cluster client has no nodes - cannot create health check client"
|
|
)
|
|
else:
|
|
raise TypeError(f"Unsupported client type: {type(database.client)}")
|
|
self._clients[db_id] = client
|
|
|
|
return client
|
|
|
|
async def close(self) -> None:
|
|
"""Close all health check clients."""
|
|
close_tasks = [
|
|
asyncio.create_task(client.aclose()) for client in self._clients.values()
|
|
]
|
|
|
|
if close_tasks:
|
|
await asyncio.gather(*close_tasks, return_exceptions=True)
|
|
|
|
self._clients.clear()
|
|
|
|
@abstractmethod
|
|
async def _execute(self, health_check: HealthCheck, database) -> bool:
|
|
"""
|
|
Executes health check against given database.
|
|
"""
|
|
pass
|
|
|
|
|
|
class HealthyAllPolicy(AbstractHealthCheckPolicy):
|
|
"""
|
|
Policy that returns True if all health check probes are successful.
|
|
"""
|
|
|
|
async def _execute(self, health_check: HealthCheck, database) -> bool:
|
|
"""
|
|
Executes health check against given database.
|
|
|
|
Uses a single client that handles topology changes automatically.
|
|
"""
|
|
client = await self.get_client(database)
|
|
probes = health_check.health_check_probes
|
|
|
|
for attempt in range(probes):
|
|
result = await health_check.check_health(database, client)
|
|
if not result:
|
|
return False
|
|
|
|
if attempt < probes - 1:
|
|
await asyncio.sleep(health_check.health_check_delay)
|
|
|
|
return True
|
|
|
|
|
|
class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
|
|
"""
|
|
Policy that returns True if a majority of health check probes are successful.
|
|
|
|
Majority means more than half must pass:
|
|
- 3 probes: need 2+ to pass (1 failure allowed)
|
|
- 4 probes: need 3+ to pass (1 failure allowed, tie = unhealthy)
|
|
- 5 probes: need 3+ to pass (2 failures allowed)
|
|
"""
|
|
|
|
async def _execute(self, health_check: HealthCheck, database) -> bool:
|
|
"""
|
|
Executes health check against given database.
|
|
|
|
Uses a single client that handles topology changes automatically.
|
|
"""
|
|
probes = health_check.health_check_probes
|
|
# Strict majority: more than half must pass
|
|
# (probes - 1) // 2 gives the max allowed failures
|
|
allowed_unsuccessful_probes = (probes - 1) // 2
|
|
client = await self.get_client(database)
|
|
last_exception = None
|
|
|
|
for attempt in range(probes):
|
|
try:
|
|
result = await health_check.check_health(database, client)
|
|
if not result:
|
|
# Probe failed (returned False)
|
|
allowed_unsuccessful_probes -= 1
|
|
if allowed_unsuccessful_probes < 0:
|
|
return False
|
|
except Exception as e:
|
|
# Probe failed (exception)
|
|
last_exception = e
|
|
allowed_unsuccessful_probes -= 1
|
|
if allowed_unsuccessful_probes < 0:
|
|
raise last_exception
|
|
|
|
if attempt < probes - 1:
|
|
await asyncio.sleep(health_check.health_check_delay)
|
|
|
|
return True
|
|
|
|
|
|
class HealthyAnyPolicy(AbstractHealthCheckPolicy):
|
|
"""
|
|
Policy that returns True if at least one health check probe is successful.
|
|
"""
|
|
|
|
async def _execute(self, health_check: HealthCheck, database) -> bool:
|
|
"""
|
|
Executes health check against given database.
|
|
|
|
Uses a single client that handles topology changes automatically.
|
|
"""
|
|
probes = health_check.health_check_probes
|
|
last_exception = None
|
|
client = await self.get_client(database)
|
|
|
|
for attempt in range(probes):
|
|
try:
|
|
result = await health_check.check_health(database, client)
|
|
if result:
|
|
# At least one probe succeeded
|
|
return True
|
|
except Exception as e:
|
|
last_exception = e
|
|
|
|
if attempt < probes - 1:
|
|
await asyncio.sleep(health_check.health_check_delay)
|
|
|
|
# All probes failed
|
|
if last_exception:
|
|
raise last_exception
|
|
|
|
return False
|
|
|
|
|
|
class HealthCheckPolicies(Enum):
|
|
HEALTHY_ALL = HealthyAllPolicy
|
|
HEALTHY_MAJORITY = HealthyMajorityPolicy
|
|
HEALTHY_ANY = HealthyAnyPolicy
|
|
|
|
|
|
DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
|
|
|
|
|
|
class AbstractHealthCheck(HealthCheck):
|
|
def __init__(
|
|
self,
|
|
health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES,
|
|
health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY,
|
|
health_check_timeout: float = DEFAULT_HEALTH_CHECK_TIMEOUT,
|
|
):
|
|
if health_check_probes < 1:
|
|
raise ValueError("health_check_probes must be greater than 0")
|
|
self._health_check_probes = health_check_probes
|
|
self._health_check_delay = health_check_delay
|
|
self._health_check_timeout = health_check_timeout
|
|
|
|
@property
|
|
def health_check_probes(self) -> int:
|
|
return self._health_check_probes
|
|
|
|
@property
|
|
def health_check_delay(self) -> float:
|
|
return self._health_check_delay
|
|
|
|
@property
|
|
def health_check_timeout(self) -> float:
|
|
return self._health_check_timeout
|
|
|
|
@abstractmethod
|
|
async def check_health(self, database, hc_client: AsyncRedisClientT) -> bool:
|
|
pass
|
|
|
|
|
|
class PingHealthCheck(AbstractHealthCheck):
|
|
"""
|
|
Health check based on PING command.
|
|
"""
|
|
|
|
async def check_health(self, database, hc_client: AsyncRedisClientT) -> bool:
|
|
if isinstance(hc_client, AsyncRedis):
|
|
return await hc_client.execute_command("PING")
|
|
else:
|
|
# For a cluster checks if all nodes are healthy.
|
|
all_nodes = hc_client.get_nodes()
|
|
for node in all_nodes:
|
|
if not await node.redis_connection.execute_command("PING"):
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class LagAwareHealthCheck(AbstractHealthCheck):
|
|
"""
|
|
Health check available for Redis Enterprise deployments.
|
|
Verify via REST API that the database is healthy based on different lags.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
rest_api_port: int = 9443,
|
|
lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
|
|
http_timeout: float = DEFAULT_TIMEOUT,
|
|
auth_basic: Optional[Tuple[str, str]] = None,
|
|
verify_tls: bool = True,
|
|
# TLS verification (server) options
|
|
ca_file: Optional[str] = None,
|
|
ca_path: Optional[str] = None,
|
|
ca_data: Optional[Union[str, bytes]] = None,
|
|
# Mutual TLS (client cert) options
|
|
client_cert_file: Optional[str] = None,
|
|
client_key_file: Optional[str] = None,
|
|
client_key_password: Optional[str] = None,
|
|
# Health check configuration
|
|
health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES,
|
|
health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY,
|
|
health_check_timeout: float = DEFAULT_HEALTH_CHECK_TIMEOUT,
|
|
):
|
|
"""
|
|
Initialize LagAwareHealthCheck with the specified parameters.
|
|
|
|
Args:
|
|
rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
|
|
lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
|
|
http_timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
|
|
auth_basic: Tuple of (username, password) for basic authentication
|
|
verify_tls: Whether to verify TLS certificates (default: True)
|
|
ca_file: Path to CA certificate file for TLS verification
|
|
ca_path: Path to CA certificates directory for TLS verification
|
|
ca_data: CA certificate data as string or bytes
|
|
client_cert_file: Path to client certificate file for mutual TLS
|
|
client_key_file: Path to client private key file for mutual TLS
|
|
client_key_password: Password for encrypted client private key
|
|
"""
|
|
self._http_client = AsyncHTTPClientWrapper(
|
|
HttpClient(
|
|
timeout=http_timeout,
|
|
auth_basic=auth_basic,
|
|
retry=Retry(NoBackoff(), retries=0),
|
|
verify_tls=verify_tls,
|
|
ca_file=ca_file,
|
|
ca_path=ca_path,
|
|
ca_data=ca_data,
|
|
client_cert_file=client_cert_file,
|
|
client_key_file=client_key_file,
|
|
client_key_password=client_key_password,
|
|
)
|
|
)
|
|
self._rest_api_port = rest_api_port
|
|
self._lag_aware_tolerance = lag_aware_tolerance
|
|
super().__init__(
|
|
health_check_probes=health_check_probes,
|
|
health_check_delay=health_check_delay,
|
|
health_check_timeout=health_check_timeout,
|
|
)
|
|
|
|
async def check_health(self, database, hc_client: AsyncRedisClientT) -> bool:
|
|
"""
|
|
Check database health via Redis Enterprise REST API.
|
|
|
|
Note: The client parameter is not used for this health check as it
|
|
relies on the REST API instead of Redis protocol. The client is
|
|
accepted for interface compatibility.
|
|
"""
|
|
if database.health_check_url is None:
|
|
raise ValueError(
|
|
"Database health check url is not set. Please check DatabaseConfig for the current database."
|
|
)
|
|
|
|
if isinstance(database.client, (AsyncRedis, SyncRedis)):
|
|
db_host = database.client.get_connection_kwargs()["host"]
|
|
else:
|
|
# Cluster client
|
|
db_host = database.client.get_nodes()[0].host
|
|
|
|
base_url = f"{database.health_check_url}:{self._rest_api_port}"
|
|
self._http_client.client.base_url = base_url
|
|
|
|
# Find bdb matching to the current database host
|
|
matching_bdb = None
|
|
for bdb in await self._http_client.get("/v1/bdbs"):
|
|
for endpoint in bdb["endpoints"]:
|
|
if endpoint["dns_name"] == db_host:
|
|
matching_bdb = bdb
|
|
break
|
|
|
|
# In case if the host was set as public IP
|
|
for addr in endpoint["addr"]:
|
|
if addr == db_host:
|
|
matching_bdb = bdb
|
|
break
|
|
|
|
if matching_bdb is None:
|
|
logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
|
|
raise ValueError("Could not find a matching bdb")
|
|
|
|
url = (
|
|
f"/v1/bdbs/{matching_bdb['uid']}/availability"
|
|
f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
|
|
)
|
|
await self._http_client.get(url, expect_json=False)
|
|
|
|
# Status checked in an http client, otherwise HttpError will be raised
|
|
return True
|