Source code for gptcache.manager.data_manager

import pickle
from abc import abstractmethod, ABCMeta
from typing import List, Any, Optional, Union

import cachetools
import numpy as np
import requests

from gptcache.manager.eviction import EvictionBase
from gptcache.manager.eviction_manager import EvictionManager
from gptcache.manager.object_data.base import ObjectBase
from gptcache.manager.scalar_data.base import (
    CacheStorage,
    CacheData,
    DataType,
    Answer,
    Question,
)
from gptcache.manager.vector_data.base import VectorBase, VectorData
from gptcache.utils.error import CacheError, ParamError
from gptcache.utils.log import gptcache_log


[docs]class DataManager(metaclass=ABCMeta): """DataManager manage the cache data, including save and search"""
[docs] @abstractmethod def save(self, question, answer, embedding_data, **kwargs): pass
[docs] @abstractmethod def import_data( self, questions: List[Any], answers: List[Any], embedding_datas: List[Any], session_ids: List[Optional[str]], ): pass
[docs] @abstractmethod def get_scalar_data(self, res_data, **kwargs) -> CacheData: pass
[docs] def hit_cache_callback(self, res_data, **kwargs): pass
[docs] @abstractmethod def search(self, embedding_data, **kwargs): """search the data in the cache store accrodding to the embedding data :return: a list of search result, [[score, id], [score, id], ...] """ pass
[docs] def flush(self): pass
[docs] @abstractmethod def add_session(self, res_data, session_id, pre_embedding_data): pass
[docs] @abstractmethod def list_sessions(self, session_id, key): pass
[docs] @abstractmethod def delete_session(self, session_id): pass
[docs] def report_cache( self, user_question, cache_question, cache_question_id, cache_answer, similarity_value, cache_delta_time, ): pass
[docs] @abstractmethod def close(self): pass
[docs]class MapDataManager(DataManager): """MapDataManager, store all data in a map data structure. :param data_path: the path to save the map data, defaults to 'data_map.txt'. :type data_path: str :param max_size: the max size for the cache, defaults to 1000. :type max_size: int :param get_data_container: a Callable to get the data container, defaults to None. :type get_data_container: Callable Example: .. code-block:: python from gptcache.manager import get_data_manager data_manager = get_data_manager("data_map.txt", 1000) """ def __init__(self, data_path, max_size, get_data_container=None): if get_data_container is None: self.data = cachetools.LRUCache(max_size) else: self.data = get_data_container(max_size) self.data_path = data_path self.init()
[docs] def init(self): try: with open(self.data_path, "rb") as f: self.data = pickle.load(f) except FileNotFoundError: return except PermissionError: raise CacheError( # pylint: disable=W0707 f"You don't have permission to access this file <{self.data_path}>." )
[docs] def save(self, question, answer, embedding_data, **kwargs): if isinstance(question, Question): question = question.content session = kwargs.get("session", None) session_id = {session.name} if session else set() self.data[embedding_data] = (question, answer, embedding_data, session_id)
[docs] def import_data( self, questions: List[Any], answers: List[Any], embedding_datas: List[Any], session_ids: List[Optional[str]], ): if ( len(questions) != len(answers) or len(questions) != len(embedding_datas) or len(questions) != len(session_ids) ): raise ParamError("Make sure that all parameters have the same length") for i, embedding_data in enumerate(embedding_datas): self.data[embedding_data] = ( questions[i], answers[i], embedding_datas[i], {session_ids[i]} if session_ids[i] else set(), )
[docs] def get_scalar_data(self, res_data, **kwargs) -> CacheData: session = kwargs.get("session", None) if session: answer = ( res_data[1].answer if isinstance(res_data[1], Answer) else res_data[1] ) if not session.check_hit_func( session.name, list(res_data[3]), [res_data[0]], answer ): return None return CacheData(question=res_data[0], answers=res_data[1])
[docs] def search(self, embedding_data, **kwargs): try: return [self.data[embedding_data]] except KeyError: return []
[docs] def flush(self): try: with open(self.data_path, "wb") as f: pickle.dump(self.data, f) except PermissionError: gptcache_log.error( "You don't have permission to access this file %s.", self.data_path )
[docs] def add_session(self, res_data, session_id, pre_embedding_data): res_data[3].add(session_id)
[docs] def list_sessions(self, session_id=None, key=None): session_ids = set() for k in self.data: if session_id and session_id in self.data[k][3]: session_ids.add(k) elif len(self.data[k][3]) > 0: session_ids.update(self.data[k][3]) return list(session_ids)
[docs] def delete_session(self, session_id): keys = self.list_sessions(session_id=session_id) for k in keys: self.data[k][3].remove(session_id) if len(self.data[k][3]) == 0: del self.data[k]
[docs] def close(self): self.flush()
[docs]def normalize(vec): magnitude = np.linalg.norm(vec) normalized_v = vec / magnitude return normalized_v
[docs]class SSDataManager(DataManager): """Generate SSDataManage to manager the data. :param s: CacheStorage to manager the scalar data, it can be generated with :meth:`gptcache.manager.CacheBase`. :type s: CacheStorage :param v: VectorBase to manager the vector data, it can be generated with :meth:`gptcache.manager.VectorBase`. :type v: VectorBase :param max_size: the max size for the cache, defaults to 1000. :type max_size: int :param clean_size: the size to clean up, defaults to `max_size * 0.2`. :type clean_size: int :param eviction: The eviction policy, it is support "LRU" and "FIFO" now, and defaults to "LRU". :type eviction: str """ def __init__( self, s: CacheStorage, v: VectorBase, o: Optional[ObjectBase], max_size, clean_size, policy="LRU", ): self.max_size = max_size self.clean_size = clean_size self.s = s self.v = v self.o = o self.eviction_manager = EvictionManager(self.s, self.v) self.eviction_base = EvictionBase( name="memory", policy=policy, maxsize=max_size, clean_size=clean_size, on_evict=self._clear, ) self.eviction_base.put(self.s.get_ids(deleted=False)) def _clear(self, marked_keys): self.eviction_manager.soft_evict(marked_keys) if self.eviction_manager.check_evict(): self.eviction_manager.delete()
[docs] def save(self, question, answer, embedding_data, **kwargs): """Save the data and vectors to cache and vector storage. :param question: question data. :type question: str :param answer: answer data. :type answer: str, Answer or (Any, DataType) :param embedding_data: vector data. :type embedding_data: np.ndarray Example: .. code-block:: python import numpy as np from gptcache.manager import get_data_manager, CacheBase, VectorBase data_manager = get_data_manager(CacheBase('sqlite'), VectorBase('faiss', dimension=128)) data_manager.save('hello', 'hi', np.random.random((128, )).astype('float32')) """ session = kwargs.get("session", None) session_id = session.name if session else None self.import_data([question], [answer], [embedding_data], [session_id])
def _process_answer_data(self, answers: Union[Answer, List[Answer]]): if isinstance(answers, Answer): answers = [answers] new_ans = [] for ans in answers: if ans.answer_type != DataType.STR: new_ans.append(Answer(self.o.put(ans.answer), ans.answer_type)) else: new_ans.append(ans) return new_ans def _process_question_data(self, question: Union[str, Question]): if isinstance(question, Question): if question.deps is None: return question for dep in question.deps: if dep.dep_type == DataType.IMAGE_URL: dep.dep_type.data = self.o.put(requests.get(dep.data).content) return question return Question(question)
[docs] def import_data( self, questions: List[Any], answers: List[Answer], embedding_datas: List[Any], session_ids: List[Optional[str]], ): if ( len(questions) != len(answers) or len(questions) != len(embedding_datas) or len(questions) != len(session_ids) ): raise ParamError("Make sure that all parameters have the same length") cache_datas = [] embedding_datas = [ normalize(embedding_data) for embedding_data in embedding_datas ] for i, embedding_data in enumerate(embedding_datas): if self.o is not None and not isinstance(answers[i], str): ans = self._process_answer_data(answers[i]) else: ans = answers[i] cache_datas.append( CacheData( question=self._process_question_data(questions[i]), answers=ans, embedding_data=embedding_data.astype("float32"), session_id=session_ids[i], ) ) ids = self.s.batch_insert(cache_datas) self.v.mul_add( [ VectorData(id=ids[i], data=embedding_data) for i, embedding_data in enumerate(embedding_datas) ] ) self.eviction_base.put(ids)
[docs] def get_scalar_data(self, res_data, **kwargs) -> Optional[CacheData]: session = kwargs.get("session", None) cache_data = self.s.get_data_by_id(res_data[1]) if cache_data is None: return None if session: cache_answer = ( cache_data.answers[0].answer if isinstance(cache_data.answers[0], Answer) else cache_data.answers[0] ) res_list = self.list_sessions(key=res_data[1]) cache_session_ids, cache_questions = [r.session_id for r in res_list], [ r.session_question for r in res_list ] if not session.check_hit_func( session.name, cache_session_ids, cache_questions, cache_answer ): return None for ans in cache_data.answers: if ans.answer_type != DataType.STR: ans.answer = self.o.get(ans.answer) return cache_data
[docs] def hit_cache_callback(self, res_data, **kwargs): self.eviction_base.get(res_data[1])
[docs] def search(self, embedding_data, **kwargs): embedding_data = normalize(embedding_data) top_k = kwargs.get("top_k", -1) return self.v.search(data=embedding_data, top_k=top_k)
[docs] def flush(self): self.s.flush() self.v.flush()
[docs] def add_session(self, res_data, session_id, pre_embedding_data): self.s.add_session(res_data[1], session_id, pre_embedding_data)
[docs] def list_sessions(self, session_id=None, key=None): res = self.s.list_sessions(session_id, key) if key: return res if session_id: return list(r.id for r in res) return list(set(r.session_id for r in res))
[docs] def delete_session(self, session_id): keys = self.list_sessions(session_id=session_id) self.s.delete_session(keys)
[docs] def report_cache( self, user_question, cache_question, cache_question_id, cache_answer, similarity_value, cache_delta_time, ): self.s.report_cache( user_question, cache_question, cache_question_id, cache_answer, similarity_value, cache_delta_time, )
[docs] def close(self): self.s.close() self.v.close()