Source code for gptcache.core

import atexit
import os
from typing import Optional, List, Any

from gptcache.config import Config
from gptcache.embedding.string import to_embeddings as string_embedding
from gptcache.manager import get_data_manager
from gptcache.manager.data_manager import DataManager
from gptcache.processor.post import temperature_softmax
from gptcache.processor.pre import last_content
from gptcache.report import Report
from gptcache.similarity_evaluation import ExactMatchEvaluation
from gptcache.similarity_evaluation import SimilarityEvaluation
from gptcache.utils import import_openai
from gptcache.utils.cache_func import cache_all
from gptcache.utils.log import gptcache_log


[docs]class Cache: """GPTCache core object. Example: .. code-block:: python from gptcache import cache from gptcache.adapter import openai cache.init() cache.set_openai_key() """ # it should be called when start the cache system def __init__(self): self.has_init = False self.cache_enable_func = None self.pre_embedding_func = None self.embedding_func = None self.data_manager: Optional[DataManager] = None self.similarity_evaluation: Optional[SimilarityEvaluation] = None self.post_process_messages_func = None self.config = Config() self.report = Report() self.next_cache = None
[docs] def init( self, cache_enable_func=cache_all, pre_embedding_func=last_content, pre_func=None, embedding_func=string_embedding, data_manager: DataManager = get_data_manager(), similarity_evaluation=ExactMatchEvaluation(), post_process_messages_func=temperature_softmax, post_func=None, config=Config(), next_cache=None, ): """Pass parameters to initialize GPTCache. :param cache_enable_func: a function to enable cache, defaults to ``cache_all`` :param pre_embedding_func: a function to preprocess embedding, defaults to ``last_content`` :param pre_func: a function to preprocess embedding, same as ``pre_embedding_func`` :param embedding_func: a function to extract embeddings from requests for similarity search, defaults to ``string_embedding`` :param data_manager: a ``DataManager`` module, defaults to ``get_data_manager()`` :param similarity_evaluation: a module to calculate embedding similarity, defaults to ``ExactMatchEvaluation()`` :param post_process_messages_func: a function to post-process messages, defaults to ``temperature_softmax`` with a default temperature of 0.0 :param post_func: a function to post-process messages, same as ``post_process_messages_func`` :param config: a module to pass configurations, defaults to ``Config()`` :param next_cache: customized method for next cache """ self.has_init = True self.cache_enable_func = cache_enable_func self.pre_embedding_func = pre_func if pre_func else pre_embedding_func self.embedding_func = embedding_func self.data_manager: DataManager = data_manager self.similarity_evaluation = similarity_evaluation self.post_process_messages_func = post_func if post_func else post_process_messages_func self.config = config self.next_cache = next_cache @atexit.register def close(): try: self.data_manager.close() except Exception as e: # pylint: disable=W0703 if not os.getenv("IS_CI"): gptcache_log.error(e)
[docs] def import_data(self, questions: List[Any], answers: List[Any], session_ids: Optional[List[Optional[str]]] = None) -> None: """Import data to GPTCache :param questions: preprocessed question Data :param answers: list of answers to questions :param session_ids: list of the session id. :return: None """ self.data_manager.import_data( questions=questions, answers=answers, embedding_datas=[self.embedding_func(question) for question in questions], session_ids=session_ids if session_ids else [None for _ in range(len(questions))], )
[docs] def flush(self): """Flush data, to prevent accidental loss of memory data, such as using map cache management or faiss, hnswlib vector storage will be useful """ self.data_manager.flush() if self.next_cache: self.next_cache.data_manager.flush()
[docs] @staticmethod def set_openai_key(): import_openai() import openai # pylint: disable=C0415 openai.api_key = os.getenv("OPENAI_API_KEY")
[docs] @staticmethod def set_azure_openai_key(): import_openai() import openai # pylint: disable=C0415 openai.api_type = "azure" openai.api_key = os.getenv("OPENAI_API_KEY") openai.api_base = os.getenv("OPENAI_API_BASE") openai.api_version = os.getenv("OPENAI_API_VERSION")
cache = Cache()