"""OpenAI API request validators"""
from smarter.common.exceptions import SmarterValueError
from smarter.lib import json
from .const import OpenAIEndPoint, OpenAIMessageKeys, OpenAIObjectTypes
####################################################################################################
# Legacy openai validators
####################################################################################################
[docs]
def validate_temperature(temperature: any) -> None:
"""Ensure that temperature is a float between 0 and 1"""
try:
float_temperature = float(temperature)
if float_temperature < 0 or float_temperature > 1:
raise SmarterValueError("temperature should be between 0 and 1")
except SmarterValueError as exc:
raise SmarterValueError("Temperature must be a float") from exc
[docs]
def validate_max_completion_tokens(max_completion_tokens: any) -> None:
"""Ensure that max_completion_tokens is an int between 1 and 2048"""
if not isinstance(max_completion_tokens, int):
raise TypeError("max_completion_tokens should be an int")
if max_completion_tokens < 1 or max_completion_tokens > 2048:
raise SmarterValueError("max_completion_tokens should be between 1 and 2048")
[docs]
def validate_endpoint(end_point: any) -> None:
"""Ensure that end_point is a valid endpoint based on the OpenAIEndPoint enum"""
if not isinstance(end_point, str):
raise TypeError(f"Invalid end_point '{end_point}'. end_point should be a string.")
if end_point not in OpenAIEndPoint.all_endpoints:
raise SmarterValueError(f"Invalid end_point {end_point}. Should be one of {OpenAIEndPoint.all_endpoints}")
[docs]
def validate_object_types(object_type: any) -> None:
"""Ensure that object_type is a valid object type based on the OpenAIObjectTypes enum"""
if not isinstance(object_type, str):
raise TypeError(f"Invalid object_type '{object_type}'. object_type should be a string.")
if object_type not in OpenAIObjectTypes.all_object_types:
raise SmarterValueError(
f"Invalid object_type {object_type}. Should be one of {OpenAIObjectTypes.all_object_types}"
)
[docs]
def validate_request_body(request_body) -> None:
"""See openai.chat.completion.request.json"""
if not isinstance(request_body, dict):
raise TypeError("request body should be a dict")
[docs]
def validate_messages(request_body):
"""See openai.chat.completion.request.json"""
if "messages" not in request_body:
raise SmarterValueError("dict key 'messages' was not found in request body object")
messages = request_body["messages"]
if not isinstance(messages, list):
raise SmarterValueError("dict key 'messages' should be a JSON list")
for message in messages:
if not isinstance(message, dict):
raise SmarterValueError(f"invalid object type {type(message)} {message} found in messages list {messages}")
if "role" not in message:
raise SmarterValueError(f"dict key 'role' not found in message {json.dumps(message)}")
if message["role"] not in OpenAIMessageKeys.all_roles:
raise SmarterValueError(
f"invalid role {message['role']} found in message {json.dumps(message)}. "
f"Should be one of {OpenAIMessageKeys.all_roles}"
)
if "content" not in message:
raise SmarterValueError(f"dict key 'content' not found in message {json.dumps(message)}")
[docs]
def validate_completion_request(request_body, version: str = "v1") -> None:
"""See openai.chat.completion.request.json"""
validate_request_body(request_body=request_body)
validate_messages(request_body=request_body)
if version == "v1":
if "model" not in request_body:
raise SmarterValueError("dict key 'model' not found in request body object")
if "temperature" not in request_body:
raise SmarterValueError("dict key 'temperature' not found in request body object")
if "max_completion_tokens" not in request_body:
raise SmarterValueError("dict key 'max_completion_tokens' not found in request body object")
[docs]
def validate_embedding_request(request_body) -> None:
"""See openai.embedding.request.json"""
validate_request_body(request_body=request_body)
if "input_text" not in request_body:
raise SmarterValueError("dict key 'input_text' not found in request body object")