Skip to content

Commit

Permalink
Added changes to load/save corpora with non-ascii character with unit…
Browse files Browse the repository at this point in the history
… test case (#93)
  • Loading branch information
IssacXid authored Dec 23, 2024
1 parent c4fef24 commit ce8f886
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
8 changes: 4 additions & 4 deletions bm25s/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def save(
corpus = corpus if corpus is not None else self.corpus

if corpus is not None:
with open(save_dir / corpus_name, "w") as f:
with open(save_dir / corpus_name, "wt", encoding='utf-8') as f:
# if it's not an iterable, we skip
if not isinstance(corpus, Iterable):
logging.warning(
Expand All @@ -918,7 +918,7 @@ def save(
continue

try:
doc_str = json_functions.dumps(doc)
doc_str = json_functions.dumps(doc, ensure_ascii=False)
except Exception as e:
logging.warning(f"Error saving document at index {i}: {e}")
else:
Expand Down Expand Up @@ -1060,7 +1060,7 @@ def load(
# Load the vocab dictionary
if load_vocab:
vocab_path = save_dir / vocab_name
with open(vocab_path, "r",encoding='utf-8') as f:
with open(vocab_path, "r", encoding='utf-8') as f:
vocab_dict: dict = json_functions.loads(f.read())
else:
vocab_dict = None
Expand Down Expand Up @@ -1091,7 +1091,7 @@ def load(
corpus = utils.corpus.JsonlCorpus(corpus_file)
else:
corpus = []
with open(corpus_file, "r") as f:
with open(corpus_file, "r", encoding='utf-8') as f:
for line in f:
doc = json_functions.loads(line)
corpus.append(doc)
Expand Down
33 changes: 16 additions & 17 deletions tests/core/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,24 +138,23 @@ class TestBM25SNonASCIILoadingSaving(unittest.TestCase):
def setUpClass(cls):
# check that import orjson fails
import bm25s

cls.text =["Thanks for your great work!"] # this works fine
cls.text = ['שלום חברים'] # this crashes!

# create a vocabulary
tokens = [ t.split() for t in cls.text ]
unique_tokens = set([item for sublist in tokens for item in sublist])
vocab_token2id = {token: i for i, token in enumerate(unique_tokens)}

# create a tokenized corpus
token_ids = [ [vocab_token2id[token] for token in text_tokens if token in vocab_token2id] for text_tokens in tokens ]
corpus_tokens = bm25s.tokenization.Tokenized(ids=token_ids, vocab=vocab_token2id)

# create a retriever
cls.retriever = bm25s.BM25()
cls.corpus = [
"a cat is a feline and likes to purr",
"a dog is the human's best friend and loves to play",
"a bird is a beautiful animal that can fly",
"a fish is a creature that lives in water and swims",
"שלום חברים, איך אתם היום?",
"El café está muy caliente",
"今天的天气真好!",
"Как дела?",
"Türkçe öğreniyorum.",
'שלום חברים'
]
corpus_tokens = bm25s.tokenize(cls.corpus, stopwords="en")
cls.retriever = bm25s.BM25(corpus=cls.corpus)
cls.retriever.index(corpus_tokens)
cls.tmpdirname = tempfile.mkdtemp()


def setUp(self):
# verify that orjson is properly installed
Expand All @@ -166,7 +165,7 @@ def setUp(self):

def test_a_save_and_load(self):
# both of these fail: UnicodeEncodeError: 'charmap' codec can't encode characters in position 2-6: character maps to <undefined>
self.retriever.save(self.tmpdirname, corpus=self.text)
self.retriever.save(self.tmpdirname, corpus=self.corpus)
self.retriever.load(self.tmpdirname, load_corpus=True)

@classmethod
Expand Down

0 comments on commit ce8f886

Please sign in to comment.