Source code for smarter.apps.provider.services.text_completion.providers

"""
Smarter.apps.provider.services.text_completion.providers
==================================================================

Service-level entry point for text completions supporting multiple LLM provider companies.
This module provides a unified interface for accessing and managing prompt completion providers,
enabling seamless integration with a variety of large language model (LLM) backends.

**Protocols Supported:**

1. **Smarter Prompt Protocol**
    - Implements: SmarterChatHandlerProtocol
    - Indirect service layer for /api/v1/prompts/smarter/<str:provider>/
    - Returns: SmarterChatCompletionResponseType
    - Used for native Smarter prompt API requests, supporting Smarter's extensibility model.

2. **OpenAI-Compatible Passthrough Protocol**
    - Implements: OpenAICompatiblePassthroughProtocol
    - Indirect service layer for /api/v1/prompts/passthrough/<str:provider>/
    - Returns: OpenAICompatibleChatCompletionResponseType
    - Used for OpenAI-compatible API passthrough, enabling direct proxying to third-party LLM providers.

**Key Features:**

- Centralized access to all configured prompt providers and their handlers.
- Supports both Smarter-native and OpenAI-compatible request/response formats.
- Provides default provider selection and handler resolution.
- Abstracts provider-specific complexities, including authentication and model selection.
- Enables dynamic handler retrieval for both protocols, facilitating flexible integration patterns.

**Common Features:**

- Both protocols support dynamic provider selection based on the incoming request and user context.
- Handlers for both protocols are designed to abstract away provider-specific details, such as authentication and model selection, allowing for flexible integration patterns.
- The factory class provides caching for provider ORM retrieval and client instantiation to optimize performance and reduce redundant database queries.
- Internal billing records are generated in a consistent manner regardless of the protocol used, ensuring accurate usage tracking and billing across all providers.
- Application-level logging is fully integrated into both protocols, with support for logging based on waffle switches to facilitate debugging and monitoring in production environments.

**Singletons:**

.. py:data:: smarter_compatible_client
   :type: OpenAICompatibleClientFactory

   Singleton instance of :class:`OpenAICompatibleClientFactory` configured for the Smarter-native protocol.
   This is the main entry point for consumers needing Smarter-native prompt completion handling.

.. py:data:: openai_compatible_client
   :type: OpenAICompatibleClientFactory

   Singleton instance of :class:`OpenAICompatibleClientFactory` configured for the OpenAI-compatible passthrough protocol.
   This is the main entry point for consumers needing OpenAI-compatible prompt completion handling and passthrough.
"""

import logging
from functools import cached_property
from typing import Any, List, Optional, Union

from django.core.handlers.asgi import ASGIRequest
from django.http import HttpRequest
from pydantic import SecretStr
from rest_framework.request import Request

from smarter.apps.account.models import User, UserProfile
from smarter.apps.plugin.plugin.base import PluginBase
from smarter.apps.prompt.models import Prompt
from smarter.apps.provider.clients import OpenAIPassthroughClient
from smarter.apps.provider.models import Provider
from smarter.apps.provider.services.text_completion.lib.openai_compatible_chat_provider import (
    OpenAISmarterClient,
)
from smarter.common.enum import SmarterEnumAbstract
from smarter.common.exceptions import SmarterValueError
from smarter.common.mixins import SmarterHelperMixin
from smarter.lib.cache import cache_results
from smarter.lib.django import waffle
from smarter.lib.django.waffle import SmarterWaffleSwitches
from smarter.lib.logging import WaffleSwitchedLoggerWrapper

from .lib.protocols import (
    OpenAICompatibleChatCompletionResponseType,
    OpenAICompatiblePassthroughProtocol,
    SmarterChatCompletionResponseType,
    SmarterChatHandlerProtocol,
)

ProviderRequestType = Union[ASGIRequest, Request, HttpRequest]


# pylint: disable=W0613
[docs] def should_log(level): """Check if logging should be done based on the waffle switch.""" return waffle.switch_is_active(SmarterWaffleSwitches.PROMPT_LOGGING)
base_logger = logging.getLogger(__name__) logger = WaffleSwitchedLoggerWrapper(base_logger, should_log)
[docs] class ClientTypeEnum(SmarterEnumAbstract): """ Client type distinguishes between the kind of handler we want. from the provider. """ SMARTER = OpenAISmarterClient.__name__ PASSTHROUGH = OpenAIPassthroughClient.__name__
[docs] class OpenAICompatibleClientFactory(SmarterHelperMixin): """ Service-level factory for OpenAI-compatible prompt completion clients. This class provides a unified interface for instantiating and managing prompt completion clients that support both Smarter-native and OpenAI-compatible passthrough protocols. It enables seamless integration with multiple LLM provider backends, abstracting provider-specific complexities such as authentication, model selection, and handler resolution. **Key Features:** - Centralized access to all configured prompt providers and their handlers. - Supports both Smarter-native and OpenAI-compatible request/response formats. - Provides default provider selection and handler resolution. - Abstracts provider-specific details, including authentication and model selection. - Enables dynamic handler retrieval for both protocols, facilitating flexible integration patterns. :param client_type: The type of client to instantiate (``SMARTER`` or ``PASSTHROUGH``). :type client_type: ClientTypeEnum, optional """ _client_type: ClientTypeEnum
[docs] def __init__(self, client_type: Optional[ClientTypeEnum] = ClientTypeEnum.SMARTER): super().__init__() if client_type is not None and client_type not in list(ClientTypeEnum): raise ValueError(f"Invalid client type: {client_type}. Must be one of {list(ClientTypeEnum)}") self._client_type = client_type or ClientTypeEnum.SMARTER
@property def client_type(self) -> ClientTypeEnum: """ Returns the client type of this factory instance. :return: The client type (``SMARTER`` or ``PASSTHROUGH``). :rtype: ClientTypeEnum """ return self._client_type
[docs] @cached_property def default_handler_name(self) -> str: """ Returns the name of the platform-wide default provider. If no default provider is found, it raises a SmarterValueError. :return: The name of the default provider. :rtype: str :raises SmarterValueError: If no default provider is found. """ provider = Provider.objects.filter(is_default=True, is_active=True).first() # type: ignore if not provider: raise SmarterValueError("Default provider not found") return provider.name
[docs] def get_client_orm_by_provider_name_and_user(self, provider_name: str, user: User) -> Provider: """ Retrieves the Provider ORM instance for the given provider name and user. :param provider_name: The name of the provider to retrieve. :param user: The user for whom to retrieve the provider. :return: The Provider ORM instance. :rtype: Provider :raises SmarterValueError: If the provider is not found for the user. """ @cache_results() def get_cached_provider_orm_by_name_and_username(provider_name: str, username: str) -> Provider: try: provider_orm = ( Provider.objects.filter(name=provider_name, is_active=True) .with_read_permission_for(user) # type: ignore .only("name", "base_url", "api_key") .first() ) # type: ignore if not provider_orm: raise Provider.DoesNotExist except Provider.DoesNotExist as e: raise SmarterValueError(f"Provider {provider_name} not found for user {user}.") from e except Provider.MultipleObjectsReturned as e: provider_orm = Provider.objects.filter(is_default=True, is_active=True).first() # type: ignore if not provider_orm: raise SmarterValueError(f"Provider {provider_name} not found for user {user}.") from e logger.warning( f"Multiple default providers found for user {username}. Choosing the first one: {provider_orm}." ) logger.debug( "%s.get_client_orm_by_provider_name_and_user() fetched and cached provider ORM for provider_name: %s, username: %s: %s", self.formatted_class_name + ".get_client_orm_by_provider_name_and_user()", provider_name, username, provider_orm, ) return provider_orm return get_cached_provider_orm_by_name_and_username(provider_name, user.username) # type: ignore
[docs] def get_openai_client_for_provider(self, provider_name: str, user: User) -> OpenAIPassthroughClient: """ Instantiates an OpenAIPassthroughClient for the given provider name and user. :param provider_name: The name of the provider for which to instantiate the client. :param user: The user for whom to instantiate the client. :return: An instance of OpenAIPassthroughClient configured for the specified provider and user. :rtype: OpenAIPassthroughClient """ @cache_results() def get_cached_openai_client_for_provider(provider_name: str, username: str) -> OpenAIPassthroughClient: provider_orm = self.get_client_orm_by_provider_name_and_user(provider_name, user) api_key = SecretStr(provider_orm.api_key.get_secret()) if provider_orm.api_key else None retval = OpenAIPassthroughClient( provider=provider_orm.name, base_url=provider_orm.base_url, api_key=api_key.get_secret_value() if api_key else "", ) logger.debug( "%s.get_openai_client_for_provider() instantiated and cached OpenAIPassthroughClient for provider_name: %s, username: %s: %s", self.formatted_class_name + ".get_openai_client_for_provider()", provider_name, username, retval, ) return retval return get_cached_openai_client_for_provider(provider_name, user.username) # type: ignore
[docs] def get_passthrough_handler( self, request: ProviderRequestType, provider_name: Optional[str] = None, **kwargs ) -> OpenAICompatiblePassthroughProtocol: """ A convenience method to get an OpenAI-compatible passthrough handler by provider name. :param request: The incoming HTTP request object. :param provider_name: The name of the provider for which to retrieve the handler. If not provided, the default provider will be used. :return: An OpenAI-compatible passthrough handler function that can be used to process prompt completion requests. :rtype: OpenAICompatiblePassthroughProtocol """ def get_handler( request: ProviderRequestType, user_profile: UserProfile, data: dict[str, Any], *args, **kwargs, ) -> OpenAICompatibleChatCompletionResponseType: """Expose the handler method of the default provider.""" client = self.get_openai_client_for_provider(provider_name=provider_name or self.default_handler_name, user=request.user) # type: ignore handler = client.handler(request, user_profile, data, *args, **kwargs) # type: ignore return handler provider_name = provider_name or self.default_handler_name return get_handler
[docs] def get_smarter_handler( self, request: ProviderRequestType, provider_name: Optional[str] = None, **kwargs ) -> SmarterChatHandlerProtocol: """ A convenience method to get a handler by provider name. :param request: The incoming HTTP request object. :param provider_name: The name of the provider for which to retrieve the handler. If not provided, the default provider will be used. :return: A handler function that can be used to process prompt completion requests according to the Smarter prompt protocol. :rtype: SmarterChatHandlerProtocol """ def get_handler( user_profile: UserProfile, prompt: Prompt, data: Union[dict[str, Any], list], plugins: Optional[List[PluginBase]] = None, functions: Optional[list[str]] = None, ) -> SmarterChatCompletionResponseType: """Expose the handler method of the default provider.""" client_orm = self.get_client_orm_by_provider_name_and_user( provider_name=provider_name or self.default_handler_name, user=request.user # type: ignore ) api_key = SecretStr(client_orm.api_key.get_secret()) if client_orm.api_key else None smarter_openai_compatible_provider = OpenAISmarterClient( provider=client_orm, provider_name=client_orm.name, base_url=client_orm.base_url, api_key=api_key, default_model=client_orm.default_model, ) handler = smarter_openai_compatible_provider.handler( user_profile, prompt, data, plugins=plugins, functions=functions ) return handler return get_handler
[docs] @cached_property def all(self) -> List[str]: """ Returns a list of all provider names. :return: A list of all active provider names. :rtype: List[str] """ return list(Provider.objects.filter(is_active=True).values_list("name", flat=True)) # type: ignore
[docs] def handler( self, request: ProviderRequestType, provider_name: Optional[str] = None, **kwargs ) -> Union[SmarterChatHandlerProtocol, OpenAICompatiblePassthroughProtocol]: """ A convenience method to get a handler by provider name. :param request: The incoming HTTP request object. :param provider_name: The name of the provider for which to retrieve the handler. If not provided, the default provider will be used. :return: A handler function that can be used to process prompt completion requests according to the specified protocol. :rtype: Union[SmarterChatHandlerProtocol, OpenAICompatiblePassthroughProtocol] """ if self.client_type == ClientTypeEnum.PASSTHROUGH: return self.get_passthrough_handler(request=request, provider_name=provider_name, **kwargs) return self.get_smarter_handler(request=request, provider_name=provider_name, **kwargs)
[docs] def default_handler( self, request: ProviderRequestType, **kwargs ) -> Union[SmarterChatHandlerProtocol, OpenAICompatiblePassthroughProtocol]: """ A convenience method to get the default handler. :param request: The incoming HTTP request object. :return: A handler function that can be used to process prompt completion requests according to the specified protocol. :rtype: Union[SmarterChatHandlerProtocol, OpenAICompatiblePassthroughProtocol] """ return self.handler(request=request, provider_name=self.default_handler_name, **kwargs)
smarter_compatible_client = OpenAICompatibleClientFactory(ClientTypeEnum.SMARTER) openai_compatible_client = OpenAICompatibleClientFactory(ClientTypeEnum.PASSTHROUGH)