eptm_dashboard/.venv/lib/python3.12/site-packages/granian/utils/proxies.py

107 lines
3.5 KiB
Python

import ipaddress
from functools import wraps as _wraps
class _Forwarders:
def __init__(self, trusted_hosts: list[str] | str) -> None:
self.always_trust: bool = trusted_hosts in ('*', ['*'])
self.literals: set[str] = set()
self.hosts: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
self.networks: set[ipaddress.IPv4Network | ipaddress.IPv6Network] = set()
if self.always_trust:
return
if isinstance(trusted_hosts, str):
trusted_hosts = _parse_raw_hosts(trusted_hosts)
for host in trusted_hosts:
try:
if '/' in host:
self.networks.add(ipaddress.ip_network(host))
continue
self.hosts.add(ipaddress.ip_address(host))
except ValueError:
self.literals.add(host)
def __contains__(self, host: str | None) -> bool:
if self.always_trust:
return True
if not host:
return False
try:
ip = ipaddress.ip_address(host)
if ip in self.hosts:
return True
return any(ip in net for net in self.networks)
except ValueError:
return host in self.literals
def get_client_host(self, x_forwarded_for: str) -> str:
x_forwarded_for_hosts = _parse_raw_hosts(x_forwarded_for)
if self.always_trust:
return x_forwarded_for_hosts[0]
for host in reversed(x_forwarded_for_hosts):
if host not in self:
return host
return x_forwarded_for_hosts[0]
def _parse_raw_hosts(value: str) -> list[str]:
return [item.strip() for item in value.split(',')]
def wrap_asgi_with_proxy_headers(app, trusted_hosts: list[str] | str = '127.0.0.1'):
forwarders = _Forwarders(trusted_hosts)
@_wraps(app)
def wrapped(scope, receive, send):
if scope['type'] == 'lifespan':
return app(scope, receive, send)
client_addr = scope.get('client')
client_host = client_addr[0] if client_addr else None
if client_host in forwarders:
headers = dict(scope['headers'])
if x_forwarded_proto := headers.get(b'x-forwarded-proto', b'').decode('latin1').strip():
if x_forwarded_proto in {'http', 'https', 'ws', 'wss'}:
if scope['type'] == 'websocket':
scope['scheme'] = x_forwarded_proto.replace('http', 'ws')
else:
scope['scheme'] = x_forwarded_proto
if x_forwarded_for := headers.get(b'x-forwarded-for', b'').decode('latin1'):
if host := forwarders.get_client_host(x_forwarded_for):
scope['client'] = (host, 0)
return app(scope, receive, send)
return wrapped
def wrap_wsgi_with_proxy_headers(app, trusted_hosts: list[str] | str = '127.0.0.1'):
forwarders = _Forwarders(trusted_hosts)
@_wraps(app)
def wrapped(scope, resp):
client_host = scope.get('REMOTE_ADDR')
if client_host in forwarders:
if x_forwarded_proto := scope.get('HTTP_X_FORWARDED_PROTO'):
if x_forwarded_proto in {'http', 'https'}:
scope['wsgi.url_scheme'] = x_forwarded_proto
if x_forwarded_for := scope.get('HTTP_X_FORWARDED_FOR'):
if host := forwarders.get_client_host(x_forwarded_for):
scope['REMOTE_ADDR'] = host
return app(scope, resp)
return wrapped