Source code for smarter.lib.django.middleware.cors

"""
This module contains the middleware for handling CORS headers for the application.
It adds chatbot urls to the CORS_ALLOWED_ORIGINS list at run-time.
"""

import logging
import re
from collections.abc import Awaitable
from functools import cached_property, lru_cache
from typing import Optional, Pattern, Sequence, Union
from urllib.parse import SplitResult, urlsplit

from corsheaders.conf import conf
from corsheaders.middleware import CorsMiddleware
from django.http import HttpRequest
from django.http.response import HttpResponseBase

from smarter.apps.chatbot.models import ChatBot, get_cached_chatbot_by_request
from smarter.common.conf import smarter_settings
from smarter.common.const import SMARTER_LOCAL_PORT, SmarterEnvironments
from smarter.common.helpers.console_helpers import formatted_text
from smarter.common.mixins import SmarterHelperMixin
from smarter.lib.django import waffle
from smarter.lib.django.http.shortcuts import SmarterHttpResponseServerError
from smarter.lib.django.waffle import SmarterWaffleSwitches
from smarter.lib.logging import WaffleSwitchedLoggerWrapper


# pylint: disable=W0613
def should_log(level):
    """Check if logging should be done based on the waffle switch."""
    return waffle.switch_is_active(SmarterWaffleSwitches.MIDDLEWARE_LOGGING)


base_logger = logging.getLogger(__name__)
logger = WaffleSwitchedLoggerWrapper(base_logger, should_log)


[docs] class SmarterCorsMiddleware(CorsMiddleware, SmarterHelperMixin): """ Middleware for handling Cross-Origin Resource Sharing (CORS) headers in the application. This middleware extends the default CORS handling to dynamically add chatbot URLs to the allowed origins at runtime. It ensures that requests from valid chatbot origins are permitted by updating the CORS allowed origins list based on the current request context. The middleware also provides additional logic to handle internal IP addresses, health check endpoints, and logging for debugging and auditing purposes. :cvar _url: The parsed URL (as a :class:`urllib.parse.SplitResult`) for the current request, or None. :vartype _url: Optional[SplitResult] :cvar _chatbot: The chatbot instance associated with the current request, or None. :vartype _chatbot: Optional[ChatBot] :cvar request: The current Django HTTP request object, or None. :vartype request: Optional[HttpRequest] **Key Features** - Dynamically adds chatbot URLs to the CORS allowed origins list. - Handles requests from internal IP addresses and health check endpoints. - Provides detailed logging for CORS-related events and decisions. - Integrates with Django and the `django-cors-headers` package. .. note:: - The chatbot URL is only added to the allowed origins if a chatbot is associated with the request. - Internal requests and health checks are short-circuited for efficiency. - Logging is controlled via a waffle switch and the application's log level. **Example** To enable this middleware, add it to your Django project's middleware settings:: MIDDLEWARE = [ ... 'smarter.lib.django.middleware.cors.SmarterCorsMiddleware', ... ] :param request: The incoming HTTP request object. :type request: django.http.HttpRequest :returns: The HTTP response object, potentially with CORS headers added. :rtype: django.http.response.HttpResponseBase or Awaitable[HttpResponseBase] """ _url: Optional[SplitResult] = None _chatbot: Optional[ChatBot] = None request: Optional[HttpRequest] = None @property def formatted_class_name(self) -> str: """Return the formatted class name for logging purposes.""" return formatted_text(f"{__name__}.{SmarterCorsMiddleware.__name__}") def __call__(self, request: HttpRequest) -> Union[HttpResponseBase, Awaitable[HttpResponseBase]]: if not waffle.switch_is_active(SmarterWaffleSwitches.ENABLE_MIDDLEWARE_CORS): return super().__call__(request) host = request.get_host() if not host: return SmarterHttpResponseServerError( request=request, error_message="Internal error (500) - could not parse request.", ) # Short-circuit for health checks if request.path.replace("/", "") in self.amnesty_urls: return super().__call__(request) # Short-circuit for any requests born from internal IP address hosts # This is unlikely, but not impossible. if any(host.startswith(prefix) for prefix in smarter_settings.internal_ip_prefixes): logger.debug( "%s %s identified as an internal IP address, exiting.", self.formatted_class_name, self.smarter_build_absolute_uri(request), ) return super().__call__(request) url = self.smarter_build_absolute_uri(request) logger.debug("%s.__call__() - url=%s", self.formatted_class_name, url) self._url = None self._chatbot = None self.request = request return super().__call__(request) # Ensure the response is returned @property def chatbot(self) -> Optional[ChatBot]: return self._chatbot @property def url(self) -> Optional[SplitResult]: if isinstance(self._url, SplitResult): return self._url @url.setter def url(self, url: Optional[SplitResult] = None): url_string = url.geturl() if isinstance(url, SplitResult) else None if url_string in conf.CORS_ALLOWED_ORIGINS: logger.debug( "%s url: %s is an allowed origin", self.formatted_class_name, url.geturl() if isinstance(url, SplitResult) else "(Missing URL)", ) return None logger.debug( "%s instantiating ChatBotHelper() for url: %s", self.formatted_class_name, url.geturl() if isinstance(url, SplitResult) else "(Missing URL)", ) if self.request is not None: self._chatbot = get_cached_chatbot_by_request(request=self.request) # If the chatbot is found, update the chatbot url # which ensures that we'll only be working with the # base url for the chatbot and that the protocol # will remain consistent. if self.chatbot: self._url = urlsplit(self.chatbot.url) # type: ignore[assignment] else: self._url = url logger.debug( "%s.url() set url: %s", self.formatted_class_name, self._url.geturl() if self._url else "(Missing URL)" )
[docs] @cached_property def CORS_ALLOWED_ORIGINS(self) -> list[str] | tuple[str]: """ Returns the list of allowed origins for the application. If the request is from a chatbot, the chatbot url is added to the list. If the host is an api.local.smarter.sh domain, allow localhost for development. """ retval = ( conf.CORS_ALLOWED_ORIGINS.copy() if isinstance(conf.CORS_ALLOWED_ORIGINS, list) else list(conf.CORS_ALLOWED_ORIGINS) ) # Add chatbot url if present if self.chatbot is not None: url = self.url.geturl() if isinstance(self.url, SplitResult) else None if url is not None and url not in retval: retval.append(url) logger.info("%s.CORS_ALLOWED_ORIGINS() added origin: %s", self.formatted_class_name, url) # Allow localhost if host is api.local.smarter.sh (for dev pairing) request = getattr(self, "request", None) if request is not None: host = request.get_host() if ( host and f"{smarter_settings.api_subdomain}.{SmarterEnvironments.LOCAL}.{smarter_settings.root_domain}" in host ): localhost_origins = [ f"http://localhost:{SMARTER_LOCAL_PORT}", f"http://127.0.0.1:{SMARTER_LOCAL_PORT}", ] for origin in localhost_origins: if origin not in retval: retval.append(origin) logger.info( "%s.CORS_ALLOWED_ORIGINS() added localhost origins for dev: %s", self.formatted_class_name, localhost_origins, ) return retval
@cached_property def CORS_ALLOWED_ORIGIN_REGEXES(self) -> Sequence[str | Pattern[str]]: # TODO: ADD CHATBOT URL return conf.CORS_ALLOWED_ORIGIN_REGEXES @lru_cache(maxsize=128) def origin_found_in_white_lists(self, origin: str, url: SplitResult) -> bool: self.url = url if self.chatbot is not None: logger.debug("%s.origin_found_in_white_lists() returning True: %s", self.formatted_class_name, url) return True return ( (origin == "null" and origin in self.CORS_ALLOWED_ORIGINS) or self._url_in_whitelist(url) or self.regex_domain_match(origin) ) @lru_cache(maxsize=128) def regex_domain_match(self, origin: str) -> bool: if self.chatbot is not None: logger.debug("%s.regex_domain_match() returning True: %s", self.formatted_class_name, self.url) return True return any(re.match(domain_pattern, origin) for domain_pattern in self.CORS_ALLOWED_ORIGIN_REGEXES) @lru_cache(maxsize=128) def _url_in_whitelist(self, url: SplitResult) -> bool: self.url = url if self.chatbot is not None: logger.debug("%s._url_in_whitelist() returning True: %s", self.formatted_class_name, url) return True origins = [urlsplit(o) for o in self.CORS_ALLOWED_ORIGINS] return any(origin.scheme == url.scheme and origin.netloc == url.netloc for origin in origins)