"""
SmarterMiddlewareMixin: A mixin for middleware classes with helper functions.
"""
import ipaddress
import re
from collections.abc import Awaitable
from typing import Optional
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
from django.http import HttpRequest
from django.http.response import HttpResponseBase
from smarter.common.conf import smarter_settings
from smarter.common.utils import (
is_authenticated_request,
)
from .helper_mixin import SmarterHelperMixin
from .logger import logger
MOCK_REGEX = re.compile(r"<MagicMock|<Mock|mock\\.MagicMock|mock\\.Mock", re.IGNORECASE)
[docs]
class SmarterMiddlewareMixin(SmarterHelperMixin):
"""
A mixin for middleware classes with helper functions. The initialization
is a blatant copy of Django 6x cors middleware. This this is our base
case for working with ASGI and sync middlewares.
This mixin provides utilities for extracting client IP addresses,
checking authentication indicators, and other middleware-related helpers.
Inherits from :class:`SmarterHelperMixin`.
Adds a `smarter_process_id` attribute to the request which is helpful for
caching and logging to correlate logs across the same request flow, especially
in async contexts where thread-local storage is not reliable.
"""
smarter_process_id: int
[docs]
def __init__(self, get_response, *args, **kwargs):
super().__init__(*args, **kwargs)
self.async_mode = iscoroutinefunction(get_response)
if self.async_mode:
# Mark the class as async-capable, but do the actual switch
# inside __call__ to avoid swapping out dunder methods
markcoroutinefunction(self)
self.get_response = get_response
self.smarter_process_id = id(self)
def __call__(self, request: HttpRequest) -> HttpResponseBase | Awaitable[HttpResponseBase]:
if self.async_mode:
return self.__acall__(request)
self.set_smarter_process_id(request)
result = self.get_response(request)
assert isinstance(result, HttpResponseBase)
response = result
return response
async def __acall__(self, request: HttpRequest) -> HttpResponseBase:
self.set_smarter_process_id(request)
result = self.get_response(request)
assert not isinstance(result, HttpResponseBase)
response = await result
return response
[docs]
def set_smarter_process_id(self, request) -> None:
"""
Set smarter_process_id on request if not already set. This has to consider
the pipeline of middlewares. Each middleware class sets its own unique
smarter_process_id, but if a previous middleware has already set the request
object's smarter_process_id then we should not overwrite it. This allows all middlewares
in the pipeline to share the same process ID for logging and caching purposes.
"""
smarter_process_id = getattr(request, "smarter_process_id", None)
if not smarter_process_id:
setattr(request, "smarter_process_id", self.smarter_process_id)
[docs]
def get_client_ip(self, request) -> Optional[str]:
"""
Get client IP address from request.
This method attempts to determine the original client IP address,
accounting for proxies, load balancers, and CDNs. It checks common
headers set by proxies and falls back to ``REMOTE_ADDR``.
Notes
-----
- In AWS CLB → Kubernetes Nginx setups, the client IP flow is:
Client → CLB → Nginx Ingress → Django.
- CLB adds ``X-Forwarded-For`` with the original client IP.
- Nginx may add ``X-Real-IP`` or modify ``X-Forwarded-For``.
- Django sees ``REMOTE_ADDR`` as the Nginx IP (not the client IP).
- If using Cloudflare, it adds the ``CF-Connecting-IP`` header.
- Always validate IPs to avoid trusting spoofed headers.
:param request: The Django request object.
:type request: HttpRequest
:return: The detected client IP address, or None if not found.
:rtype: Optional[str]
"""
# First check X-Forwarded-For (most reliable for CLB)
# set by Nginx ingress controller and Traefik.
# Contains the original client IP and any proxy IPs.
forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
if forwarded_for:
# X-Forwarded-For format: "client_ip, proxy1_ip, proxy2_ip"
# The leftmost IP is the original client IP
client_ip = forwarded_for.split(",")[0].strip()
# Validate it's not a private IP (load balancer/proxy IP)
if not self._is_private_ip(client_ip):
logger.debug(
"%s.get_client_ip() - Using X-Forwarded-For: %s",
self.formatted_class_name,
client_ip,
)
return client_ip
# Check X-Real-IP (set by Nginx ingress controller and Traefik)
real_ip = request.META.get("HTTP_X_REAL_IP")
if real_ip and not self._is_private_ip(real_ip.strip()):
logger.debug(
"%s.get_client_ip() - Using X-Real-IP: %s",
self.formatted_class_name,
real_ip.strip(),
)
return real_ip.strip()
# Check Cloudflare connecting IP if using Cloudflare
cf_ip = request.META.get("HTTP_CF_CONNECTING_IP")
if cf_ip and not self._is_private_ip(cf_ip.strip()):
logger.debug(
"%s.get_client_ip() - Using CF-Connecting-IP: %s",
self.formatted_class_name,
cf_ip.strip(),
)
return cf_ip.strip()
# Fallback to REMOTE_ADDR (will be load balancer IP in AWS)
remote_addr = request.META.get("REMOTE_ADDR", "127.0.0.1")
logger.debug(
"%s.get_client_ip() - Falling back to REMOTE_ADDR: %s",
self.formatted_class_name,
remote_addr,
)
if not self._is_private_ip(remote_addr):
logger.debug(
"%s.get_client_ip() - Using REMOTE_ADDR: %s",
self.formatted_class_name,
remote_addr,
)
return remote_addr
if request.path.replace("/", "") not in self.amnesty_urls and not smarter_settings.environment_is_local:
logger.warning(
"%s __call()__ - Could not determine client IP: %s, Meta: %s",
self.formatted_class_name,
self.smarter_build_absolute_uri(request=request),
request.META,
)
return None
def _is_private_ip(self, ip):
"""Check if IP is in private/internal ranges."""
try:
ip_obj = ipaddress.ip_address(ip)
return ip_obj.is_private or ip_obj.is_loopback or ip_obj.is_link_local
except ValueError as e:
# Regex to match MagicMock or mock object string representations
ip_str = str(ip)
if MOCK_REGEX.search(ip_str) or "Mock" in getattr(ip, "__class__", type(ip)).__name__:
logger.warning(
"%s._is_private_ip() - Mock object detected as IP: %s", self.formatted_class_name, ip_str
)
else:
logger.warning(
"%s._is_private_ip() - Invalid IP address: %s, error: %s", self.formatted_class_name, ip_str, e
)
return True
[docs]
def has_auth_indicators(self, request):
"""
Check if request has authentication indicators (cookies, headers, etc.).
This method inspects the request for common authentication signals,
such as session cookies, CSRF tokens, authorization headers, API keys,
or Django's built-in authentication.
:param request: The Django request object.
:type request: HttpRequest
:return: True if authentication indicators are present, False otherwise.
:rtype: bool
"""
# Check for Django session cookie
if request.COOKIES.get("sessionid"):
return True
# Check for CSRF token (indicates active session)
if request.COOKIES.get("csrftoken"):
return True
# Check for Authorization header
if request.META.get("HTTP_AUTHORIZATION"):
return True
# Check for API key header
if request.META.get("HTTP_X_API_KEY"):
return True
# Check if user is authenticated (Django built-in)
if is_authenticated_request(request):
return True
return False
__all__ = [
"SmarterMiddlewareMixin",
]