# pylint: disable=W0613,C0115
"""All models for the Provider app."""
import datetime
import logging
import os
import urllib.parse
from collections.abc import Sequence
from typing import Optional, TypedDict
import requests
from django.conf import settings
from django.db import models
from smarter.apps.account.models import (
Account,
MetaDataWithOwnershipModel,
Secret,
User,
UserProfile,
)
from smarter.apps.account.utils import (
get_cached_account_for_user,
get_cached_admin_user_for_account,
get_cached_smarter_admin_user_profile,
smarter_cached_objects,
)
from smarter.common.exceptions import (
SmarterBusinessRuleViolation,
SmarterConfigurationError,
SmarterValueError,
)
from smarter.common.helpers.logger_helpers import formatted_text
from smarter.common.utils import rfc1034_compliant_str
from smarter.lib.cache import cache_results
from smarter.lib.django import waffle
from smarter.lib.django.models import TimestampedModel
from smarter.lib.django.waffle import SmarterWaffleSwitches
from smarter.lib.logging import WaffleSwitchedLoggerWrapper
from .const import VERIFICATION_LEAD_TIME, VERIFICATION_LIFETIME
from .manifest.enum import ProviderModelEnum
from .signals import (
provider_activated,
provider_deactivated,
provider_deprecated,
provider_flagged,
provider_suspended,
provider_undeprecated,
provider_unflagged,
provider_unsuspended,
provider_verification_requested,
)
[docs]
def should_log(level):
"""Check if logging should be done based on the waffle switch."""
return waffle.switch_is_active(SmarterWaffleSwitches.PROVIDER_LOGGING) and waffle.switch_is_active(
SmarterWaffleSwitches.PLUGIN_LOGGING
)
base_logger = logging.getLogger(__name__)
logger = WaffleSwitchedLoggerWrapper(base_logger, should_log)
CACHE_TIMEOUT = int(60 / 2) # 30 seconds
[docs]
class ProviderModelTypedDict(TypedDict):
"""TypedDict for provider model information."""
api_key: str
provider_name: str
provider_id: int
base_url: str
model: str
max_completion_tokens: int
temperature: float
top_p: float
supports_streaming: bool
supports_tools: bool
supports_text_input: bool
supports_image_input: bool
supports_audio_input: bool
supports_embedding: bool
supports_fine_tuning: bool
supports_search: bool
supports_code_interpreter: bool
supports_image_generation: bool
supports_audio_generation: bool
supports_text_generation: bool
supports_translation: bool
supports_summarization: bool
[docs]
class ProviderStatus(models.TextChoices):
UNVERIFIED = "unverified", "Unverified"
VERIFYING = "verifying", "Verifying"
FAILED = "failed", "Verification Failed"
VERIFIED = "verified", "Verified"
SUSPENDED = "suspended", "Suspended"
DEPRECATED = "deprecated", "Deprecated"
[docs]
class ProviderVerificationTypes(models.TextChoices):
API_CONNECTIVITY = "api_connectivity", "Api Connectivity"
LOGO = "logo", "Logo"
CONTACT_EMAIL = "contact_email", "Contact Email"
SUPPORT_EMAIL = "support_email", "Support Email"
WEBSITE_URL = "website_url", "Website URL"
TOS_URL = "tos_url", "Terms of Service URL"
PRIVACY_POLICY_URL = "privacy_policy_url", "Privacy Policy URL"
TOS_ACCEPTANCE = "tos_acceptance", "Terms of Service Acceptance"
PRODUCTION_API_KEY = "production_api_key", "Production API Key"
[docs]
class ProviderModelVerificationTypes(models.TextChoices):
STREAMING = "streaming", "Streaming"
TOOLS = "tools", "Tools"
TEXT_INPUT = "text_input", "Text Input"
IMAGE_INPUT = "image_input", "Image Input"
AUDIO_INPUT = "audio_input", "Audio Input"
FINE_TUNING = "fine_tuning", "Fine Tuning"
SEARCH = "search", "Search"
CODE_INTERPRETER = "code_interpreter", "Code Interpreter"
TEXT_TO_IMAGE = "text_to_image", "Text to Image"
TEXT_TO_AUDIO = "text_to_audio", "Text to Audio"
TEXT_TO_TEXT = "text_to_text", "Text to Text"
TRANSLATION = "translation", "Translation"
SUMMARIZATION = "summarization", "Summarization"
[docs]
class Provider(MetaDataWithOwnershipModel):
"""Chat model."""
class Meta:
verbose_name = "Provider"
verbose_name_plural = "Providers"
status = models.CharField(
max_length=32,
choices=ProviderStatus.choices,
default=ProviderStatus.UNVERIFIED,
blank=False,
null=False,
)
# good things
is_active = models.BooleanField(default=False, blank=False, null=False)
is_verified = models.BooleanField(default=False, blank=False, null=False)
is_featured = models.BooleanField(default=False, blank=False, null=False)
# bad things
is_deprecated = models.BooleanField(default=False, blank=False, null=False)
is_flagged = models.BooleanField(default=False, blank=False, null=False)
is_suspended = models.BooleanField(default=False, blank=False, null=False)
# connectivity
base_url = models.URLField(max_length=255, blank=True, null=True, help_text="The base URL for the provider's API.")
api_key = models.ForeignKey(
Secret,
on_delete=models.SET_NULL,
blank=True,
null=True,
related_name="provider_api_key",
help_text="The API key for the provider.",
)
connectivity_test_path = models.CharField(
max_length=255,
default="",
blank=True,
null=True,
help_text="The URL to test connectivity with the provider's API.",
)
# Provider metadata
logo = models.ImageField(
upload_to="provider/provider_logos/",
blank=True,
null=True,
help_text="The logo of the provider.",
)
website_url = models.URLField(
max_length=255, blank=True, null=True, help_text="The website_url URL of the provider."
)
ownership_requested = models.EmailField(
max_length=255,
blank=True,
null=True,
help_text="The email address of an alternative contact who has requested to take ownership of the provider.",
)
contact_email = models.EmailField(
max_length=255, blank=True, null=True, help_text="The contact email of the provider."
)
contact_email_verified = models.DateTimeField(
blank=True,
null=True,
help_text="The date and time when the contact email was verified.",
)
support_email = models.EmailField(
max_length=255, blank=True, null=True, help_text="The support email of the provider."
)
support_email_verified = models.DateTimeField(
blank=True,
null=True,
help_text="The date and time when the support email was verified.",
)
docs_url = models.URLField(
max_length=255, blank=True, null=True, help_text="The documentation URL of the provider."
)
terms_of_service_url = models.URLField(
max_length=255, blank=True, null=True, help_text="The terms of service URL of the provider."
)
privacy_policy_url = models.URLField(
max_length=255, blank=True, null=True, help_text="The privacy policy URL of the provider."
)
tos_accepted_at = models.DateTimeField(
blank=True, null=True, help_text="The date and time when the terms of service were accepted."
)
tos_accepted_by = models.ForeignKey(
settings.AUTH_USER_MODEL,
on_delete=models.SET_NULL,
blank=True,
null=True,
related_name="tos_accepted_by",
help_text="The user who accepted the terms of service.",
)
@property
def is_official_provider(self) -> bool:
"""Check if the provider is an official provider."""
smarter_admin = get_cached_smarter_admin_user_profile()
return self.user_profile == smarter_admin.user
@property
def tos_accepted(self) -> bool:
"""Check if the terms of service have been accepted."""
return self.tos_accepted_at is not None and self.tos_accepted_by is not None
[docs]
def production_api_key(self, mask: bool = True) -> str:
"""Return the production API key for the provider."""
api_key_name = f"{self.name.upper()}_API_KEY"
api_key = os.environ.get(api_key_name)
if api_key is None:
raise SmarterConfigurationError(
f"Production API key for provider {self.name} was accessed but is not set in environment variables."
)
return api_key if not mask else "********"
@property
def authorization_header(self) -> dict:
"""Return the authorization header for the provider."""
if self.production_api_key(mask=False) is not None:
return {"Authorization": f"Bearer {self.production_api_key(mask=False)}"}
if self.api_key:
return {"Authorization": f"Bearer {self.api_key.get_secret()}"}
return {}
@property
def can_activate(self) -> bool:
"""Check if the provider can be activated."""
return (
self.status == ProviderStatus.VERIFIED
and not self.is_deprecated
and not self.is_suspended
and not self.is_flagged
and self.tos_accepted
and self.tos_accepted_at is not None
and self.tos_accepted_by is not None
)
@property
def rfc1034_compliant_name(self) -> Optional[str]:
"""
Returns a URL-friendly name for the chatbot.
This property returns an RFC 1034-compliant name for the chatbot, suitable for use in URLs and DNS labels.
**Example:**
.. code-block:: python
self.name = 'Example ChatBot 1'
self.rfc1034_compliant_name # 'example-chatbot-1'
:return: The RFC 1034-compliant name, or None if ``self.name`` is not set.
:rtype: Optional[str]
"""
if self.name:
return rfc1034_compliant_str(self.name)
return None
[docs]
def test_connectivity(self) -> bool:
"""
Test connectivity to the provider's API.
This method should be overridden by subclasses to implement specific connectivity tests.
"""
if not self.base_url:
raise SmarterValueError("base_url is not set for this provider.")
url = urllib.parse.urljoin(self.base_url, self.connectivity_test_path)
try:
if self.api_key is not None:
logger.info(
"%s verifying API connectivity and key for %s with URL: %s",
self.formatted_class_name,
self.name,
url,
)
response = requests.get(url, headers=self.authorization_header, timeout=10)
else:
logger.info(
"%s verifying API connectivity for %s with URL: %s", self.formatted_class_name, self.name, url
)
response = requests.get(url, timeout=10)
if response.status_code == 200:
return True
else:
logger.error(
"%s API URL and key verification for %s failed with status code: %s",
self.formatted_class_name,
self.name,
response.status_code,
)
return False
except requests.RequestException as exc:
logger.error(
"%s Got an unexpected error testing API URL and key verification for %s failed: %s",
self.formatted_class_name,
self.name,
exc,
)
return False
[docs]
def verify(self):
"""
Request a batch of acceptance tests.
Set the status but don't change the is_verified flag.
This is used to indicate that the provider is being verified but has not yet been activated.
"""
self.status = ProviderStatus.VERIFYING
self.save()
provider_verification_requested.send(
sender=self.__class__,
instance=self,
)
[docs]
def activate(self):
"""Activate the provider."""
if not self.can_activate:
if self.is_active:
self.deactivate()
if self.is_deprecated:
raise SmarterValueError("Provider is deprecated and cannot be activated.")
if self.is_suspended:
raise SmarterValueError("Provider is suspended and cannot be activated.")
if self.is_flagged:
raise SmarterValueError("Provider is flagged and cannot be activated.")
if not self.tos_accepted:
raise SmarterValueError("Terms of service must be accepted before activation.")
if self.tos_accepted_at is None:
raise SmarterValueError("Terms of service acceptance date must be set before activation.")
if self.tos_accepted_by is None:
raise SmarterValueError("Terms of service acceptance user must be set before activation.")
if self.status != ProviderStatus.VERIFIED:
raise SmarterValueError("Provider must be verified before activation.")
if not self.is_active:
self.is_active = True
self.save()
provider_activated.send(
sender=self.__class__,
instance=self,
)
[docs]
def deactivate(self):
"""Deactivate the provider."""
self.is_active = False
self.save()
provider_deactivated.send(
sender=self.__class__,
instance=self,
)
[docs]
def suspend(self):
"""Suspend the provider."""
self.status = ProviderStatus.SUSPENDED
self.is_suspended = True
self.save()
self.deactivate()
provider_suspended.send(
sender=self.__class__,
instance=self,
)
[docs]
def unsuspend(self):
"""Unsuspend the provider."""
self.reset()
provider_unsuspended.send(
sender=self.__class__,
instance=self,
)
[docs]
def deprecate(self):
"""Deprecate the provider."""
self.status = ProviderStatus.DEPRECATED
self.is_deprecated = True
self.save()
self.deactivate()
provider_deprecated.send(
sender=self.__class__,
instance=self,
)
[docs]
def undeprecate(self):
"""Undeprecate the provider."""
self.reset()
provider_undeprecated.send(
sender=self.__class__,
instance=self,
)
[docs]
def flag(self):
"""Flag the provider."""
self.is_flagged = True
self.save()
self.deactivate()
provider_flagged.send(
sender=self.__class__,
instance=self,
)
[docs]
def unflag(self):
"""Unflag the provider."""
self.is_flagged = False
self.save()
if self.can_activate:
self.activate()
else:
self.reset()
provider_unflagged.send(
sender=self.__class__,
instance=self,
)
[docs]
def reset(self):
"""Reset the provider to its initial state."""
self.status = ProviderStatus.UNVERIFIED
self.is_active = False
self.is_verified = False
self.is_deprecated = False
self.is_flagged = False
self.is_suspended = False
self.save()
[docs]
@classmethod
def get_cached_provider_by_account_id_and_name(
cls, invalidate: Optional[bool] = False, account_id: Optional[int] = None, name: Optional[str] = None
) -> Optional["Provider"]:
"""Get a cached provider by account ID and name."""
logger_prefix = formatted_text(
__name__ + "." + Provider.__name__ + ".get_cached_provider_by_account_id_and_name()"
)
@cache_results()
def cached_provider_by_account_id_and_name(account_id: int, name: str) -> Optional["Provider"]:
try:
logger.debug(
"%s.cached_provider_by_account_id_and_name() cache miss for account_id: %s, name: %s",
logger_prefix,
account_id,
name,
)
return cls.objects.get(user_profile__account__id=account_id, name=name)
except cls.DoesNotExist:
logger.debug(
"%s.cached_provider_by_account_id_and_name() no provider found for account_id: %s, name: %s",
logger_prefix,
account_id,
name,
)
return None
if invalidate:
cached_provider_by_account_id_and_name.invalidate(account_id, name)
provider = cached_provider_by_account_id_and_name(account_id, name)
return provider
[docs]
@classmethod
def get_cached_providers_for_user(
cls, invalidate: Optional[bool] = False, user: Optional[User] = None
) -> Sequence["Provider"]:
"""Get cached providers for a user."""
@cache_results()
def cached_providers_by_account_id(account_id: int) -> Sequence["Provider"]:
if not user_profile:
logger.debug(
"%s: No user profile found for user %s, returning empty list", cls.formatted_class_name, user
)
return []
admin_user = get_cached_admin_user_for_account(invalidate=invalidate, account=user_profile.cached_account) # type: ignore[arg-type]
admin_user_profile = UserProfile.get_cached_object(invalidate=invalidate, user=admin_user) # type: ignore[arg-type]
account_providers = (
Provider.objects.filter(user_profile=admin_user_profile)
.select_related(
"user_profile",
"user_profile__account",
"user_profile__user",
)
.order_by("name")
)
smarter_providers = (
Provider.objects.filter(user_profile=smarter_cached_objects.smarter_admin_user_profile)
.select_related(
"user_profile",
"user_profile__account",
"user_profile__user",
)
.order_by("name")
)
retval = list((account_providers | smarter_providers).distinct()) or []
logger.debug(
"%s.cached_providers_by_account_id() retrieved %s providers for account %s",
cls.formatted_class_name,
retval,
user_profile.account,
)
return retval
user_profile = UserProfile.get_cached_object(invalidate=invalidate, user=user)
if not user_profile:
logger.debug("%s: No user profile found for user %s, returning empty list", cls.formatted_class_name, user)
return []
if invalidate and user_profile and user_profile.account:
cached_providers_by_account_id.invalidate(user_profile.account.id)
if user_profile and user_profile.account:
providers = cached_providers_by_account_id(user_profile.account.id)
return list(providers) or []
return []
[docs]
@classmethod
def get_cached_provider_by_user_and_name(
cls, invalidate: Optional[bool] = False, user: Optional[User] = None, name: Optional[str] = ""
) -> Optional["Provider"]:
"""
Return a single instance of Provider by name for the given user.
This method caches the results to improve performance.
:param user: The user whose provider should be retrieved.
:type user: User
:param name: The name of the provider to retrieve.
:type name: str
:return: A Provider instance if found, otherwise None.
:rtype: Optional[Provider]
"""
account = get_cached_account_for_user(invalidate=invalidate, user=user)
if not account:
return None
return cls.get_cached_provider_by_account_id_and_name(invalidate=invalidate, account_id=account.id, name=name)
[docs]
def validate(self) -> None:
"""Validate the provider before saving."""
def __str__(self):
"""String representation of the provider."""
return f"{self.name} ({self.user_profile}) - {self.status}"
[docs]
class ProviderModel(TimestampedModel):
"""Provider completion models for a provider."""
class Meta:
verbose_name = "Provider Model"
verbose_name_plural = "Provider Models"
unique_together = (("provider", "name"),)
provider = models.ForeignKey(Provider, on_delete=models.CASCADE, blank=False, null=False)
name = models.CharField(max_length=255, blank=False, null=False)
description = models.TextField(blank=True, null=True)
# good things
is_default = models.BooleanField(default=False, blank=False, null=False)
is_active = models.BooleanField(default=False, blank=False, null=False)
# bad things
is_deprecated = models.BooleanField(default=False, blank=False, null=False)
is_flagged = models.BooleanField(default=False, blank=False, null=False)
is_suspended = models.BooleanField(default=False, blank=False, null=False)
# model configuration
max_completion_tokens = models.PositiveIntegerField(default=4096, blank=False, null=False)
temperature = models.FloatField(default=0.7, blank=False, null=False)
top_p = models.FloatField(default=1.0, blank=False, null=False)
# verifiable features - defaults to True
supports_text_input = models.BooleanField(default=True, blank=False, null=False)
supports_text_generation = models.BooleanField(default=True, blank=False, null=False)
supports_translation = models.BooleanField(default=True, blank=False, null=False)
supports_summarization = models.BooleanField(default=True, blank=False, null=False)
# verifiable features - defaults to False
supports_streaming = models.BooleanField(default=False, blank=False, null=False)
supports_tools = models.BooleanField(default=False, blank=False, null=False)
supports_image_input = models.BooleanField(default=False, blank=False, null=False)
supports_audio_input = models.BooleanField(default=False, blank=False, null=False)
supports_embedding = models.BooleanField(default=False, blank=False, null=False)
supports_fine_tuning = models.BooleanField(default=False, blank=False, null=False)
supports_search = models.BooleanField(default=False, blank=False, null=False)
supports_code_interpreter = models.BooleanField(default=False, blank=False, null=False)
supports_image_generation = models.BooleanField(default=False, blank=False, null=False)
supports_audio_generation = models.BooleanField(default=False, blank=False, null=False)
def __str__(self):
"""String representation of the model."""
return f"{self.provider.name} - {self.name}"
# ------------------------------------------------------------------------------
# Verification history for providers and provider models
# ------------------------------------------------------------------------------
[docs]
class ProviderVerification(TimestampedModel):
"""Provider completion model verifications for a provider."""
class Meta:
verbose_name = "Provider Verification"
verbose_name_plural = "Provider Verifications"
unique_together = (("provider", "verification_type"),)
provider = models.ForeignKey(Provider, on_delete=models.CASCADE, blank=False, null=False)
verification_type = models.CharField(
max_length=32,
choices=ProviderVerificationTypes.choices,
default=ProviderVerificationTypes.API_CONNECTIVITY,
blank=False,
null=False,
)
is_successful = models.BooleanField(default=False, blank=False, null=False)
error_message = models.TextField(blank=True, null=True)
@property
def is_valid(self) -> bool:
"""Check if the verification is valid."""
if not self.elapsed_updated:
return False
return self.is_successful and self.elapsed_updated < VERIFICATION_LIFETIME
@property
def next_verification(self) -> datetime.datetime:
"""Get the next verification time."""
return self.updated_at + VERIFICATION_LIFETIME - VERIFICATION_LEAD_TIME
def __str__(self):
"""String representation of the verification."""
return f"{self.provider.name} - {self.verification_type}: {'Success' if self.is_successful else 'Failed'}"
[docs]
class ProviderModelVerification(TimestampedModel):
"""Provider completion model verifications for a provider."""
class Meta:
verbose_name = "Provider Model Verification"
verbose_name_plural = "Provider Model Verifications"
unique_together = (("provider_model", "verification_type"),)
provider_model = models.ForeignKey(ProviderModel, on_delete=models.CASCADE, blank=False, null=False)
verification_type = models.CharField(
max_length=32,
choices=ProviderModelVerificationTypes.choices,
default=ProviderModelVerificationTypes.TEXT_INPUT,
blank=False,
null=False,
)
is_successful = models.BooleanField(default=False, blank=False, null=False)
error_message = models.TextField(blank=True, null=True)
@property
def is_valid(self) -> bool:
"""Check if the verification is valid."""
if not self.elapsed_updated:
return False
return self.is_successful and self.elapsed_updated < VERIFICATION_LIFETIME
@property
def next_verification(self) -> datetime.datetime:
"""Get the next verification time."""
return self.updated_at + VERIFICATION_LIFETIME - VERIFICATION_LEAD_TIME
def __str__(self):
"""String representation of the verification."""
return f"{self.provider_model.name} - {self.verification_type}: {'Success' if self.is_successful else 'Failed'}"
[docs]
@cache_results(timeout=CACHE_TIMEOUT)
def get_provider(provider_name: str) -> Provider:
"""
Get the provider by name and account number. This is the primary way to
retrieve a provider. Raises a Smarter error if anything goes wrong.
"""
try:
provider = Provider.objects.get(name=provider_name)
except Provider.DoesNotExist as e:
raise SmarterValueError(f"Provider {provider_name} does not exist.") from e
if not provider.account.is_active:
raise SmarterBusinessRuleViolation(f"Provider account {provider.account.account_number} is not active.")
# the Provider might be inactive for a variety of reasons: suspended, flagged, deprecated, or something else.
# We don't care why we just want to know if it is active or not.
if not provider.is_active:
raise SmarterBusinessRuleViolation(f"Provider {provider_name} is not active.")
return provider
[docs]
@cache_results(timeout=CACHE_TIMEOUT)
def get_providers() -> list[Provider]:
"""
Get all active providers. This is the primary way to retrieve all providers.
Raises a Smarter error if anything goes wrong.
"""
try:
providers = Provider.objects.filter(is_active=True).select_related(
"user_profile", "user_profile__account", "user_profile__user"
)
except Provider.DoesNotExist as e:
raise SmarterValueError("No active providers found.") from e
return list(providers)
[docs]
@cache_results(timeout=CACHE_TIMEOUT)
def get_model_for_provider(provider_name: str, model_name: Optional[str] = None) -> ProviderModelTypedDict:
"""
Get the model for a provider by name and account number. This is the
primary way to retrieve a model for a provider. Raises a Smarter error if
anything goes wrong.
"""
provider = get_provider(provider_name=provider_name)
# the Provider might be inactive for a variety of reasons: suspended, flagged, deprecated, or something else.
# We don't care why we just want to know if it is active or not.
if not provider.is_active:
raise SmarterBusinessRuleViolation(f"Provider {provider_name} is not active.")
# 3.) get the model for the provider
if model_name is not None:
try:
model = ProviderModel.objects.get(provider=provider, name=model_name)
except ProviderModel.DoesNotExist as e:
raise SmarterValueError(f"Model {model_name} for provider {provider_name} does not exist.") from e
else:
try:
model = ProviderModel.objects.get(provider=provider, is_default=True)
except ProviderModel.DoesNotExist as e:
raise SmarterValueError(f"No default model found for provider {provider_name}.") from e
# The model is periodically re-verified and is therefore subject to being inactived if any of
# it's verification tests fail.
# Again, we don't care why it is inactive, we just want to know if it is active or not.
if not model.is_active:
raise SmarterBusinessRuleViolation(f"Model {model_name} for provider {provider_name} is not active.")
return {
ProviderModelEnum.API_KEY.value: provider.production_api_key(mask=False),
ProviderModelEnum.PROVIDER_NAME.value: provider.name,
ProviderModelEnum.PROVIDER_ID.value: provider.id, # type: ignore[union-attr]
ProviderModelEnum.BASE_URL.value: provider.base_url,
ProviderModelEnum.MODEL.value: model.name,
ProviderModelEnum.MAX_TOKENS.value: model.max_completion_tokens,
ProviderModelEnum.TEMPERATURE.value: model.temperature,
ProviderModelEnum.TOP_P.value: model.top_p,
ProviderModelEnum.SUPPORTS_STREAMING.value: model.supports_streaming,
ProviderModelEnum.SUPPORTS_TOOLS.value: model.supports_tools,
ProviderModelEnum.SUPPORTS_TEXT_INPUT.value: model.supports_text_input,
ProviderModelEnum.SUPPORTS_IMAGE_INPUT.value: model.supports_image_input,
ProviderModelEnum.SUPPORTS_AUDIO_INPUT.value: model.supports_audio_input,
ProviderModelEnum.SUPPORTS_EMBEDDING.value: model.supports_embedding,
ProviderModelEnum.SUPPORTS_FINE_TUNING.value: model.supports_fine_tuning,
ProviderModelEnum.SUPPORTS_SEARCH.value: model.supports_search,
ProviderModelEnum.SUPPORTS_CODE_INTERPRETER.value: model.supports_code_interpreter,
ProviderModelEnum.SUPPORTS_IMAGE_GENERATION.value: model.supports_image_generation,
ProviderModelEnum.SUPPORTS_AUDIO_GENERATION.value: model.supports_audio_generation,
ProviderModelEnum.SUPPORTS_TEXT_GENERATION.value: model.supports_text_generation,
ProviderModelEnum.SUPPORTS_TRANSLATION.value: model.supports_translation,
ProviderModelEnum.SUPPORTS_SUMMARIZATION.value: model.supports_summarization,
}
[docs]
@cache_results(timeout=CACHE_TIMEOUT)
def get_models_for_provider(provider_name: str) -> list[ProviderModelTypedDict]:
"""
Get all models for a provider by name and account number. This is the
primary way to retrieve all models for a provider. Raises a Smarter error if
anything goes wrong.
"""
provider = get_provider(provider_name=provider_name)
provider_models = ProviderModel.objects.filter(provider=provider, is_active=True)
return [
get_model_for_provider(provider_name=provider_name, model_name=provider_model.name)
for provider_model in provider_models
]