Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add saving and loading corpus/stopwords to Tokenizer and add integration to HF Hub via bm25s.hf.TokenizerHF (save/load) #59

Merged
merged 3 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ corpus = [

# Pick your favorite stemmer, and pass
stemmer = None
stopwords = []
stopwords = ["is"]
splitter = lambda x: x.split() # function or regex pattern
# Create a tokenizer
tokenizer = Tokenizer(
Expand All @@ -211,6 +211,19 @@ print("tokens:", corpus_tokens)
print("vocab:", tokenizer.get_vocab_dict())

# note: the vocab dict will either be a dict of `word -> id` if you don't have a stemmer, and a dict of `stemmed word -> stem id` if you do.
# You can save the vocab. it's fine to use the same dir as your index if filename doesn't conflict
tokenizer.save_vocab(save_dir="bm25s_very_big_index")

# loading:
new_tokenizer = Tokenizer(stemmer=stemmer, stopwords=[], splitter=splitter)
new_tokenizer.load_vocab("bm25s_very_big_index")
print("vocab reloaded:", new_tokenizer.get_vocab_dict())

# the same can be done for stopwords
print("stopwords before reload:", new_tokenizer.stopwords)
tokenizer.save_stopwords(save_dir="bm25s_very_big_index")
new_tokenizer.load_stopwords("bm25s_very_big_index")
print("stopwords reloaded:", new_tokenizer.stopwords)
```

You can find advanced examples in [examples/tokenizer_class.py](examples/tokenizer_class.py), including how to:
Expand Down
269 changes: 259 additions & 10 deletions bm25s/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tempfile
from typing import Iterable, Union
from . import BM25, __version__
from .tokenization import Tokenizer

try:
from huggingface_hub import HfApi
Expand Down Expand Up @@ -106,6 +107,32 @@
retriever = BM25HF.load_from_hub("{username}/{repo_name}", token=token)
```

## Tokenizer

If you have saved a `Tokenizer` object with the index using the following approach:

```python
from bm25s.hf import TokenizerHF

token = "your_hugging_face_token"
tokenizer = TokenizerHF(corpus=corpus, stopwords="english")
tokenizer.save_to_hub("{username}/{repo_name}", token=token)

# and stopwords too
tokenizer.save_stopwords_to_hub("{username}/{repo_name}", token=token)
```

Then, you can load the tokenizer using the following code:

```python
from bm25s.hf import TokenizerHF

tokenizer = TokenizerHF(corpus=corpus, stopwords=[])
tokenizer.load_vocab_from_hub("{username}/{repo_name}", token=token)
tokenizer.load_stopwords_from_hub("{username}/{repo_name}", token=token)
```


## Stats

This dataset was created using the following data:
Expand Down Expand Up @@ -133,15 +160,15 @@
To cite `bm25s`, please use the following bibtex:

```
@misc{lu_2024_bm25s,
title={BM25S: Orders of magnitude faster lexical search via eager sparse scoring},
author={Xing Han Lù},
year={2024},
eprint={2407.03618},
archivePrefix={arXiv},
primaryClass={cs.IR},
url={https://arxiv.org/abs/2407.03618},
}
@misc{{lu_2024_bm25s,
title={{BM25S: Orders of magnitude faster lexical search via eager sparse scoring}},
author={{Xing Han Lù}},
year={{2024}},
eprint={{2407.03618}},
archivePrefix={{arXiv}},
primaryClass={{cs.IR}},
url={{https://arxiv.org/abs/2407.03618}},
}}
```

"""
Expand Down Expand Up @@ -216,6 +243,228 @@ def can_save_locally(local_save_dir, overwrite_local: bool) -> bool:
return True


class TokenizerHF(Tokenizer):
def save_vocab_to_hub(
self,
repo_id: str,
token: str = None,
local_dir: str = None,
commit_message: str = "Update tokenizer",
overwrite_local: bool = False,
private=True,
**kwargs,
):
"""
This function saves the tokenizer's vocab to the Hugging Face Hub.

Parameters
----------
repo_id: str
The unique identifier of the repository to save the model to.
The `repo_id` should be in the form of "username/repo_name".

token: str
The Hugging Face API token to use.

local_dir: str
The directory to save the model to before pushing to the Hub.
If it is not empty and `overwrite_local` is False, it will fall
back to saving to a temporary directory.

commit_message: str
The commit message to use when saving the model.

overwrite_local: bool
Whether to overwrite the existing local directory if it exists.

kwargs: dict
Additional keyword arguments to pass to `HfApi.upload_folder` call.
"""
api = HfApi(token=token)
repo_url = api.create_repo(
repo_id=repo_id,
token=api.token,
private=private,
repo_type="model",
exist_ok=True,
)
repo_id = repo_url.repo_id

saving_locally = can_save_locally(local_dir, overwrite_local)
if saving_locally:
os.makedirs(local_dir, exist_ok=True)
save_dir = local_dir
else:
# save to a temporary directory otherwise
save_dir = tempfile.mkdtemp()

self.save_vocab(save_dir)
# push content of the temporary directory to the repo
api.upload_folder(
repo_id=repo_id,
commit_message=commit_message,
token=api.token,
folder_path=save_dir,
repo_type=repo_url.repo_type,
**kwargs,
)
# delete the temporary directory if it was created
if not saving_locally:
shutil.rmtree(save_dir)

return repo_url

def load_vocab_from_hub(
cls,
repo_id: str,
revision=None,
token=None,
local_dir=None,
):
"""
This function loads the tokenizer's vocab from the Hugging Face Hub.

Parameters
----------
repo_id: str
The unique identifier of the repository to load the model from.
The `repo_id` should be in the form of "username/repo_name".

revision: str
The revision of the model to load.

token: str
The Hugging Face API token to use.

local_dir: str
The local dir where the model will be stored after downloading.

allow_pickle: bool
Whether to allow pickling the model. Default is False.
"""
api = HfApi(token=token)
# check if the model exists
repo_url = api.repo_info(repo_id)
if repo_url is None:
raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.")

snapshot = api.snapshot_download(
repo_id=repo_id, revision=revision, token=token, local_dir=local_dir
)
if snapshot is None:
raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.")

return cls.load_vocab(save_dir=snapshot)

def save_stopwords_to_hub(
self,
repo_id: str,
token: str = None,
local_dir: str = None,
commit_message: str = "Update stopwords",
overwrite_local: bool = False,
private=True,
**kwargs,
):
"""
This function saves the tokenizer's stopwords to the Hugging Face Hub.

Parameters
----------
repo_id: str
The unique identifier of the repository to save the model to.
The `repo_id` should be in the form of "username/repo_name".

token: str
The Hugging Face API token to use.

local_dir: str
The directory to save the model to before pushing to the Hub.
If it is not empty and `overwrite_local` is False, it will fall
back to saving to a temporary directory.

commit_message: str
The commit message to use when saving the model.

overwrite_local: bool
Whether to overwrite the existing local directory if it exists.

kwargs: dict
Additional keyword arguments to pass to `HfApi.upload_folder` call.
"""
api = HfApi(token=token)
repo_url = api.create_repo(
repo_id=repo_id,
token=api.token,
private=private,
repo_type="model",
exist_ok=True,
)
repo_id = repo_url.repo_id

saving_locally = can_save_locally(local_dir, overwrite_local)
if saving_locally:
os.makedirs(local_dir, exist_ok=True)
save_dir = local_dir
else:
# save to a temporary directory otherwise
save_dir = tempfile.mkdtemp()

self.save_stopwords(save_dir)
# push content of the temporary directory to the repo
api.upload_folder(
repo_id=repo_id,
commit_message=commit_message,
token=api.token,
folder_path=save_dir,
repo_type=repo_url.repo_type,
**kwargs,
)
# delete the temporary directory if it was created
if not saving_locally:
shutil.rmtree(save_dir)

return repo_url

def load_stopwords_from_hub(
self,
repo_id: str,
revision=None,
token=None,
local_dir=None,
):
"""
This function loads the tokenizer's stopwords from the Hugging Face Hub.

Parameters
----------
repo_id: str
The unique identifier of the repository to load the model from.
The `repo_id` should be in the form of "username/repo_name".

revision: str
The revision of the model to load.

token: str
The Hugging Face API token to use.

local_dir: str
The local dir where the model will be stored after downloading.
"""
api = HfApi(token=token)
# check if the model exists
repo_url = api.repo_info(repo_id)
if repo_url is None:
raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.")

snapshot = api.snapshot_download(
repo_id=repo_id, revision=revision, token=token, local_dir=local_dir
)
if snapshot is None:
raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.")

return self.load_stopwords(save_dir=snapshot)

class BM25HF(BM25):
def save_to_hub(
self,
Expand All @@ -238,7 +487,7 @@ def save_to_hub(

repo_id: str
The name of the repository to save the model to.
It should be username/repo_name.
the `repo_id` should be in the form of "username/repo_name".

token: str
The Hugging Face API token to use.
Expand Down
Loading
Loading