diff --git a/qdrant_client/async_qdrant_fastembed.py b/qdrant_client/async_qdrant_fastembed.py index 923b1af6..1f92ee9b 100644 --- a/qdrant_client/async_qdrant_fastembed.py +++ b/qdrant_client/async_qdrant_fastembed.py @@ -32,6 +32,9 @@ ImageEmbedding, SparseTextEmbedding, SUPPORTED_EMBEDDING_MODELS, + SUPPORTED_SPARSE_EMBEDDING_MODELS, + _LATE_INTERACTION_EMBEDDING_MODELS, + _IMAGE_EMBEDDING_MODELS, IDF_EMBEDDING_MODELS, OnnxProvider, ) @@ -184,11 +187,17 @@ def _import_fastembed(cls) -> None: @classmethod def _get_model_params(cls, model_name: str) -> tuple[int, models.Distance]: cls._import_fastembed() - if model_name not in SUPPORTED_EMBEDDING_MODELS: + if model_name in SUPPORTED_EMBEDDING_MODELS: + return SUPPORTED_EMBEDDING_MODELS[model_name] + if model_name in _LATE_INTERACTION_EMBEDDING_MODELS: + return _LATE_INTERACTION_EMBEDDING_MODELS[model_name] + if model_name in _IMAGE_EMBEDDING_MODELS: + return _IMAGE_EMBEDDING_MODELS[model_name] + if model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS: raise ValueError( - f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_EMBEDDING_MODELS}" + "Sparse embeddings do not return fixed embedding size and distance type" ) - return SUPPORTED_EMBEDDING_MODELS[model_name] + raise ValueError(f"Unsupported embedding model: {model_name}") def _get_or_init_model( self, @@ -424,6 +433,10 @@ def get_embedding_size(self, model_name: Optional[str] = None) -> int: int: the size of the embeddings produced by the model. """ model_name = model_name or self.embedding_model_name + if model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS: + raise ValueError( + f"Sparse embeddings do not have a fixed embedding size. Current model: {model_name}" + ) (embeddings_size, _) = self._get_model_params(model_name=model_name) return embeddings_size diff --git a/qdrant_client/fastembed_common.py b/qdrant_client/fastembed_common.py index 72e7e727..595b907f 100644 --- a/qdrant_client/fastembed_common.py +++ b/qdrant_client/fastembed_common.py @@ -31,7 +31,7 @@ else {} ) -SUPPORTED_SPARSE_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( +SUPPORTED_SPARSE_EMBEDDING_MODELS: dict[str, dict[str, Any]] = ( {model["model"]: model for model in SparseTextEmbedding.list_supported_models()} if SparseTextEmbedding else {} @@ -48,13 +48,19 @@ ) _LATE_INTERACTION_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( - {model["model"]: model for model in LateInteractionTextEmbedding.list_supported_models()} + { + model["model"]: (model["dim"], models.Distance.COSINE) + for model in LateInteractionTextEmbedding.list_supported_models() + } if LateInteractionTextEmbedding else {} ) _IMAGE_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = ( - {model["model"]: model for model in ImageEmbedding.list_supported_models()} + { + model["model"]: (model["dim"], models.Distance.COSINE) + for model in ImageEmbedding.list_supported_models() + } if ImageEmbedding else {} ) diff --git a/qdrant_client/qdrant_fastembed.py b/qdrant_client/qdrant_fastembed.py index a2dc373e..fc56f3e7 100644 --- a/qdrant_client/qdrant_fastembed.py +++ b/qdrant_client/qdrant_fastembed.py @@ -23,6 +23,9 @@ ImageEmbedding, SparseTextEmbedding, SUPPORTED_EMBEDDING_MODELS, + SUPPORTED_SPARSE_EMBEDDING_MODELS, + _LATE_INTERACTION_EMBEDDING_MODELS, + _IMAGE_EMBEDDING_MODELS, IDF_EMBEDDING_MODELS, OnnxProvider, ) @@ -186,12 +189,21 @@ def _import_fastembed(cls) -> None: def _get_model_params(cls, model_name: str) -> tuple[int, models.Distance]: cls._import_fastembed() - if model_name not in SUPPORTED_EMBEDDING_MODELS: + if model_name in SUPPORTED_EMBEDDING_MODELS: + return SUPPORTED_EMBEDDING_MODELS[model_name] + + if model_name in _LATE_INTERACTION_EMBEDDING_MODELS: + return _LATE_INTERACTION_EMBEDDING_MODELS[model_name] + + if model_name in _IMAGE_EMBEDDING_MODELS: + return _IMAGE_EMBEDDING_MODELS[model_name] + + if model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS: raise ValueError( - f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_EMBEDDING_MODELS}" + "Sparse embeddings do not return fixed embedding size and distance type" ) - return SUPPORTED_EMBEDDING_MODELS[model_name] + raise ValueError(f"Unsupported embedding model: {model_name}") def _get_or_init_model( self, @@ -438,8 +450,8 @@ def _validate_collection_info(self, collection_info: models.CollectionInfo) -> N ), f"{self.sparse_embedding_model_name} requires modifier IDF, current modifier is {modifier}" def get_embedding_size( - self, - model_name: Optional[str] = None, + self, + model_name: Optional[str] = None, ) -> int: """ Get the size of the embeddings produced by the specified model. @@ -451,6 +463,10 @@ def get_embedding_size( int: the size of the embeddings produced by the model. """ model_name = model_name or self.embedding_model_name + if model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS: + raise ValueError( + f"Sparse embeddings do not have a fixed embedding size. Current model: {model_name}" + ) embeddings_size, _ = self._get_model_params(model_name=model_name) return embeddings_size diff --git a/tests/test_fastembed.py b/tests/test_fastembed.py index 9c02e257..d089dd9b 100644 --- a/tests/test_fastembed.py +++ b/tests/test_fastembed.py @@ -201,3 +201,21 @@ def test_idf_models(): # the only sparse model without IDF is SPLADE, however it's too large for tests, so we don't test how non-idf # models work + + +def test_get_embedding_size(): + local_client = QdrantClient(":memory:") + + if not local_client._FASTEMBED_INSTALLED: + pytest.skip("FastEmbed is not installed, skipping test") + + assert local_client.get_embedding_size() == 384 + + assert local_client.get_embedding_size(model_name="BAAI/bge-base-en-v1.5") == 768 + + assert local_client.get_embedding_size(model_name="Qdrant/resnet50-onnx") == 2048 + + assert local_client.get_embedding_size(model_name="colbert-ir/colbertv2.0") == 128 + + with pytest.raises(ValueError, match="Sparse embeddings do not have a fixed embedding size."): + local_client.get_embedding_size(model_name="Qdrant/bm25")