diff --git a/fastembed/image/image_embedding.py b/fastembed/image/image_embedding.py index 29df2e10..6fa03702 100644 --- a/fastembed/image/image_embedding.py +++ b/fastembed/image/image_embedding.py @@ -65,8 +65,8 @@ def __init__( return raise ValueError( - f"Model {model_name} is not supported in TextEmbedding." - "Please check the supported models using `TextEmbedding.list_supported_models()`" + f"Model {model_name} is not supported in ImageEmbedding." + "Please check the supported models using `ImageEmbedding.list_supported_models()`" ) def embed( diff --git a/fastembed/image/onnx_embedding.py b/fastembed/image/onnx_embedding.py index 13070e4e..c39e9ca1 100644 --- a/fastembed/image/onnx_embedding.py +++ b/fastembed/image/onnx_embedding.py @@ -29,9 +29,28 @@ }, "model_file": "model.onnx", }, + { + "model": "Qdrant/Unicom-ViT-B-16", + "dim": 768, + "description": "Unicom Unicom-ViT-B-16 from open-metric-learning", + "size_in_GB": 0.82, + "sources": { + "hf": "Qdrant/Unicom-ViT-B-16", + }, + "model_file": "model.onnx", + }, + { + "model": "Qdrant/Unicom-ViT-B-32", + "dim": 512, + "description": "Unicom Unicom-ViT-B-32 from open-metric-learning", + "size_in_GB": 0.48, + "sources": { + "hf": "Qdrant/Unicom-ViT-B-32", + }, + "model_file": "model.onnx", + }, ] - class OnnxImageEmbedding(ImageEmbeddingBase, OnnxImageModel[np.ndarray]): def __init__( self, @@ -122,10 +141,16 @@ def _preprocess_onnx_input( return onnx_input - def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: + def _post_process_onnx_output( + self, output: OnnxOutputContext + ) -> Iterable[np.ndarray]: return normalize(output.model_output).astype(np.float32) class OnnxImageEmbeddingWorker(ImageEmbeddingWorker): - def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> OnnxImageEmbedding: - return OnnxImageEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs) + def init_embedding( + self, model_name: str, cache_dir: str, **kwargs + ) -> OnnxImageEmbedding: + return OnnxImageEmbedding( + model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs + ) diff --git a/tests/test_image_onnx_embeddings.py b/tests/test_image_onnx_embeddings.py index 196c3713..bc6a4d00 100644 --- a/tests/test_image_onnx_embeddings.py +++ b/tests/test_image_onnx_embeddings.py @@ -11,6 +11,14 @@ "Qdrant/resnet50-onnx": np.array( [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01046245, 0.01171397, 0.00705971, 0.0] ), + "Qdrant/Unicom-ViT-B-16": np.array( + [ 0.0170, -0.0361, 0.0125, -0.0428, -0.0232, 0.0232, -0.0602, -0.0333, + 0.0155, 0.0497] + ), + "Qdrant/Unicom-ViT-B-32": np.array( + [0.0418, 0.0550, 0.0003, 0.0253, -0.0185, 0.0016, -0.0368, -0.0402, + -0.0891, -0.0186] + ), }