from typing import List, Optional
import numpy as np
from gptcache.manager.vector_data.base import VectorBase, VectorData
from gptcache.utils import import_qdrant
from gptcache.utils.log import gptcache_log
import_qdrant()
# pylint: disable=C0413
from qdrant_client import QdrantClient
from qdrant_client.models import (
PointStruct,
HnswConfigDiff,
VectorParams,
OptimizersConfigDiff,
Distance,
)
[docs]class QdrantVectorStore(VectorBase):
"""Qdrant Vector Store"""
def __init__(
self,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
collection_name: Optional[str] = "gptcache",
location: Optional[str] = "./qdrant",
dimension: int = 0,
top_k: int = 1,
flush_interval_sec: int = 5,
index_params: Optional[dict] = None,
):
if dimension <= 0:
raise ValueError(
f"invalid `dim` param: {dimension} in the Qdrant vector store."
)
self._client: QdrantClient
self._collection_name = collection_name
self._in_memory = location == ":memory:"
self.dimension = dimension
self.top_k = top_k
if self._in_memory or location is not None:
self._create_local(location)
else:
self._create_remote(
url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https
)
self._create_collection(collection_name, flush_interval_sec, index_params)
def _create_local(self, location):
self._client = QdrantClient(location=location)
def _create_remote(
self, url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https
):
self._client = QdrantClient(
url=url,
port=port,
api_key=api_key,
timeout=timeout,
host=host,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
prefix=prefix,
https=https,
)
def _create_collection(
self,
collection_name: str,
flush_interval_sec: int,
index_params: Optional[dict] = None,
):
hnsw_config = HnswConfigDiff(**(index_params or {}))
vectors_config = VectorParams(
size=self.dimension, distance=Distance.COSINE, hnsw_config=hnsw_config
)
optimizers_config = OptimizersConfigDiff(
deleted_threshold=0.2,
vacuum_min_vector_number=1000,
flush_interval_sec=flush_interval_sec,
)
# check if the collection exists
existing_collections = self._client.get_collections()
for existing_collection in existing_collections.collections:
if existing_collection.name == collection_name:
gptcache_log.warning(
"The %s collection already exists, and it will be used directly.",
collection_name,
)
break
else:
self._client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
optimizers_config=optimizers_config,
)
[docs] def mul_add(self, datas: List[VectorData]):
points = [
PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas
]
self._client.upsert(
collection_name=self._collection_name, points=points, wait=False
)
[docs] def search(self, data: np.ndarray, top_k: int = -1):
if top_k == -1:
top_k = self.top_k
reshaped_data = data.reshape(-1).tolist()
search_result = self._client.search(
collection_name=self._collection_name,
query_vector=reshaped_data,
limit=top_k,
)
return list(map(lambda x: (x.score, x.id), search_result))
[docs] def delete(self, ids: List[str]):
self._client.delete(collection_name=self._collection_name, points_selector=ids)
[docs] def rebuild(self, ids=None): # pylint: disable=unused-argument
optimizers_config = OptimizersConfigDiff(
deleted_threshold=0.2, vacuum_min_vector_number=1000
)
self._client.update_collection(
collection_name=self._collection_name, optimizer_config=optimizers_config
)
[docs] def flush(self):
# no need to flush manually as qdrant flushes automatically based on the optimizers_config for remote Qdrant
pass
[docs] def close(self):
self.flush()