diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index 8e3d2dc5..b736a984 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -198,7 +198,9 @@ def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) -> return model_dir @classmethod - def download_model(cls, model: Dict[str, Any], cache_dir: Path, retries=3, **kwargs) -> Path: + def download_model( + cls, model: Dict[str, Any], cache_dir: Path, retries: object = 3, **kwargs: object + ) -> Path: """ Downloads a model from HuggingFace Hub or Google Cloud Storage. diff --git a/fastembed/common/utils.py b/fastembed/common/utils.py index ab6a8580..cd50dfdf 100644 --- a/fastembed/common/utils.py +++ b/fastembed/common/utils.py @@ -3,8 +3,11 @@ from itertools import islice from pathlib import Path from typing import Generator, Iterable, Optional, Union - +import unicodedata +import sys import numpy as np +import re +from typing import Set def normalize(input_array, p=2, dim=1, eps=1e-12) -> np.ndarray: @@ -37,7 +40,16 @@ def define_cache_dir(cache_dir: Optional[str] = None) -> Path: cache_path = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir)) else: cache_path = Path(cache_dir) - cache_path.mkdir(parents=True, exist_ok=True) return cache_path + + +def get_all_punctuation() -> Set[str]: + return set( + chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P") + ) + + +def remove_non_alphanumeric(text: str) -> str: + return re.sub(r"[^\w\s]", " ", text, flags=re.UNICODE) diff --git a/fastembed/sparse/bm25.py b/fastembed/sparse/bm25.py index eea38804..cc4fbbb6 100644 --- a/fastembed/sparse/bm25.py +++ b/fastembed/sparse/bm25.py @@ -1,21 +1,23 @@ import os -import string from collections import defaultdict from multiprocessing import get_all_start_methods from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union - import mmh3 import numpy as np from snowballstemmer import stemmer as get_stemmer - -from fastembed.common.utils import define_cache_dir, iter_batch +from fastembed.common.utils import ( + define_cache_dir, + iter_batch, + get_all_punctuation, + remove_non_alphanumeric, +) from fastembed.parallel_processor import ParallelWorkerPool, Worker from fastembed.sparse.sparse_embedding_base import ( SparseEmbedding, SparseTextEmbeddingBase, ) -from fastembed.sparse.utils.tokenizer import WordTokenizer +from fastembed.sparse.utils.tokenizer import SimpleTokenizer supported_languages = [ "arabic", @@ -100,6 +102,7 @@ def __init__( b: float = 0.75, avg_len: float = 256.0, language: str = "english", + token_max_length: int = 40, **kwargs, ): super().__init__(model_name, cache_dir, **kwargs) @@ -120,10 +123,12 @@ def __init__( model_description, self.cache_dir, local_files_only=self._local_files_only ) - self.punctuation = set(string.punctuation) + self.token_max_length = token_max_length + self.punctuation = set(get_all_punctuation()) self.stopwords = set(self._load_stopwords(self._model_dir, self.language)) + self.stemmer = get_stemmer(language) - self.tokenizer = WordTokenizer + self.tokenizer = SimpleTokenizer @classmethod def list_supported_models(cls) -> List[Dict[str, Any]]: @@ -222,7 +227,10 @@ def _stem(self, tokens: List[str]) -> List[str]: if token.lower() in self.stopwords: continue - stemmed_token = self.stemmer.stemWord(token) + if len(token) > self.token_max_length: + continue + + stemmed_token = self.stemmer.stemWord(token.lower()) if stemmed_token: stemmed_tokens.append(stemmed_token) @@ -234,6 +242,7 @@ def raw_embed( ) -> List[SparseEmbedding]: embeddings = [] for document in documents: + document = remove_non_alphanumeric(document) tokens = self.tokenizer.tokenize(document) stemmed_tokens = self._stem(tokens) token_id2value = self._term_frequency(stemmed_tokens) @@ -282,6 +291,7 @@ def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[Sp query = [query] for text in query: + text = remove_non_alphanumeric(text) tokens = self.tokenizer.tokenize(text) stemmed_tokens = self._stem(tokens) token_ids = np.array( diff --git a/fastembed/sparse/utils/tokenizer.py b/fastembed/sparse/utils/tokenizer.py index c2f378bc..88b6c059 100644 --- a/fastembed/sparse/utils/tokenizer.py +++ b/fastembed/sparse/utils/tokenizer.py @@ -4,6 +4,14 @@ from typing import List +class SimpleTokenizer: + def tokenize(text: str) -> List[str]: + text = re.sub(r"[^\w]", " ", text.lower()) + text = re.sub(r"\s+", " ", text) + + return text.strip().split() + + class WordTokenizer: """The tokenizer is "destructive" such that the regexes applied will munge the input string to a state beyond re-construction. @@ -68,8 +76,7 @@ class WordTokenizer: ) ] CONTRACTIONS3 = [ - re.compile(pattern) - for pattern in (r"(?i) ('t)(?#X)(is)\b", r"(?i) ('t)(?#X)(was)\b") + re.compile(pattern) for pattern in (r"(?i) ('t)(?#X)(is)\b", r"(?i) ('t)(?#X)(was)\b") ] @classmethod diff --git a/pyproject.toml b/pyproject.toml index b729c9d8..f4ea13b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ numpy = [ pillow = "^10.3.0" snowballstemmer = "^2.2.0" PyStemmer = "^2.2.0" -mmh3 = "^4.0" +mmh3 = "^4.1.0" [tool.poetry.group.dev.dependencies] pytest = "^7.4.2" diff --git a/tests/test_attention_embeddings.py b/tests/test_attention_embeddings.py index f198e7b5..f0107d3d 100644 --- a/tests/test_attention_embeddings.py +++ b/tests/test_attention_embeddings.py @@ -111,11 +111,34 @@ def test_multilanguage(model_name): model = SparseTextEmbedding(model_name=model_name, language="english") embeddings = list(model.embed(docs))[:2] - assert embeddings[0].values.shape == (4,) - assert embeddings[0].indices.shape == (4,) + assert embeddings[0].values.shape == (5,) + assert embeddings[0].indices.shape == (5,) assert embeddings[1].values.shape == (4,) assert embeddings[1].indices.shape == (4,) if is_ci: shutil.rmtree(model.model._model_dir) + + +@pytest.mark.parametrize("model_name", ["Qdrant/bm25"]) +def test_special_characters(model_name): + is_ci = os.getenv("CI") + + docs = [ + "Über den größten Flüssen Österreichs äußern sich Experten häufig: Öko-Systeme müssen geschützt werden!", + "L'élève français s'écrie : « Où est mon crayon ? J'ai besoin de finir cet exercice avant la récréation!", + "Într-o zi însorită, Ștefan și Ioana au mâncat mămăligă cu brânză și au băut țuică la cabană.", + "Üzgün öğretmen öğrencilere seslendi: Lütfen gürültü yapmayın, sınavınızı bitirmeye çalışıyorum!", + "Ο Ξενοφών είπε: «Ψάχνω για ένα ωραίο δώρο για τη γιαγιά μου. Ίσως ένα φυτό ή ένα βιβλίο;»", + "Hola! ¿Cómo estás? Estoy muy emocionado por el cumpleaños de mi hermano, ¡va a ser increíble! También quiero comprar un pastel de chocolate con fresas y un regalo especial: un libro titulado «Cien años de soledad", + ] + + model = SparseTextEmbedding(model_name=model_name, language="english") + embeddings = list(model.embed(docs)) + for idx, shape in enumerate([14, 18, 15, 10, 15]): + assert embeddings[idx].values.shape == (shape,) + assert embeddings[idx].indices.shape == (shape,) + + if is_ci: + shutil.rmtree(model.model._model_dir) diff --git a/tests/test_sparse_embeddings.py b/tests/test_sparse_embeddings.py index a6e982a6..1c3149d9 100644 --- a/tests/test_sparse_embeddings.py +++ b/tests/test_sparse_embeddings.py @@ -147,5 +147,5 @@ def test_stem_case_insensitive_stopwords(bm25_instance): result = bm25_instance._stem(tokens) # Assert - expected = ["Quick", "Brown", "Fox", "Test", "Sentenc"] + expected = ["quick", "brown", "fox", "test", "sentenc"] assert result == expected, f"Expected {expected}, but got {result}"