Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tf-idf fix: punctuation removal + lowercasing #339

Merged
merged 14 commits into from
Sep 24, 2024
4 changes: 3 additions & 1 deletion fastembed/common/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) ->
return model_dir

@classmethod
def download_model(cls, model: Dict[str, Any], cache_dir: Path, retries=3, **kwargs) -> Path:
def download_model(
cls, model: Dict[str, Any], cache_dir: Path, retries: object = 3, **kwargs: object
) -> Path:
"""
Downloads a model from HuggingFace Hub or Google Cloud Storage.

Expand Down
16 changes: 14 additions & 2 deletions fastembed/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from itertools import islice
from pathlib import Path
from typing import Generator, Iterable, Optional, Union

import unicodedata
import sys
import numpy as np
import re
from typing import Set


def normalize(input_array, p=2, dim=1, eps=1e-12) -> np.ndarray:
Expand Down Expand Up @@ -37,7 +40,16 @@ def define_cache_dir(cache_dir: Optional[str] = None) -> Path:
cache_path = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir))
else:
cache_path = Path(cache_dir)

cache_path.mkdir(parents=True, exist_ok=True)

return cache_path


def get_all_punctuation() -> Set[str]:
return set(
chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P")
)


def remove_non_alphanumeric(text: str) -> str:
return re.sub(r"[^\w\s]", " ", text, flags=re.UNICODE)
24 changes: 16 additions & 8 deletions fastembed/sparse/bm25.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import os
import string
from collections import defaultdict
from multiprocessing import get_all_start_methods
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union

import mmh3
import numpy as np
from snowballstemmer import stemmer as get_stemmer

from fastembed.common.utils import define_cache_dir, iter_batch
from fastembed.common.utils import (
define_cache_dir,
iter_batch,
get_all_punctuation,
remove_non_alphanumeric,
)
from fastembed.parallel_processor import ParallelWorkerPool, Worker
from fastembed.sparse.sparse_embedding_base import (
SparseEmbedding,
SparseTextEmbeddingBase,
)
from fastembed.sparse.utils.tokenizer import WordTokenizer
from fastembed.sparse.utils.tokenizer import SimpleTokenizer

supported_languages = [
"arabic",
Expand Down Expand Up @@ -120,10 +122,11 @@ def __init__(
model_description, self.cache_dir, local_files_only=self._local_files_only
)

self.punctuation = set(string.punctuation)
self.punctuation = set(get_all_punctuation())
self.stopwords = set(self._load_stopwords(model_dir, self.language))

self.stemmer = get_stemmer(language)
self.tokenizer = WordTokenizer
self.tokenizer = SimpleTokenizer # WordTokenizer

@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -222,7 +225,10 @@ def _stem(self, tokens: List[str]) -> List[str]:
if token.lower() in self.stopwords:
continue

stemmed_token = self.stemmer.stemWord(token)
if len(token) > 40:
continue

stemmed_token = self.stemmer.stemWord(token.lower())

if stemmed_token:
stemmed_tokens.append(stemmed_token)
Expand All @@ -234,6 +240,7 @@ def raw_embed(
) -> List[SparseEmbedding]:
embeddings = []
for document in documents:
document = remove_non_alphanumeric(document)
tokens = self.tokenizer.tokenize(document)
stemmed_tokens = self._stem(tokens)
token_id2value = self._term_frequency(stemmed_tokens)
Expand Down Expand Up @@ -282,6 +289,7 @@ def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[Sp
query = [query]

for text in query:
text = remove_non_alphanumeric(text)
tokens = self.tokenizer.tokenize(text)
stemmed_tokens = self._stem(tokens)
token_ids = np.array(
Expand Down
11 changes: 9 additions & 2 deletions fastembed/sparse/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
from typing import List


class SimpleTokenizer:
def tokenize(text: str) -> List[str]:
text = re.sub(r"[^\w]", " ", text.lower())
text = re.sub(r"\s+", " ", text)

return text.strip().split()


class WordTokenizer:
"""The tokenizer is "destructive" such that the regexes applied will munge the
input string to a state beyond re-construction.
Expand Down Expand Up @@ -68,8 +76,7 @@ class WordTokenizer:
)
]
CONTRACTIONS3 = [
re.compile(pattern)
for pattern in (r"(?i) ('t)(?#X)(is)\b", r"(?i) ('t)(?#X)(was)\b")
re.compile(pattern) for pattern in (r"(?i) ('t)(?#X)(is)\b", r"(?i) ('t)(?#X)(was)\b")
]

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ numpy = [
pillow = "^10.3.0"
snowballstemmer = "^2.2.0"
PyStemmer = "^2.2.0"
mmh3 = "^4.0"
mmh3 = "^4.1.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.2"
Expand Down
22 changes: 20 additions & 2 deletions tests/test_attention_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,26 @@ def test_multilanguage(model_name):

model = SparseTextEmbedding(model_name=model_name, language="english")
embeddings = list(model.embed(docs))[:2]
assert embeddings[0].values.shape == (4,)
assert embeddings[0].indices.shape == (4,)
assert embeddings[0].values.shape == (5,)
assert embeddings[0].indices.shape == (5,)

assert embeddings[1].values.shape == (4,)
assert embeddings[1].indices.shape == (4,)


@pytest.mark.parametrize("model_name", ["Qdrant/bm25"])
def test_special_characters(model_name):
docs = [
"Über den größten Flüssen Österreichs äußern sich Experten häufig: Öko-Systeme müssen geschützt werden!",
"L'élève français s'écrie : « Où est mon crayon ? J'ai besoin de finir cet exercice avant la récréation!",
"Într-o zi însorită, Ștefan și Ioana au mâncat mămăligă cu brânză și au băut țuică la cabană.",
"Üzgün öğretmen öğrencilere seslendi: Lütfen gürültü yapmayın, sınavınızı bitirmeye çalışıyorum!",
"Ο Ξενοφών είπε: «Ψάχνω για ένα ωραίο δώρο για τη γιαγιά μου. Ίσως ένα φυτό ή ένα βιβλίο;»",
"Hola! ¿Cómo estás? Estoy muy emocionado por el cumpleaños de mi hermano, ¡va a ser increíble! También quiero comprar un pastel de chocolate con fresas y un regalo especial: un libro titulado «Cien años de soledad",
]

model = SparseTextEmbedding(model_name=model_name, language="english")
embeddings = list(model.embed(docs))
for idx, shape in enumerate([14, 18, 15, 10, 15]):
assert embeddings[idx].values.shape == (shape,)
assert embeddings[idx].indices.shape == (shape,)
2 changes: 1 addition & 1 deletion tests/test_sparse_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,5 @@ def test_stem_case_insensitive_stopwords(bm25_instance):
result = bm25_instance._stem(tokens)

# Assert
expected = ["Quick", "Brown", "Fox", "Test", "Sentenc"]
expected = ["quick", "brown", "fox", "test", "sentenc"]
assert result == expected, f"Expected {expected}, but got {result}"
Loading