From 87decb0d538a1c25c36493e256c3de5fc8fb2036 Mon Sep 17 00:00:00 2001 From: Anush Date: Tue, 30 Jan 2024 21:27:16 +0530 Subject: [PATCH] chore: port to Xenova Jina source (#102) * chore: xenova jina * chore: try recusive model location * chore: updated doc string, blob pattern --- fastembed/embedding.py | 25 ++++++++++++++++--------- fastembed/models.json | 4 ++-- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/fastembed/embedding.py b/fastembed/embedding.py index 52cf4207..620b5d5e 100644 --- a/fastembed/embedding.py +++ b/fastembed/embedding.py @@ -35,6 +35,21 @@ def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable: yield b +def locate_model_file(model_dir: Path, file_names: List[str]): + """ + Find model path for both TransformerJS style `onnx` subdirectory structure and direct model weights structure used by Optimum and Qdrant + """ + if not model_dir.is_dir(): + raise ValueError(f"Provided model path '{model_dir}' is not a directory.") + + for path in model_dir.rglob("*.onnx"): + for file_name in file_names: + if path.is_file() and path.name == file_name: + return path + + raise ValueError(f"Could not find either of {', '.join(file_names)} in {model_dir}") + + def normalize(input_array, p=2, dim=1, eps=1e-12): # Calculate the Lp norm along the specified dimension norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True) @@ -92,19 +107,11 @@ def __init__( ): self.path = path self.model_name = model_name - model_path = self.path / "model.onnx" - optimized_model_path = self.path / "model_optimized.onnx" + model_path = locate_model_file(self.path, ["model.onnx", "model_optimized.onnx"]) # List of Execution Providers: https://onnxruntime.ai/docs/execution-providers onnx_providers = ["CPUExecutionProvider"] - if not model_path.exists(): - # Rename file model_optimized.onnx to model.onnx if it exists - if optimized_model_path.exists(): - optimized_model_path.rename(model_path) - else: - raise ValueError(f"Could not find model.onnx in {self.path}") - # Hacky support for multilingual model self.exclude_token_type_ids = False if model_name == "intfloat/multilingual-e5-large": diff --git a/fastembed/models.json b/fastembed/models.json index 1c4e7b5d..7aea39ad 100644 --- a/fastembed/models.json +++ b/fastembed/models.json @@ -83,7 +83,7 @@ "description": " English embedding model supporting 8192 sequence length", "size_in_GB": 0.55, "hf_sources": [ - "jinaai/jina-embeddings-v2-base-en" + "xenova/jina-embeddings-v2-base-en" ], "compressed_url_sources": [] }, @@ -93,7 +93,7 @@ "description": " English embedding model supporting 8192 sequence length", "size_in_GB": 0.13, "hf_sources": [ - "jinaai/jina-embeddings-v2-small-en" + "xenova/jina-embeddings-v2-small-en" ], "compressed_url_sources": [] },