Source code for smarter.apps.llm_client.middleware.security

"""This module is used to suppress DisallowedHost exception and return HttpResponseBadRequest instead."""

import fnmatch
from typing import Optional
from urllib.parse import urlparse

from django.http import HttpRequest
from django.middleware.security import SecurityMiddleware as DjangoSecurityMiddleware

from smarter.common.conf import smarter_settings
from smarter.common.mixins import SmarterHelperMixin
from smarter.lib import logging
from smarter.lib.django import waffle
from smarter.lib.django.http.shortcuts import (
    SmarterHttpResponseBadRequest,
    SmarterHttpResponseServerError,
)
from smarter.lib.django.validators import SmarterValidator
from smarter.lib.django.waffle import SmarterWaffleSwitches

from ..models import LLMClient, get_cached_llm_client_by_request

logger = logging.getSmarterLogger(__name__, any_switches=[SmarterWaffleSwitches.MIDDLEWARE_LOGGING])
logger.debug(
    "%s is %s",
    logging.formatted_text(__name__ + ".SmarterSecurityMiddleware"),
    SmarterHelperMixin().formatted_state_ready,
)


[docs] class SmarterSecurityMiddleware(DjangoSecurityMiddleware, SmarterHelperMixin): """ This middleware overrides Django’s built-in ``SecurityMiddleware`` to provide custom host validation logic for the Smarter platform. **Key Features:** - **Custom Host Validation:** Instead of relying solely on Django’s ``ALLOWED_HOSTS``, this middleware introduces ``smarter_settings.allowed_hosts``. It checks incoming requests against both the traditional allowed hosts and a dynamic list of domains associated with deployed LLMClients. - **LLMClient Domain Support:** If the request’s host matches a domain for a deployed LLMClient, the request is allowed to pass through. This enables flexible multi-tenant deployments where each LLMClient can have its own domain. - **Friendly Error Handling:** The middleware suppresses Django’s default ``DisallowedHost`` exception. Instead, it returns a ``HttpResponseBadRequest`` (400) response, which is not logged and is more user-friendly for clients. - **Health Check Short-Circuiting:** Requests from internal IP addresses or for health/readiness endpoints are allowed to pass through without further validation. This ensures that infrastructure health checks do not get blocked by host validation. - **Logging:** Uses a custom logger that respects feature flags (waffle switches) for granular control over middleware and llm_client logging. **Request Validation Steps:** 1. **Internal IPs:** Requests from internal IP addresses (e.g., load balancer health checks) are allowed. 2. **Local Hosts:** Requests from local hosts (e.g., ``localhost``, ``127.0.0.1``) are allowed. 3. **Health/Readiness URLs:** Requests to health or readiness endpoints are allowed. 4. **Allowed Hosts:** Requests matching any pattern in ``smarter_settings.allowed_hosts`` are allowed. 5. **LLMClient Domains:** Requests where the host matches a deployed LLMClient’s domain are allowed. 6. **Fallback:** All other requests are rejected with a ``400 Bad Request`` response. **Example Usage:** .. code-block:: python MIDDLEWARE = [ ... 'smarter.apps.llm_client.middleware.security.SmarterSecurityMiddleware', ... ] """ def __call__(self, request): if self.async_mode: return self.__acall__(request) if self.deserves_amnesty(request.path): return self.get_response(request) logger.debug("%s.__call__(): Request received: %s %s", self.formatted_class_name, request.method, request.path) return super().__call__(request) async def __acall__(self, request): """ Async version of __call__ that is swapped in when an async request. is running. """ logger.debug("%s.__acall__(): Request received: %s %s", self.formatted_class_name, request.method, request.path) return await super().__acall__(request) @property def formatted_class_name(self) -> str: class_name = f"{__name__}.{self.__class__.__name__}[{id(self)}]" return self.formatted_text(class_name) def process_request(self, request: HttpRequest): if not waffle.switch_is_active(SmarterWaffleSwitches.ENABLE_MIDDLEWARE_SECURITY): return None logger_prefix = logging.formatted_text(__name__ + "." + self.__class__.__name__) logger.debug( "%s.process_request() called for %s", logger_prefix, self.smarter_build_absolute_uri(request), ) # 1.) If the request is from an internal ip address, allow it to pass through # these typically originate from health checks from load balancers. # --------------------------------------------------------------------- # Short-circuit for health checks if request.path.replace("/", "") in self.amnesty_urls: logger.debug( "%s %s found in amnesty_urls: %s", logger_prefix, self.smarter_build_absolute_uri(request), self.amnesty_urls, ) return None host = request.get_host() if not host: return SmarterHttpResponseServerError( request=request, error_message="Internal error (500) - could not parse request.", ) # Short-circuit for any requests born from internal IP address hosts # This is unlikely, but not impossible. if any(host.startswith(prefix) for prefix in smarter_settings.internal_ip_prefixes): logger.debug( "%s %s identified as an internal IP address, exiting.", logger_prefix, self.smarter_build_absolute_uri(request), ) return None url = self.smarter_build_absolute_uri(request) # 2.) If the request is from a local host, allow it to pass through # --------------------------------------------------------------------- host_no_port = host.split(":")[0] base_host = host_no_port.split(".")[-1] if base_host in [h.rsplit(".", maxsplit=1)[-1] for h in SmarterValidator.LOCAL_HOSTS]: logger.debug( "%s %s base host matched in SmarterValidator.LOCAL_HOSTS: %s", logger_prefix, host, SmarterValidator.LOCAL_HOSTS, ) return None if host in SmarterValidator.LOCAL_HOSTS: logger.debug( "%s %s found in SmarterValidator.LOCAL_HOSTS: %s", logger_prefix, host, SmarterValidator.LOCAL_HOSTS, ) return None parsed_url = urlparse(url) # 3.) readiness and liveness checks # --------------------------------------------------------------------- path_parts = list(filter(None, parsed_url.path.split("/"))) # type: ignore[assignment] # if the entire path is healthz or readiness then we don't need to check if len(path_parts) == 1 and path_parts[0] in self.amnesty_urls: logger.debug( "%s %s found in amnesty_urls: %s", logger_prefix, host, path_parts, ) return None # 4.) Acme challenge requests should be allowed through # http://platform.example.com/.well-known/acme-challenge/RYdbP7-MUXbQRZI1CZj-KKySBkHwHze8z04cjyN18Bk # http://stackademy-sql.3141-5926-5359.api.example.com/.well-known/acme-challenge/QrRzO7QE7y6DhV8UqhfdD4_OoQ3Yh6XLR1qbJCRGcls # --------------------------------------------------------------------- if ".well-known/acme-challenge" in str(parsed_url.path): logger.debug( "%s %s identified as an ACME challenge request, exiting.", logger_prefix, url, ) return None # 5.) If the host is in the list of allowed hosts for # our environment then allow it to pass through # --------------------------------------------------------------------- for allowed_host in smarter_settings.allowed_hosts: if fnmatch.fnmatch(host, allowed_host): logger.debug( "%s %s matched with smarter_settings.allowed_hosts: %s", logger_prefix, host, allowed_host, ) return None # 6.) If the host is a domain for a deployed LLMClient, allow it to pass through # FIX NOTE: this is ham fisted and should be refactored. we shouldn't need # to instantiate a LLMClientHelper object just to check if the host is a domain # for a deployed LLMClient. # --------------------------------------------------------------------- logger.debug("%s instantiating LLMClientHelper() for url: %s", logger_prefix, url) llm_client: Optional[LLMClient] = get_cached_llm_client_by_request(request=request) if llm_client is not None: logger.info("%s LLMClientHelper() verified that %s is an llm_client.", logger_prefix, url) return None # --------------------------------------------------------------------- # End of the road. Reject the request with a 400 Bad Request response. # --------------------------------------------------------------------- logger.error("%s %s failed security tests.", logger_prefix, url) return SmarterHttpResponseBadRequest( request=request, error_message="SecurityMiddleware() Bad Request (400) - Invalid Hostname." )