import datetime
from typing import List, Optional

import numpy as np

from gptcache.manager.scalar_data.base import (
from gptcache.utils import import_redis


# pylint: disable=C0413
from redis import Redis
from redis.client import Pipeline
from redis_om import get_redis_connection
from redis_om import JsonModel, EmbeddedJsonModel, NotFoundError, Field, Migrator

[docs]def get_models(global_key: str, redis_connection: Redis): """ Get all the models for the given global key and redis connection. :param global_key: Global key will be used as a prefix for all the keys :type global_key: str :param redis_connection: Redis connection to use for all the models. Note: This needs to be explicitly mentioned in `Meta` class for each Object Model, otherwise it will use the default connection from the pool. :type redis_connection: Redis """ class Counter: """ counter collection """ key_name = global_key + ":counter" database = redis_connection @classmethod def incr(cls): cls.database.incr(cls.key_name) @classmethod def get(cls): return cls.database.get(cls.key_name) class Embedding: """ Custom class for storing embedding result. An embedding of type ``bytes`` is stored against Hash record type for the provided key. :param pk: Primary key against which hash data for embedding would be stored :type pk: str :param embedding: Embedding information to store :type embedding: bytes Note: As of this implementation, redis-om doesn't have a good compatibility to store bytes data and successfully retrieve it without corruption. In addition to that, decoding while getting the response is disabled as well. """ prefix = global_key + ":embedding" def __init__(self, pk: str, embedding: bytes): = pk self.embedding = embedding def save(self, pipeline: Pipeline): pipeline.hset(self.prefix + ":" + str(, "embedding", self.embedding) @classmethod def get(cls, key: int, db: Redis): """ Returns embedding stored against the ``key``. Decode only key value while creating a response :param key: redis key to fetch embedding :type key: str """ result = db.hgetall(cls.prefix + ":" + str(key)) return {k.decode("utf-8"): v for k, v in result.items()} class Answers(EmbeddedJsonModel): """ answer collection """ answer: str answer_type: int class Meta: database = redis_connection class Questions(JsonModel): """ questions collection """ question: str = Field(index=True) create_on: datetime.datetime last_access: datetime.datetime deleted: int = Field(index=True) answers: List[Answers] class Meta: global_key_prefix = global_key model_key_prefix = "questions" database = redis_connection class Sessions(JsonModel): """ session collection """ class Meta: global_key_prefix = global_key model_key_prefix = "sessions" database = redis_connection session_id: str = Field(index=True) session_question: str question_id: str = Field(index=True) class QuestionDeps(JsonModel): """ Question Dep collection """ class Meta: global_key_prefix = global_key model_key_prefix = "ques_deps" database = redis_connection question_id: str = Field(index=True) dep_name: str dep_data: str dep_type: int class Report(JsonModel): """ Report collection """ class Meta: global_key_prefix = global_key model_key_prefix = "report" database = redis_connection user_question: str cache_question_id: int = Field(index=True) cache_question: str cache_answer: str similarity: float = Field(index=True) cache_delta_time: float = Field(index=True) cache_time: datetime.datetime = Field(index=True) extra: Optional[str] return Questions, Embedding, Answers, QuestionDeps, Sessions, Counter, Report
[docs]class RedisCacheStorage(CacheStorage): """ Using redis-om as OM to store data in redis cache storage :param host: redis host, default value 'localhost' :type host: str :param port: redis port, default value 27017 :type port: int :param global_key_prefix: A global prefix for keys against which data is stored. For example, for a global_key_prefix ='gptcache', keys would be constructed would look like this: gptcache:questions:abc123 :type global_key_prefix: str :param kwargs: Additional parameters to provide in order to create redis om connection Example: .. code-block:: python from gptcache.manager import CacheBase, manager_factory cache_store = CacheBase('redis', redis_host="localhost", redis_port=6379, global_key_prefix="gptcache", ) # or data_manager = manager_factory("mongo,faiss", data_dir="./workspace", scalar_params={ "redis_host"="localhost", "redis_port"=6379, "global_key_prefix"="gptcache", }, vector_params={"dimension": 128}, ) """ def __init__( self, global_key_prefix="gptcache", host: str = "localhost", port: int = 6379, **kwargs ): self.con = get_redis_connection(host=host, port=port, **kwargs) self.con_encoded = get_redis_connection( host=host, port=port, decode_responses=False, **kwargs ) ( self._ques, self._embedding, self._answer, self._ques_dep, self._session, self._counter, self._report, ) = get_models(global_key_prefix, redis_connection=self.con) Migrator().run()
[docs] def create(self): pass
def _insert(self, data: CacheData, pipeline: Pipeline = None): self._counter.incr() pk = str(self._counter.get()) answers = data.answers if isinstance(data.answers, list) else [data.answers] all_data = [] for answer in answers: answer_data = self._answer( answer=answer.answer, answer_type=int(answer.answer_type), ) all_data.append(answer_data) ques_data = self._ques( pk=pk, question=data.question if isinstance(data.question, str) else data.question.content, create_on=datetime.datetime.utcnow(), last_access=datetime.datetime.utcnow(), deleted=0, answers=answers, ) embedding_data = ( data.embedding_data.astype(np.float32).tobytes() if data.embedding_data is not None else None ) self._embedding(, embedding=embedding_data).save(pipeline) if isinstance(data.question, Question) and data.question.deps is not None: all_deps = [] for dep in data.question.deps: all_deps.append( self._ques_dep(,,, dep_type=dep.dep_type, ) ) self._ques_dep.add(all_deps, pipeline=pipeline) if data.session_id: session_data = self._session(, session_id=data.session_id, session_question=data.question if isinstance(data.question, str) else data.question.content, ) return int(
[docs] def batch_insert(self, all_data: List[CacheData]): ids = [] with self.con.pipeline() as pipeline: for data in all_data: ids.append(self._insert(data, pipeline=pipeline)) pipeline.execute() return ids
[docs] def get_data_by_id(self, key: str): key = str(key) try: qs = self._ques.get(pk=key) except NotFoundError: return None qs.update(last_access=datetime.datetime.utcnow()) res_ans = [(item.answer, item.answer_type) for item in qs.answers] deps = self._ques_dep.find(self._ques_dep.question_id == key).all() res_deps = [ QuestionDep(item.dep_name, item.dep_data, item.dep_type) for item in deps ] session_ids = [ obj.session_id for obj in self._session.find(self._session.question_id == key).all() ] res_embedding = self._embedding.get(, self.con_encoded)["embedding"] return CacheData( question=qs.question if not deps else Question(qs.question, res_deps), answers=res_ans, embedding_data=np.frombuffer(res_embedding, dtype=np.float32), session_id=session_ids, create_on=qs.create_on, last_access=qs.last_access, )
[docs] def mark_deleted(self, keys): result = self._ques.find( << keys).all() for qs in result: qs.update(deleted=-1)
[docs] def clear_deleted_data(self): with self.con.pipeline() as pipeline: qs_to_delete = self._ques.find(self._ques.deleted == -1).all() self._ques.delete_many(qs_to_delete, pipeline) q_ids = [ for qs in qs_to_delete] sessions_to_delete = self._session.find( self._session.question_id << q_ids ).all() self._session.delete_many(sessions_to_delete, pipeline) deps_to_delete = self._ques_dep.find( self._ques_dep.question_id << q_ids ).all() self._ques_dep.delete_many(deps_to_delete, pipeline) pipeline.execute()
[docs] def get_ids(self, deleted=True): state = -1 if deleted else 0 res = [ int( for obj in self._ques.find(self._ques.deleted == state).all() ] return res
[docs] def count(self, state: int = 0, is_all: bool = False): if is_all: return self._ques.find().count() return self._ques.find(self._ques.deleted == state).count()
[docs] def add_session(self, question_id, session_id, session_question): self._session( question_id=question_id, session_id=session_id, session_question=session_question, ).save()
[docs] def list_sessions(self, session_id=None, key=None): if session_id and key: self._session.find( self._session.session_id == session_id and self._session.question_id == key ).all() if key: key = str(key) return self._session.find(self._session.question_id == key).all() if session_id: return self._session.find(self._session.session_id == session_id).all() return self._session.find().all()
[docs] def delete_session(self, keys: List[str]): keys = [str(key) for key in keys] with self.con.pipeline() as pipeline: sessions_to_delete = self._session.find( self._session.question_id << keys ).all() self._session.delete_many(sessions_to_delete, pipeline) pipeline.execute()
[docs] def report_cache(self, user_question, cache_question, cache_question_id, cache_answer, similarity_value, cache_delta_time): self._report( user_question=user_question, cache_question=cache_question, cache_question_id=cache_question_id, cache_answer=cache_answer, similarity=similarity_value, cache_delta_time=cache_delta_time, cache_time=datetime.datetime.utcnow(), ).save()
[docs] def close(self): self.con.close()