Skip to content

Commit

Permalink
Oml zoo (#291)
Browse files Browse the repository at this point in the history
* Support of Qdrant/Unicom-ViT-B-16 and Qdrant/Unicom-ViT-B-32
  • Loading branch information
I8dNLo authored Jul 10, 2024
1 parent d09af55 commit f0ff09c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
4 changes: 2 additions & 2 deletions fastembed/image/image_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def __init__(
return

raise ValueError(
f"Model {model_name} is not supported in TextEmbedding."
"Please check the supported models using `TextEmbedding.list_supported_models()`"
f"Model {model_name} is not supported in ImageEmbedding."
"Please check the supported models using `ImageEmbedding.list_supported_models()`"
)

def embed(
Expand Down
33 changes: 29 additions & 4 deletions fastembed/image/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,28 @@
},
"model_file": "model.onnx",
},
{
"model": "Qdrant/Unicom-ViT-B-16",
"dim": 768,
"description": "Unicom Unicom-ViT-B-16 from open-metric-learning",
"size_in_GB": 0.82,
"sources": {
"hf": "Qdrant/Unicom-ViT-B-16",
},
"model_file": "model.onnx",
},
{
"model": "Qdrant/Unicom-ViT-B-32",
"dim": 512,
"description": "Unicom Unicom-ViT-B-32 from open-metric-learning",
"size_in_GB": 0.48,
"sources": {
"hf": "Qdrant/Unicom-ViT-B-32",
},
"model_file": "model.onnx",
},
]


class OnnxImageEmbedding(ImageEmbeddingBase, OnnxImageModel[np.ndarray]):
def __init__(
self,
Expand Down Expand Up @@ -122,10 +141,16 @@ def _preprocess_onnx_input(

return onnx_input

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
def _post_process_onnx_output(
self, output: OnnxOutputContext
) -> Iterable[np.ndarray]:
return normalize(output.model_output).astype(np.float32)


class OnnxImageEmbeddingWorker(ImageEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> OnnxImageEmbedding:
return OnnxImageEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)
def init_embedding(
self, model_name: str, cache_dir: str, **kwargs
) -> OnnxImageEmbedding:
return OnnxImageEmbedding(
model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs
)
8 changes: 8 additions & 0 deletions tests/test_image_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
"Qdrant/resnet50-onnx": np.array(
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01046245, 0.01171397, 0.00705971, 0.0]
),
"Qdrant/Unicom-ViT-B-16": np.array(
[ 0.0170, -0.0361, 0.0125, -0.0428, -0.0232, 0.0232, -0.0602, -0.0333,
0.0155, 0.0497]
),
"Qdrant/Unicom-ViT-B-32": np.array(
[0.0418, 0.0550, 0.0003, 0.0253, -0.0185, 0.0016, -0.0368, -0.0402,
-0.0891, -0.0186]
),
}


Expand Down

0 comments on commit f0ff09c

Please sign in to comment.