Skip to content

Commit

Permalink
new: extend embedding size to support image and late interaction models
Browse files Browse the repository at this point in the history
  • Loading branch information
joein committed Jan 26, 2025
1 parent d6cc925 commit 1cf81ed
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 11 deletions.
19 changes: 16 additions & 3 deletions qdrant_client/async_qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
ImageEmbedding,
SparseTextEmbedding,
SUPPORTED_EMBEDDING_MODELS,
SUPPORTED_SPARSE_EMBEDDING_MODELS,
_LATE_INTERACTION_EMBEDDING_MODELS,
_IMAGE_EMBEDDING_MODELS,
IDF_EMBEDDING_MODELS,
OnnxProvider,
)
Expand Down Expand Up @@ -184,11 +187,17 @@ def _import_fastembed(cls) -> None:
@classmethod
def _get_model_params(cls, model_name: str) -> tuple[int, models.Distance]:
cls._import_fastembed()
if model_name not in SUPPORTED_EMBEDDING_MODELS:
if model_name in SUPPORTED_EMBEDDING_MODELS:
return SUPPORTED_EMBEDDING_MODELS[model_name]
if model_name in _LATE_INTERACTION_EMBEDDING_MODELS:
return _LATE_INTERACTION_EMBEDDING_MODELS[model_name]
if model_name in _IMAGE_EMBEDDING_MODELS:
return _IMAGE_EMBEDDING_MODELS[model_name]
if model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS:
raise ValueError(
f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_EMBEDDING_MODELS}"
"Sparse embeddings do not return fixed embedding size and distance type"
)
return SUPPORTED_EMBEDDING_MODELS[model_name]
raise ValueError(f"Unsupported embedding model: {model_name}")

def _get_or_init_model(
self,
Expand Down Expand Up @@ -424,6 +433,10 @@ def get_embedding_size(self, model_name: Optional[str] = None) -> int:
int: the size of the embeddings produced by the model.
"""
model_name = model_name or self.embedding_model_name
if model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS:
raise ValueError(
f"Sparse embeddings do not have a fixed embedding size. Current model: {model_name}"
)
(embeddings_size, _) = self._get_model_params(model_name=model_name)
return embeddings_size

Expand Down
12 changes: 9 additions & 3 deletions qdrant_client/fastembed_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
else {}
)

SUPPORTED_SPARSE_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = (
SUPPORTED_SPARSE_EMBEDDING_MODELS: dict[str, dict[str, Any]] = (
{model["model"]: model for model in SparseTextEmbedding.list_supported_models()}
if SparseTextEmbedding
else {}
Expand All @@ -48,13 +48,19 @@
)

_LATE_INTERACTION_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = (
{model["model"]: model for model in LateInteractionTextEmbedding.list_supported_models()}
{
model["model"]: (model["dim"], models.Distance.COSINE)
for model in LateInteractionTextEmbedding.list_supported_models()
}
if LateInteractionTextEmbedding
else {}
)

_IMAGE_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = (
{model["model"]: model for model in ImageEmbedding.list_supported_models()}
{
model["model"]: (model["dim"], models.Distance.COSINE)
for model in ImageEmbedding.list_supported_models()
}
if ImageEmbedding
else {}
)
Expand Down
26 changes: 21 additions & 5 deletions qdrant_client/qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
ImageEmbedding,
SparseTextEmbedding,
SUPPORTED_EMBEDDING_MODELS,
SUPPORTED_SPARSE_EMBEDDING_MODELS,
_LATE_INTERACTION_EMBEDDING_MODELS,
_IMAGE_EMBEDDING_MODELS,
IDF_EMBEDDING_MODELS,
OnnxProvider,
)
Expand Down Expand Up @@ -186,12 +189,21 @@ def _import_fastembed(cls) -> None:
def _get_model_params(cls, model_name: str) -> tuple[int, models.Distance]:
cls._import_fastembed()

if model_name not in SUPPORTED_EMBEDDING_MODELS:
if model_name in SUPPORTED_EMBEDDING_MODELS:
return SUPPORTED_EMBEDDING_MODELS[model_name]

if model_name in _LATE_INTERACTION_EMBEDDING_MODELS:
return _LATE_INTERACTION_EMBEDDING_MODELS[model_name]

if model_name in _IMAGE_EMBEDDING_MODELS:
return _IMAGE_EMBEDDING_MODELS[model_name]

if model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS:
raise ValueError(
f"Unsupported embedding model: {model_name}. Supported models: {SUPPORTED_EMBEDDING_MODELS}"
"Sparse embeddings do not return fixed embedding size and distance type"
)

return SUPPORTED_EMBEDDING_MODELS[model_name]
raise ValueError(f"Unsupported embedding model: {model_name}")

def _get_or_init_model(
self,
Expand Down Expand Up @@ -438,8 +450,8 @@ def _validate_collection_info(self, collection_info: models.CollectionInfo) -> N
), f"{self.sparse_embedding_model_name} requires modifier IDF, current modifier is {modifier}"

def get_embedding_size(
self,
model_name: Optional[str] = None,
self,
model_name: Optional[str] = None,
) -> int:
"""
Get the size of the embeddings produced by the specified model.
Expand All @@ -451,6 +463,10 @@ def get_embedding_size(
int: the size of the embeddings produced by the model.
"""
model_name = model_name or self.embedding_model_name
if model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS:
raise ValueError(
f"Sparse embeddings do not have a fixed embedding size. Current model: {model_name}"
)
embeddings_size, _ = self._get_model_params(model_name=model_name)
return embeddings_size

Expand Down
18 changes: 18 additions & 0 deletions tests/test_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,21 @@ def test_idf_models():

# the only sparse model without IDF is SPLADE, however it's too large for tests, so we don't test how non-idf
# models work


def test_get_embedding_size():
local_client = QdrantClient(":memory:")

if not local_client._FASTEMBED_INSTALLED:
pytest.skip("FastEmbed is not installed, skipping test")

assert local_client.get_embedding_size() == 384

assert local_client.get_embedding_size(model_name="BAAI/bge-base-en-v1.5") == 768

assert local_client.get_embedding_size(model_name="Qdrant/resnet50-onnx") == 2048

assert local_client.get_embedding_size(model_name="colbert-ir/colbertv2.0") == 128

with pytest.raises(ValueError, match="Sparse embeddings do not have a fixed embedding size."):
local_client.get_embedding_size(model_name="Qdrant/bm25")

0 comments on commit 1cf81ed

Please sign in to comment.