diff --git a/vectordb_bench/backend/clients/pinecone/pinecone.py b/vectordb_bench/backend/clients/pinecone/pinecone.py index 9ff4f74ff..62325f90b 100644 --- a/vectordb_bench/backend/clients/pinecone/pinecone.py +++ b/vectordb_bench/backend/clients/pinecone/pinecone.py @@ -53,6 +53,7 @@ def __init__( index.delete(delete_all=True, namespace=namespace) self._metadata_key = "meta" + self._scalar_id_field = "id" self._scalar_label_field = "label" @classmethod diff --git a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py index a51632bc6..10bbabd88 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py @@ -3,6 +3,7 @@ import logging import time from contextlib import contextmanager +from vectordb_bench.backend.filter import Filter, FilterType from ..api import VectorDB, DBCaseConfig from qdrant_client.http.models import ( @@ -10,7 +11,7 @@ VectorParams, PayloadSchemaType, Batch, - Filter, + QdrantFilter, FieldCondition, Range, ) @@ -22,6 +23,11 @@ class QdrantCloud(VectorDB): + supported_filter_types: list[FilterType] = [ + FilterType.NonFilter, + FilterType.Int, + FilterType.Label, + ] def __init__( self, dim: int, @@ -29,6 +35,7 @@ def __init__( db_case_config: DBCaseConfig, collection_name: str = "QdrantCloudCollection", drop_old: bool = False, + with_scalar_labels: bool = False, **kwargs, ): """Initialize wrapper around the QdrantCloud vector database.""" @@ -40,11 +47,14 @@ def __init__( self._vector_field = "vector" tmp_client = QdrantClient(**self.db_config) + self.with_scalar_labels = with_scalar_labels if drop_old: log.info(f"QdrantCloud client drop_old collection: {self.collection_name}") tmp_client.delete_collection(self.collection_name) self._create_collection(dim, tmp_client) tmp_client = None + self._scalar_id_field = "id" + self._scalar_label_field = "label" @contextmanager def init(self) -> None: @@ -105,6 +115,7 @@ def insert_embeddings( self, embeddings: list[list[float]], metadata: list[int], + labels_data: list[str] = None, **kwargs, ) -> (int, Exception): """Insert embeddings into Milvus. should call self.init() first""" @@ -138,10 +149,10 @@ def search_embedding( Should call self.init() first. """ assert self.qdrant_client is not None - - f = None + condition = self.condition + f = self.condition if filters: - f = Filter( + f = QdrantFilter( must=[FieldCondition( key = self._primary_field, range = Range( @@ -160,3 +171,27 @@ def search_embedding( ret = [result.id for result in res[0]] return ret + + def prepare_filter(self, filter: Filter): + if filter.type == FilterType.NonFilter: + self.condition = None + elif filter.type == FilterType.Int: + self.condition = QdrantFilter( + must=[ + FieldCondition( + key=self._scalar_id_field, + range=Range(gte=filter.int_value), + ), + ] + ) + elif filter.type == FilterType.Label: + self.condition = QdrantFilter( + must=[ + FieldCondition( + key=self._scalar_label_field, + match={"value": filter.label_value}, + ), + ] + ) + else: + raise ValueError(f"Not support Filter for Pinecone - {filter}")