Source code for gptcache.adapter.api

# pylint: disable=wrong-import-position
from typing import Any, Optional, Callable

import gptcache.processor.post
import gptcache.processor.pre
from gptcache import Cache, cache, Config
from gptcache.adapter.adapter import adapt
from gptcache.embedding import (
    Onnx,
    Huggingface,
    SBERT,
    FastText,
    Data2VecAudio,
    Timm,
    ViT,
    OpenAI,
    Cohere,
    Rwkv,
    PaddleNLP,
    UForm,
)
from gptcache.embedding.base import BaseEmbedding
from gptcache.manager import manager_factory
from gptcache.manager.data_manager import DataManager
from gptcache.processor.context import (
    SummarizationContextProcess,
    SelectiveContextProcess,
    ConcatContextProcess,
)
from gptcache.processor.post import temperature_softmax
from gptcache.processor.pre import get_prompt
from gptcache.similarity_evaluation import (
    SearchDistanceEvaluation,
    NumpyNormEvaluation,
    OnnxModelEvaluation,
    ExactMatchEvaluation,
    KReciprocalEvaluation,
    SimilarityEvaluation,
    CohereRerankEvaluation,
    SequenceMatchEvaluation,
    TimeEvaluation,
    SbertCrossencoderEvaluation
)
from gptcache.utils import import_ruamel


def _cache_data_converter(cache_data):
    """For cache results, do nothing"""
    return cache_data


def _update_cache_callback_none(
    llm_data, update_cache_func, *args, **kwargs  # pylint: disable=W0613
) -> None:
    """When updating cached data, do nothing, because currently only cached queries are processed"""
    return None


def _llm_handle_none(*llm_args, **llm_kwargs) -> None:  # pylint: disable=W0613
    """Do nothing on a cache miss"""
    return None


def _update_cache_callback(
    llm_data, update_cache_func, *args, **kwargs
):  # pylint: disable=W0613
    """Save the `llm_data` to cache storage"""
    update_cache_func(llm_data)


[docs]def put(prompt: str, data: Any, **kwargs) -> None: """put api, put qa pair information to GPTCache Please make sure that the `pre_embedding_func` param is `get_prompt` when initializing the cache :param prompt: the cache data key, usually question text :type prompt: str :param data: the cache data value, usually answer text :type data: Any :param kwargs: list of user-defined parameters :type kwargs: Dict Example: .. code-block:: python from gptcache.adapter.api import put from gptcache.processor.pre import get_prompt cache.init(pre_embedding_func=get_prompt) put("hello", "foo") """ def llm_handle(*llm_args, **llm_kwargs): # pylint: disable=W0613 return data adapt( llm_handle, _cache_data_converter, _update_cache_callback, cache_skip=True, prompt=prompt, **kwargs, )
[docs]def get(prompt: str, **kwargs) -> Any: """get api, get the cache data according to the `prompt` Please make sure that the `pre_embedding_func` param is `get_prompt` when initializing the cache :param prompt: the cache data key, usually question text :type prompt: str :param kwargs: list of user-defined parameters :type kwargs: Dict Example: .. code-block:: python from gptcache.adapter.api import put, get from gptcache.processor.pre import get_prompt cache.init(pre_embedding_func=get_prompt) put("hello", "foo") print(get("hello")) """ res = adapt( _llm_handle_none, _cache_data_converter, _update_cache_callback_none, prompt=prompt, **kwargs, ) return res
[docs]def init_similar_cache( data_dir: str = "api_cache", cache_obj: Optional[Cache] = None, pre_func: Callable = get_prompt, embedding: Optional[BaseEmbedding] = None, data_manager: Optional[DataManager] = None, evaluation: Optional[SimilarityEvaluation] = None, post_func: Callable = temperature_softmax, config: Config = Config(), ): """Provide a quick way to initialize cache for api service :param data_dir: cache data storage directory :type data_dir: str :param cache_obj: specify to initialize the Cache object, if not specified, initialize the global object :type cache_obj: Optional[Cache] :param pre_func: pre-processing of the cache input text :type pre_func: Callable :param embedding: embedding object :type embedding: BaseEmbedding :param data_manager: data manager object :type data_manager: DataManager :param evaluation: similarity evaluation object :type evaluation: SimilarityEvaluation :param post_func: post-processing of the cached result list, the most similar result is taken by default :type post_func: Callable[[List[Any]], Any] :param config: cache configuration, the core is similar threshold :type config: Config :return: None Example: .. code-block:: python from gptcache.adapter.api import put, get, init_similar_cache init_similar_cache() put("hello", "foo") print(get("hello")) """ if not embedding: embedding = Onnx() if not data_manager: data_manager = manager_factory( "sqlite,faiss", data_dir=data_dir, vector_params={"dimension": embedding.dimension}, ) if not evaluation: evaluation = SearchDistanceEvaluation() cache_obj = cache_obj if cache_obj else cache cache_obj.init( pre_embedding_func=pre_func, embedding_func=embedding.to_embeddings, data_manager=data_manager, similarity_evaluation=evaluation, post_process_messages_func=post_func, config=config, )
[docs]def init_similar_cache_from_config(config_dir: str, cache_obj: Optional[Cache] = None): import_ruamel() from ruamel.yaml import YAML # pylint: disable=C0415 if config_dir: with open(config_dir, "r", encoding="utf-8") as f: yaml = YAML(typ="unsafe", pure=True) init_conf = yaml.load(f) else: init_conf = {} # Due to the problem with the first naming, it is reserved to ensure compatibility embedding = init_conf.get("model_source", "") if not embedding: embedding = init_conf.get("embedding", "onnx") # ditto embedding_config = init_conf.get("model_config", {}) if not embedding_config: embedding_config = init_conf.get("embedding_config", {}) embedding_model = _get_model(embedding, embedding_config) storage_config = init_conf.get("storage_config", {}) storage_config.setdefault("manager", "sqlite,faiss") storage_config.setdefault("data_dir", "gptcache_data") storage_config.setdefault("vector_params", {}) storage_config["vector_params"] = storage_config["vector_params"] or {} storage_config["vector_params"]["dimension"] = embedding_model.dimension data_manager = manager_factory(**storage_config) eval_strategy = init_conf.get("evaluation", "distance") # Due to the problem with the first naming, it is reserved to ensure compatibility eval_config = init_conf.get("evaluation_kws", {}) if not eval_config: eval_config = init_conf.get("evaluation_config", {}) evaluation = _get_eval(eval_strategy, eval_config) cache_obj = cache_obj if cache_obj else cache pre_process = init_conf.get("pre_context_function") if pre_process: pre_func = _get_pre_context_function( pre_process, init_conf.get("pre_context_config") ) pre_func = pre_func.pre_process else: pre_process = init_conf.get("pre_function", "get_prompt") pre_func = _get_pre_func(pre_process) post_process = init_conf.get("post_function", "first") post_func = _get_post_func(post_process) config_kws = init_conf.get("config", {}) or {} config = Config(**config_kws) cache_obj.init( pre_embedding_func=pre_func, embedding_func=embedding_model.to_embeddings, data_manager=data_manager, similarity_evaluation=evaluation, post_process_messages_func=post_func, config=config, ) return init_conf
def _get_model(model_src, model_config=None): model_src = model_src.lower() model_config = model_config or {} if model_src == "onnx": return Onnx(**model_config) if model_src == "huggingface": return Huggingface(**model_config) if model_src == "sbert": return SBERT(**model_config) if model_src == "fasttext": return FastText(**model_config) if model_src == "data2vecaudio": return Data2VecAudio(**model_config) if model_src == "timm": return Timm(**model_config) if model_src == "vit": return ViT(**model_config) if model_src == "openai": return OpenAI(**model_config) if model_src == "cohere": return Cohere(**model_config) if model_src == "rwkv": return Rwkv(**model_config) if model_src == "paddlenlp": return PaddleNLP(**model_config) if model_src == "uform": return UForm(**model_config) def _get_eval(strategy, kws=None): strategy = strategy.lower() kws = kws or {} if "distance" in strategy: return SearchDistanceEvaluation(**kws) if "np" in strategy: return NumpyNormEvaluation(**kws) if "exact" in strategy: return ExactMatchEvaluation() if "onnx" in strategy: return OnnxModelEvaluation(**kws) if "kreciprocal" in strategy: return KReciprocalEvaluation(**kws) if "cohere" in strategy: return CohereRerankEvaluation(**kws) if "sequence_match" in strategy: return SequenceMatchEvaluation(**kws) if "time" in strategy: return TimeEvaluation(**kws) if "sbert_crossencoder" in strategy: return SbertCrossencoderEvaluation(**kws) def _get_pre_func(pre_process): return getattr(gptcache.processor.pre, pre_process) def _get_pre_context_function(pre_context_process, kws=None): pre_context_process = pre_context_process.lower() kws = kws or {} if pre_context_process in "summarization": return SummarizationContextProcess(**kws) if pre_context_process in "selective": return SelectiveContextProcess(**kws) if pre_context_process in "concat": return ConcatContextProcess() def _get_post_func(post_process): return getattr(gptcache.processor.post, post_process)