"""
This module contains the middleware for handling CORS headers for the application.
It adds chatbot urls to the CORS_ALLOWED_ORIGINS list at run-time.
"""
import logging
import re
from collections.abc import Awaitable
from functools import cached_property, lru_cache
from typing import Optional, Pattern, Sequence, Union
from urllib.parse import SplitResult, urlsplit
from corsheaders.conf import conf
from corsheaders.middleware import CorsMiddleware
from django.http import HttpRequest
from django.http.response import HttpResponseBase
from smarter.apps.chatbot.models import ChatBot, get_cached_chatbot_by_request
from smarter.common.conf import smarter_settings
from smarter.common.const import SMARTER_LOCAL_PORT, SmarterEnvironments
from smarter.common.helpers.console_helpers import formatted_text
from smarter.common.mixins import SmarterHelperMixin
from smarter.lib.django import waffle
from smarter.lib.django.http.shortcuts import SmarterHttpResponseServerError
from smarter.lib.django.waffle import SmarterWaffleSwitches
from smarter.lib.logging import WaffleSwitchedLoggerWrapper
# pylint: disable=W0613
def should_log(level):
"""Check if logging should be done based on the waffle switch."""
return waffle.switch_is_active(SmarterWaffleSwitches.MIDDLEWARE_LOGGING)
base_logger = logging.getLogger(__name__)
logger = WaffleSwitchedLoggerWrapper(base_logger, should_log)
[docs]
class SmarterCorsMiddleware(CorsMiddleware, SmarterHelperMixin):
"""
Middleware for handling Cross-Origin Resource Sharing (CORS) headers in the application.
This middleware extends the default CORS handling to dynamically add chatbot URLs to the
allowed origins at runtime. It ensures that requests from valid chatbot origins are permitted
by updating the CORS allowed origins list based on the current request context.
The middleware also provides additional logic to handle internal IP addresses, health check
endpoints, and logging for debugging and auditing purposes.
:cvar _url: The parsed URL (as a :class:`urllib.parse.SplitResult`) for the current request, or None.
:vartype _url: Optional[SplitResult]
:cvar _chatbot: The chatbot instance associated with the current request, or None.
:vartype _chatbot: Optional[ChatBot]
:cvar request: The current Django HTTP request object, or None.
:vartype request: Optional[HttpRequest]
**Key Features**
- Dynamically adds chatbot URLs to the CORS allowed origins list.
- Handles requests from internal IP addresses and health check endpoints.
- Provides detailed logging for CORS-related events and decisions.
- Integrates with Django and the `django-cors-headers` package.
.. note::
- The chatbot URL is only added to the allowed origins if a chatbot is associated with the request.
- Internal requests and health checks are short-circuited for efficiency.
- Logging is controlled via a waffle switch and the application's log level.
**Example**
To enable this middleware, add it to your Django project's middleware settings::
MIDDLEWARE = [
...
'smarter.lib.django.middleware.cors.SmarterCorsMiddleware',
...
]
:param request: The incoming HTTP request object.
:type request: django.http.HttpRequest
:returns: The HTTP response object, potentially with CORS headers added.
:rtype: django.http.response.HttpResponseBase or Awaitable[HttpResponseBase]
"""
_url: Optional[SplitResult] = None
_chatbot: Optional[ChatBot] = None
request: Optional[HttpRequest] = None
@property
def formatted_class_name(self) -> str:
"""Return the formatted class name for logging purposes."""
return formatted_text(f"{__name__}.{SmarterCorsMiddleware.__name__}")
def __call__(self, request: HttpRequest) -> Union[HttpResponseBase, Awaitable[HttpResponseBase]]:
if not waffle.switch_is_active(SmarterWaffleSwitches.ENABLE_MIDDLEWARE_CORS):
return super().__call__(request)
host = request.get_host()
if not host:
return SmarterHttpResponseServerError(
request=request,
error_message="Internal error (500) - could not parse request.",
)
# Short-circuit for health checks
if request.path.replace("/", "") in self.amnesty_urls:
return super().__call__(request)
# Short-circuit for any requests born from internal IP address hosts
# This is unlikely, but not impossible.
if any(host.startswith(prefix) for prefix in smarter_settings.internal_ip_prefixes):
logger.debug(
"%s %s identified as an internal IP address, exiting.",
self.formatted_class_name,
self.smarter_build_absolute_uri(request),
)
return super().__call__(request)
url = self.smarter_build_absolute_uri(request)
logger.debug("%s.__call__() - url=%s", self.formatted_class_name, url)
self._url = None
self._chatbot = None
self.request = request
return super().__call__(request) # Ensure the response is returned
@property
def chatbot(self) -> Optional[ChatBot]:
return self._chatbot
@property
def url(self) -> Optional[SplitResult]:
if isinstance(self._url, SplitResult):
return self._url
@url.setter
def url(self, url: Optional[SplitResult] = None):
url_string = url.geturl() if isinstance(url, SplitResult) else None
if url_string in conf.CORS_ALLOWED_ORIGINS:
logger.debug(
"%s url: %s is an allowed origin",
self.formatted_class_name,
url.geturl() if isinstance(url, SplitResult) else "(Missing URL)",
)
return None
logger.debug(
"%s instantiating ChatBotHelper() for url: %s",
self.formatted_class_name,
url.geturl() if isinstance(url, SplitResult) else "(Missing URL)",
)
if self.request is not None:
self._chatbot = get_cached_chatbot_by_request(request=self.request)
# If the chatbot is found, update the chatbot url
# which ensures that we'll only be working with the
# base url for the chatbot and that the protocol
# will remain consistent.
if self.chatbot:
self._url = urlsplit(self.chatbot.url) # type: ignore[assignment]
else:
self._url = url
logger.debug(
"%s.url() set url: %s", self.formatted_class_name, self._url.geturl() if self._url else "(Missing URL)"
)
[docs]
@cached_property
def CORS_ALLOWED_ORIGINS(self) -> list[str] | tuple[str]:
"""
Returns the list of allowed origins for the application. If the request
is from a chatbot, the chatbot url is added to the list. If the host is
an api.local.smarter.sh domain, allow localhost for development.
"""
retval = (
conf.CORS_ALLOWED_ORIGINS.copy()
if isinstance(conf.CORS_ALLOWED_ORIGINS, list)
else list(conf.CORS_ALLOWED_ORIGINS)
)
# Add chatbot url if present
if self.chatbot is not None:
url = self.url.geturl() if isinstance(self.url, SplitResult) else None
if url is not None and url not in retval:
retval.append(url)
logger.info("%s.CORS_ALLOWED_ORIGINS() added origin: %s", self.formatted_class_name, url)
# Allow localhost if host is api.local.smarter.sh (for dev pairing)
request = getattr(self, "request", None)
if request is not None:
host = request.get_host()
if (
host
and f"{smarter_settings.api_subdomain}.{SmarterEnvironments.LOCAL}.{smarter_settings.root_domain}"
in host
):
localhost_origins = [
f"http://localhost:{SMARTER_LOCAL_PORT}",
f"http://127.0.0.1:{SMARTER_LOCAL_PORT}",
]
for origin in localhost_origins:
if origin not in retval:
retval.append(origin)
logger.info(
"%s.CORS_ALLOWED_ORIGINS() added localhost origins for dev: %s",
self.formatted_class_name,
localhost_origins,
)
return retval
@cached_property
def CORS_ALLOWED_ORIGIN_REGEXES(self) -> Sequence[str | Pattern[str]]:
# TODO: ADD CHATBOT URL
return conf.CORS_ALLOWED_ORIGIN_REGEXES
@lru_cache(maxsize=128)
def origin_found_in_white_lists(self, origin: str, url: SplitResult) -> bool:
self.url = url
if self.chatbot is not None:
logger.debug("%s.origin_found_in_white_lists() returning True: %s", self.formatted_class_name, url)
return True
return (
(origin == "null" and origin in self.CORS_ALLOWED_ORIGINS)
or self._url_in_whitelist(url)
or self.regex_domain_match(origin)
)
@lru_cache(maxsize=128)
def regex_domain_match(self, origin: str) -> bool:
if self.chatbot is not None:
logger.debug("%s.regex_domain_match() returning True: %s", self.formatted_class_name, self.url)
return True
return any(re.match(domain_pattern, origin) for domain_pattern in self.CORS_ALLOWED_ORIGIN_REGEXES)
@lru_cache(maxsize=128)
def _url_in_whitelist(self, url: SplitResult) -> bool:
self.url = url
if self.chatbot is not None:
logger.debug("%s._url_in_whitelist() returning True: %s", self.formatted_class_name, url)
return True
origins = [urlsplit(o) for o in self.CORS_ALLOWED_ORIGINS]
return any(origin.scheme == url.scheme and origin.netloc == url.netloc for origin in origins)