diff --git a/bm25s/__init__.py b/bm25s/__init__.py index f163183..792fd5d 100644 --- a/bm25s/__init__.py +++ b/bm25s/__init__.py @@ -899,7 +899,7 @@ def save( corpus = corpus if corpus is not None else self.corpus if corpus is not None: - with open(save_dir / corpus_name, "w") as f: + with open(save_dir / corpus_name, "wt", encoding='utf-8') as f: # if it's not an iterable, we skip if not isinstance(corpus, Iterable): logging.warning( @@ -918,7 +918,7 @@ def save( continue try: - doc_str = json_functions.dumps(doc) + doc_str = json_functions.dumps(doc, ensure_ascii=False) except Exception as e: logging.warning(f"Error saving document at index {i}: {e}") else: @@ -1060,7 +1060,7 @@ def load( # Load the vocab dictionary if load_vocab: vocab_path = save_dir / vocab_name - with open(vocab_path, "r",encoding='utf-8') as f: + with open(vocab_path, "r", encoding='utf-8') as f: vocab_dict: dict = json_functions.loads(f.read()) else: vocab_dict = None @@ -1091,7 +1091,7 @@ def load( corpus = utils.corpus.JsonlCorpus(corpus_file) else: corpus = [] - with open(corpus_file, "r") as f: + with open(corpus_file, "r", encoding='utf-8') as f: for line in f: doc = json_functions.loads(line) corpus.append(doc) diff --git a/tests/core/test_save_load.py b/tests/core/test_save_load.py index 120af39..dae37e9 100644 --- a/tests/core/test_save_load.py +++ b/tests/core/test_save_load.py @@ -138,24 +138,23 @@ class TestBM25SNonASCIILoadingSaving(unittest.TestCase): def setUpClass(cls): # check that import orjson fails import bm25s - - cls.text =["Thanks for your great work!"] # this works fine - cls.text = ['שלום חברים'] # this crashes! - - # create a vocabulary - tokens = [ t.split() for t in cls.text ] - unique_tokens = set([item for sublist in tokens for item in sublist]) - vocab_token2id = {token: i for i, token in enumerate(unique_tokens)} - - # create a tokenized corpus - token_ids = [ [vocab_token2id[token] for token in text_tokens if token in vocab_token2id] for text_tokens in tokens ] - corpus_tokens = bm25s.tokenization.Tokenized(ids=token_ids, vocab=vocab_token2id) - - # create a retriever - cls.retriever = bm25s.BM25() + cls.corpus = [ + "a cat is a feline and likes to purr", + "a dog is the human's best friend and loves to play", + "a bird is a beautiful animal that can fly", + "a fish is a creature that lives in water and swims", + "שלום חברים, איך אתם היום?", + "El café está muy caliente", + "今天的天气真好!", + "Как дела?", + "Türkçe öğreniyorum.", + 'שלום חברים' + ] + corpus_tokens = bm25s.tokenize(cls.corpus, stopwords="en") + cls.retriever = bm25s.BM25(corpus=cls.corpus) cls.retriever.index(corpus_tokens) cls.tmpdirname = tempfile.mkdtemp() - + def setUp(self): # verify that orjson is properly installed @@ -166,7 +165,7 @@ def setUp(self): def test_a_save_and_load(self): # both of these fail: UnicodeEncodeError: 'charmap' codec can't encode characters in position 2-6: character maps to - self.retriever.save(self.tmpdirname, corpus=self.text) + self.retriever.save(self.tmpdirname, corpus=self.corpus) self.retriever.load(self.tmpdirname, load_corpus=True) @classmethod