Skip to content

Commit

Permalink
Fix to avoid overfloat and get rid of model_max_length (#319)
Browse files Browse the repository at this point in the history
* Fix to avoid overfloat and get rid of model_max_length
* Fixes for max_length vs model_max_length logic
Jupter warning disabled

* Support of jwodder/versioningit#48

* Update fastembed/common/preprocessor_utils.py
---------

Co-authored-by: George <[email protected]>
  • Loading branch information
I8dNLo and joein authored Aug 14, 2024
1 parent 49762a6 commit 62607c2
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 44 deletions.
4 changes: 3 additions & 1 deletion fastembed/common/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str):
# Open the tar.gz file
with tarfile.open(targz_path, "r:gz") as tar:
# Extract all files into the cache directory
tar.extractall(path=cache_dir)
tar.extractall(
path=cache_dir,
)
except tarfile.TarError as e:
# If any error occurs while opening or extracting the tar.gz file,
# delete the cache directory (if it was created in this function)
Expand Down
16 changes: 11 additions & 5 deletions fastembed/common/preprocessor_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
from pathlib import Path
from typing import Tuple

from tokenizers import AddedToken, Tokenizer

from fastembed.image.transform.operators import Compose
Expand All @@ -18,7 +17,7 @@ def load_special_tokens(model_dir: Path) -> dict:
return tokens_map


def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tuple[Tokenizer, dict]:
def load_tokenizer(model_dir: Path) -> Tuple[Tokenizer, dict]:
config_path = model_dir / "config.json"
if not config_path.exists():
raise ValueError(f"Could not find config.json in {model_dir}")
Expand All @@ -36,13 +35,20 @@ def load_tokenizer(model_dir: Path, max_length: int = 512) -> Tuple[Tokenizer, d

with open(str(tokenizer_config_path)) as tokenizer_config_file:
tokenizer_config = json.load(tokenizer_config_file)
assert (
"model_max_length" in tokenizer_config or "max_length" in tokenizer_config
), "Models without model_max_length or max_length are not supported."
if "model_max_length" not in tokenizer_config:
max_context = tokenizer_config["max_length"]
elif "max_length" not in tokenizer_config:
max_context = tokenizer_config["model_max_length"]
else:
max_context = min(tokenizer_config["model_max_length"], tokenizer_config["max_length"])

tokens_map = load_special_tokens(model_dir)

tokenizer = Tokenizer.from_file(str(tokenizer_path))
tokenizer.enable_truncation(
max_length=min(tokenizer_config["model_max_length"], max_length)
)
tokenizer.enable_truncation(max_length=max_context)
tokenizer.enable_padding(
pad_id=config.get("pad_token_id", 0), pad_token=tokenizer_config["pad_token"]
)
Expand Down
8 changes: 2 additions & 6 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,7 @@ def _preprocess_onnx_input(
"""
return onnx_input

def _post_process_onnx_output(
self, output: OnnxOutputContext
) -> Iterable[np.ndarray]:
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
embeddings = output.model_output
return normalize(embeddings[:, 0]).astype(np.float32)

Expand All @@ -258,6 +256,4 @@ def init_embedding(
cache_dir: str,
**kwargs,
) -> OnnxTextEmbedding:
return OnnxTextEmbedding(
model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs
)
return OnnxTextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)
4 changes: 4 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import os

# disable DeprecationWarning https://github.com/jupyter/jupyter_core/issues/398
os.environ["JUPYTER_PLATFORM_DIRS"] = "1"
43 changes: 11 additions & 32 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,54 +32,34 @@
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": np.array(
[0.0094, 0.0184, 0.0328, 0.0072, -0.0351]
),
"intfloat/multilingual-e5-large": np.array(
[0.0098, 0.0045, 0.0066, -0.0354, 0.0070]
),
"intfloat/multilingual-e5-large": np.array([0.0098, 0.0045, 0.0066, -0.0354, 0.0070]),
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2": np.array(
[-0.01341097, 0.0416553, -0.00480805, 0.02844842, 0.0505299]
),
"jinaai/jina-embeddings-v2-small-en": np.array(
[-0.0455, -0.0428, -0.0122, 0.0613, 0.0015]
),
"jinaai/jina-embeddings-v2-base-en": np.array(
[-0.0332, -0.0509, 0.0287, -0.0043, -0.0077]
),
"jinaai/jina-embeddings-v2-base-de": np.array(
[-0.0085, 0.0417, 0.0342, 0.0309, -0.0149]
),
"jinaai/jina-embeddings-v2-base-code": np.array(
[0.0145, -0.0164, 0.0136, -0.0170, 0.0734]
),
"nomic-ai/nomic-embed-text-v1": np.array(
[0.3708 , 0.2031, -0.3406, -0.2114, -0.3230]
),
"jinaai/jina-embeddings-v2-small-en": np.array([-0.0455, -0.0428, -0.0122, 0.0613, 0.0015]),
"jinaai/jina-embeddings-v2-base-en": np.array([-0.0332, -0.0509, 0.0287, -0.0043, -0.0077]),
"jinaai/jina-embeddings-v2-base-de": np.array([-0.0085, 0.0417, 0.0342, 0.0309, -0.0149]),
"jinaai/jina-embeddings-v2-base-code": np.array([0.0145, -0.0164, 0.0136, -0.0170, 0.0734]),
"nomic-ai/nomic-embed-text-v1": np.array([0.3708, 0.2031, -0.3406, -0.2114, -0.3230]),
"nomic-ai/nomic-embed-text-v1.5": np.array(
[-0.15407836, -0.03053198, -3.9138033, 0.1910364, 0.13224715]
),
"nomic-ai/nomic-embed-text-v1.5-Q": np.array(
[-0.12525563, 0.38030425, -3.961622 , 0.04176439, -0.0758301]
[-0.12525563, 0.38030425, -3.961622, 0.04176439, -0.0758301]
),
"thenlper/gte-large": np.array(
[-0.01920587, 0.00113156, -0.00708992, -0.00632304, -0.04025577]
),
"mixedbread-ai/mxbai-embed-large-v1": np.array(
[0.02295546, 0.03196154, 0.016512, -0.04031524, -0.0219634]
),
"snowflake/snowflake-arctic-embed-xs": np.array(
[0.0092, 0.0619, 0.0196, 0.009, -0.0114]
),
"snowflake/snowflake-arctic-embed-s": np.array(
[-0.0416, -0.0867, 0.0209, 0.0554, -0.0272]
),
"snowflake/snowflake-arctic-embed-m": np.array(
[-0.0329, 0.0364, 0.0481, 0.0016, 0.0328]
),
"snowflake/snowflake-arctic-embed-xs": np.array([0.0092, 0.0619, 0.0196, 0.009, -0.0114]),
"snowflake/snowflake-arctic-embed-s": np.array([-0.0416, -0.0867, 0.0209, 0.0554, -0.0272]),
"snowflake/snowflake-arctic-embed-m": np.array([-0.0329, 0.0364, 0.0481, 0.0016, 0.0328]),
"snowflake/snowflake-arctic-embed-m-long": np.array(
[0.0080, -0.0266, -0.0335, 0.0282, 0.0143]
),
"snowflake/snowflake-arctic-embed-l": np.array(
[0.0189, -0.0673, 0.0183, 0.0124, 0.0146]
),
"snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]),
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
}

Expand All @@ -94,7 +74,6 @@ def test_embedding():
dim = model_desc["dim"]

model = TextEmbedding(model_name=model_desc["model"])

docs = ["hello world", "flag embedding"]
embeddings = list(model.embed(docs))
embeddings = np.stack(embeddings, axis=0)
Expand Down

0 comments on commit 62607c2

Please sign in to comment.