107 lines
3.5 KiB
Python
107 lines
3.5 KiB
Python
import ipaddress
|
|
from functools import wraps as _wraps
|
|
|
|
|
|
class _Forwarders:
|
|
def __init__(self, trusted_hosts: list[str] | str) -> None:
|
|
self.always_trust: bool = trusted_hosts in ('*', ['*'])
|
|
self.literals: set[str] = set()
|
|
self.hosts: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
|
|
self.networks: set[ipaddress.IPv4Network | ipaddress.IPv6Network] = set()
|
|
|
|
if self.always_trust:
|
|
return
|
|
|
|
if isinstance(trusted_hosts, str):
|
|
trusted_hosts = _parse_raw_hosts(trusted_hosts)
|
|
|
|
for host in trusted_hosts:
|
|
try:
|
|
if '/' in host:
|
|
self.networks.add(ipaddress.ip_network(host))
|
|
continue
|
|
self.hosts.add(ipaddress.ip_address(host))
|
|
except ValueError:
|
|
self.literals.add(host)
|
|
|
|
def __contains__(self, host: str | None) -> bool:
|
|
if self.always_trust:
|
|
return True
|
|
|
|
if not host:
|
|
return False
|
|
|
|
try:
|
|
ip = ipaddress.ip_address(host)
|
|
if ip in self.hosts:
|
|
return True
|
|
return any(ip in net for net in self.networks)
|
|
except ValueError:
|
|
return host in self.literals
|
|
|
|
def get_client_host(self, x_forwarded_for: str) -> str:
|
|
x_forwarded_for_hosts = _parse_raw_hosts(x_forwarded_for)
|
|
|
|
if self.always_trust:
|
|
return x_forwarded_for_hosts[0]
|
|
|
|
for host in reversed(x_forwarded_for_hosts):
|
|
if host not in self:
|
|
return host
|
|
|
|
return x_forwarded_for_hosts[0]
|
|
|
|
|
|
def _parse_raw_hosts(value: str) -> list[str]:
|
|
return [item.strip() for item in value.split(',')]
|
|
|
|
|
|
def wrap_asgi_with_proxy_headers(app, trusted_hosts: list[str] | str = '127.0.0.1'):
|
|
forwarders = _Forwarders(trusted_hosts)
|
|
|
|
@_wraps(app)
|
|
def wrapped(scope, receive, send):
|
|
if scope['type'] == 'lifespan':
|
|
return app(scope, receive, send)
|
|
|
|
client_addr = scope.get('client')
|
|
client_host = client_addr[0] if client_addr else None
|
|
|
|
if client_host in forwarders:
|
|
headers = dict(scope['headers'])
|
|
|
|
if x_forwarded_proto := headers.get(b'x-forwarded-proto', b'').decode('latin1').strip():
|
|
if x_forwarded_proto in {'http', 'https', 'ws', 'wss'}:
|
|
if scope['type'] == 'websocket':
|
|
scope['scheme'] = x_forwarded_proto.replace('http', 'ws')
|
|
else:
|
|
scope['scheme'] = x_forwarded_proto
|
|
|
|
if x_forwarded_for := headers.get(b'x-forwarded-for', b'').decode('latin1'):
|
|
if host := forwarders.get_client_host(x_forwarded_for):
|
|
scope['client'] = (host, 0)
|
|
|
|
return app(scope, receive, send)
|
|
|
|
return wrapped
|
|
|
|
|
|
def wrap_wsgi_with_proxy_headers(app, trusted_hosts: list[str] | str = '127.0.0.1'):
|
|
forwarders = _Forwarders(trusted_hosts)
|
|
|
|
@_wraps(app)
|
|
def wrapped(scope, resp):
|
|
client_host = scope.get('REMOTE_ADDR')
|
|
|
|
if client_host in forwarders:
|
|
if x_forwarded_proto := scope.get('HTTP_X_FORWARDED_PROTO'):
|
|
if x_forwarded_proto in {'http', 'https'}:
|
|
scope['wsgi.url_scheme'] = x_forwarded_proto
|
|
|
|
if x_forwarded_for := scope.get('HTTP_X_FORWARDED_FOR'):
|
|
if host := forwarders.get_client_host(x_forwarded_for):
|
|
scope['REMOTE_ADDR'] = host
|
|
|
|
return app(scope, resp)
|
|
|
|
return wrapped
|