Source code for smarter.lib.unittest.base_classes

"""
Project level base classes for unit tests.
"""

import csv
import logging
import unittest
from typing import Union

import yaml
from django.core.cache import cache
from django.http import HttpRequest
from django.test import RequestFactory

from smarter.common.helpers.console_helpers import formatted_text, formatted_text_red
from smarter.common.mixins import SmarterHelperMixin
from smarter.common.utils import hash_factory, to_snake_case
from smarter.lib import json

logger = logging.getLogger(__name__)
HERE = __name__
logger_prefix = formatted_text(f"{HERE}.SmarterTestBase()")


[docs] class SmarterTestBase(unittest.TestCase, SmarterHelperMixin): """Base class for all unit tests.""" name: str smarter_test_base_logger_prefix = formatted_text(f"{HERE}.SmarterTestBase()") line_width = 150
[docs] @classmethod def setUpClass(cls) -> None: """Set up the test class.""" super().setUpClass() title = f" {logger_prefix}.setUpClass() " msg = "*" * ((cls.line_width - len(title)) // 2) + title + "*" * ((cls.line_width - len(title)) // 2) logger.debug(msg) cls.hash_suffix = SmarterTestBase.generate_hash_suffix() cls.name = str(to_snake_case("smarterTestBase_" + cls.hash_suffix)) cls.uid = SmarterTestBase.generate_uid() cache.clear() logger.debug( "%s.setUpClass() Setting up test class with hash suffix: %s", cls.smarter_test_base_logger_prefix, cls.hash_suffix, ) logger.debug( "%s.setUpClass() Setting up test class with name: %s", cls.smarter_test_base_logger_prefix, cls.name ) logger.debug("%s.setUpClass() Setting up test class with uid: %s", cls.smarter_test_base_logger_prefix, cls.uid) logger.debug( "%s.setUpClass() %s", cls.smarter_test_base_logger_prefix, formatted_text_red("Django cache has been cleared"), )
[docs] @classmethod def tearDownClass(cls) -> None: """Tear down the test class.""" super().tearDownClass()
[docs] def setUp(self) -> None: """SetUp the test case.""" super().setUp() title = f" {logger_prefix}.{self._testMethodName}() " msg = "-" * ((self.line_width - len(title)) // 2) + title + "-" * ((self.line_width - len(title)) // 2) logger.debug(msg)
[docs] def tearDown(self) -> None: """Tear down the test case.""" title = f" {logger_prefix}.tearDown() {self._testMethodName} " msg = "-" * ((self.line_width - len(title)) // 2) + title + "-" * ((self.line_width - len(title)) // 2) logger.debug(msg) super().tearDown()
[docs] @classmethod def generate_uid(cls) -> str: """Generate a unique identifier for the test.""" return hash_factory(length=64)
[docs] @classmethod def get_readonly_yaml_file(cls, file_path) -> dict: """Read a YAML file in read-only mode.""" with open(file_path, encoding="utf-8") as file: return yaml.safe_load(file)
[docs] @classmethod def get_readonly_csv_file(cls, file_path) -> Union[dict, list[dict]]: """Read a CSV file in read-only mode.""" with open(file_path, encoding="utf-8") as file: reader = csv.DictReader(file) return list(reader)
[docs] @classmethod def get_readonly_json_file(cls, file_path) -> Union[dict, list]: """Read a JSON file in read-only mode.""" with open(file_path, encoding="utf-8") as file: return json.load(file)
[docs] @staticmethod def generate_hash_suffix(length: int = 16) -> str: """Generate a unique hash suffix for test data.""" return hash_factory(length=length)
[docs] def create_generic_request(self, url="http://example.com") -> HttpRequest: """Create a generic HTTP request for testing purposes.""" factory = RequestFactory() json_data = { "session_key": "6f3bdd1981e0cac2de5fdc7afc2fb4e565826473a124153220e9f6bf49bca67b", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, { "role": "assistant", "content": "Welcome to Smarter!. Following are some example prompts: blah blah blah", }, {"role": "smarter", "content": 'Tool call: function_calling_plugin_0002({"inquiry_type":"about"})'}, {"role": "user", "content": "Hello, World!"}, ], } json_data = json.dumps(json_data).encode("utf-8") headers = {} data = {} request: HttpRequest = factory.post(path=url, data=data, content_type="application/json", headers=headers) return request