Source code for gptcache.manager.vector_data.chroma
from typing import List
import numpy as np
from gptcache.manager.vector_data.base import VectorBase, VectorData
from gptcache.utils import import_chromadb, import_torch
import_torch()
import_chromadb()
import chromadb # pylint: disable=C0413
[docs]class Chromadb(VectorBase):
"""vector store: Chromadb
:param client_settings: the setting for Chromadb.
:type client_settings: Settings
:param persist_directory: the directory to persist, defaults to .chromadb/ in the current directory.
:type persist_directory: str
:param collection_name: the name of the collection in Chromadb, defaults to 'gptcache'.
:type collection_name: str
:param top_k: the number of the vectors results to return, defaults to 1.
:type top_k: int
"""
def __init__(
self,
client_settings=None,
persist_directory=None,
collection_name: str = "gptcache",
top_k: int = 1,
):
self.top_k = top_k
if client_settings:
self._client_settings = client_settings
else:
self._client_settings = chromadb.config.Settings()
if persist_directory is not None:
self._client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet", persist_directory=persist_directory
)
self._client = chromadb.Client(self._client_settings)
self._persist_directory = persist_directory
self._collection = self._client.get_or_create_collection(name=collection_name)
[docs] def mul_add(self, datas: List[VectorData]):
data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas)))
self._collection.add(embeddings=data_array, ids=id_array)
[docs] def search(self, data, top_k: int = -1):
if self._collection.count() == 0:
return []
if top_k == -1:
top_k = self.top_k
results = self._collection.query(
query_embeddings=[data.tolist()],
n_results=top_k,
include=["distances"],
)
return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]]))
[docs] def delete(self, ids):
self._collection.delete([str(x) for x in ids])
[docs] def rebuild(self, ids=None): # pylint: disable=unused-argument
return True
[docs] def get_embeddings(self, data_id: str):
vec_emb = self._collection.get(
str(data_id),
include=["embeddings"],
)["embeddings"]
if vec_emb is None or len(vec_emb) < 1:
return None
vec_emb = np.asarray(vec_emb[0], dtype="float32")
return vec_emb
[docs] def update_embeddings(self, data_id: str, emb: np.ndarray):
self._collection.update(
ids=str(data_id),
embeddings=emb.tolist(),
)