Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support None to reset pre_tokenizers and normalizers, and index sequences #1590

Merged
merged 37 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
4bce6b2
initial commit
ArthurZucker Aug 4, 2024
908a5be
support None
ArthurZucker Aug 4, 2024
e5872f0
fix clippy
ArthurZucker Aug 4, 2024
79af84f
cleanup
ArthurZucker Aug 4, 2024
1d44d2c
clean?
ArthurZucker Aug 4, 2024
c2cccc8
propagate to pre_tokenizer
ArthurZucker Aug 4, 2024
ac15c15
fix test
ArthurZucker Aug 4, 2024
458368c
fix rust tests
ArthurZucker Aug 6, 2024
34b8f89
fix node
ArthurZucker Aug 6, 2024
1b926f3
propagate to decoder and post processor
ArthurZucker Aug 6, 2024
44a63eb
fix calls
ArthurZucker Aug 6, 2024
d30e10f
lint
ArthurZucker Aug 6, 2024
c4b1470
fmt
ArthurZucker Aug 6, 2024
a330807
node be happy I am fixing you
ArthurZucker Aug 6, 2024
891186a
initial commit
ArthurZucker Aug 4, 2024
c1fc9f1
support None
ArthurZucker Aug 4, 2024
7248894
fix clippy
ArthurZucker Aug 4, 2024
84c0685
cleanup
ArthurZucker Aug 4, 2024
ef0a697
clean?
ArthurZucker Aug 4, 2024
1ec38e3
propagate to pre_tokenizer
ArthurZucker Aug 4, 2024
3c143fd
fix test
ArthurZucker Aug 4, 2024
186a55e
fix rust tests
ArthurZucker Aug 6, 2024
c57a556
fix node
ArthurZucker Aug 6, 2024
645cdec
propagate to decoder and post processor
ArthurZucker Aug 6, 2024
624b520
fix calls
ArthurZucker Aug 6, 2024
2ecaed1
lint
ArthurZucker Aug 6, 2024
3b89f7f
fmt
ArthurZucker Aug 6, 2024
fa5fa08
node be happy I am fixing you
ArthurZucker Aug 6, 2024
5cd1f75
Merge branch 'fix-sequences' of github.com:huggingface/tokenizers int…
ArthurZucker Aug 7, 2024
0e512fb
add a small test
ArthurZucker Aug 7, 2024
4e8ee6e
styling
ArthurZucker Aug 7, 2024
af12117
Merge branch 'main' into fix-sequences
ArthurZucker Aug 7, 2024
ca1534b
style merge
ArthurZucker Aug 7, 2024
886be88
fix merge test
ArthurZucker Aug 7, 2024
e772707
fmt
ArthurZucker Aug 7, 2024
2e125bf
nits
ArthurZucker Aug 7, 2024
2de36e3
update tset
ArthurZucker Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions bindings/node/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl Tokenizer {
.tokenizer
.write()
.unwrap()
.with_pre_tokenizer((*pre_tokenizer).clone());
.with_pre_tokenizer(Some((*pre_tokenizer).clone()));
}

#[napi]
Expand All @@ -217,7 +217,7 @@ impl Tokenizer {
.tokenizer
.write()
.unwrap()
.with_decoder((*decoder).clone());
.with_decoder(Some((*decoder).clone()));
}

#[napi]
Expand All @@ -231,7 +231,7 @@ impl Tokenizer {
.tokenizer
.write()
.unwrap()
.with_post_processor((*post_processor).clone());
.with_post_processor(Some((*post_processor).clone()));
}

#[napi]
Expand All @@ -240,7 +240,7 @@ impl Tokenizer {
.tokenizer
.write()
.unwrap()
.with_normalizer((*normalizer).clone());
.with_normalizer(Some((*normalizer).clone()));
}

#[napi]
Expand Down
23 changes: 19 additions & 4 deletions bindings/python/src/normalizers.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::sync::{Arc, RwLock};

use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::{exceptions, prelude::*};
use std::sync::{Arc, RwLock};

use crate::error::ToPyResult;
use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern};
Expand Down Expand Up @@ -354,6 +352,7 @@ impl PyNFKC {
/// A list of Normalizer to be run as a sequence
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Sequence")]
pub struct PySequence {}

#[pymethods]
impl PySequence {
#[new]
Expand All @@ -380,6 +379,22 @@ impl PySequence {
fn __len__(&self) -> usize {
0
}

fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
match &self_.as_ref().normalizer {
PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) {
Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(item)))
.get_as_subtype(py),
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
)),
},
PyNormalizerTypeWrapper::Single(inner) => {
PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(inner)))
.get_as_subtype(py)
}
}
}
}

/// Lowercase Normalizer
Expand Down
18 changes: 18 additions & 0 deletions bindings/python/src/pre_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,24 @@ impl PySequence {
fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> {
PyTuple::new_bound(py, [PyList::empty_bound(py)])
}

fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
match &self_.as_ref().pretok {
PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) {
Some(item) => {
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(item)))
.get_as_subtype(py)
}
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
)),
},
PyPreTokenizerTypeWrapper::Single(inner) => {
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(inner)))
.get_as_subtype(py)
}
}
}
}

pub(crate) fn from_string(string: String) -> Result<PrependScheme, PyErr> {
Expand Down
39 changes: 23 additions & 16 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1371,8 +1371,9 @@ impl PyTokenizer {

/// Set the :class:`~tokenizers.normalizers.Normalizer`
#[setter]
fn set_normalizer(&mut self, normalizer: PyRef<PyNormalizer>) {
self.tokenizer.with_normalizer(normalizer.clone());
fn set_normalizer(&mut self, normalizer: Option<PyRef<PyNormalizer>>) {
let normalizer_option = normalizer.map(|norm| norm.clone());
self.tokenizer.with_normalizer(normalizer_option);
}

/// The `optional` :class:`~tokenizers.pre_tokenizers.PreTokenizer` in use by the Tokenizer
Expand All @@ -1387,8 +1388,9 @@ impl PyTokenizer {

/// Set the :class:`~tokenizers.normalizers.Normalizer`
#[setter]
fn set_pre_tokenizer(&mut self, pretok: PyRef<PyPreTokenizer>) {
self.tokenizer.with_pre_tokenizer(pretok.clone());
fn set_pre_tokenizer(&mut self, pretok: Option<PyRef<PyPreTokenizer>>) {
self.tokenizer
.with_pre_tokenizer(pretok.map(|pre| pre.clone()));
}

/// The `optional` :class:`~tokenizers.processors.PostProcessor` in use by the Tokenizer
Expand All @@ -1403,8 +1405,9 @@ impl PyTokenizer {

/// Set the :class:`~tokenizers.processors.PostProcessor`
#[setter]
fn set_post_processor(&mut self, processor: PyRef<PyPostProcessor>) {
self.tokenizer.with_post_processor(processor.clone());
fn set_post_processor(&mut self, processor: Option<PyRef<PyPostProcessor>>) {
self.tokenizer
.with_post_processor(processor.map(|p| p.clone()));
}

/// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer
Expand All @@ -1419,8 +1422,8 @@ impl PyTokenizer {

/// Set the :class:`~tokenizers.decoders.Decoder`
#[setter]
fn set_decoder(&mut self, decoder: PyRef<PyDecoder>) {
self.tokenizer.with_decoder(decoder.clone());
fn set_decoder(&mut self, decoder: Option<PyRef<PyDecoder>>) {
self.tokenizer.with_decoder(decoder.map(|d| d.clone()));
}
}

Expand All @@ -1436,10 +1439,12 @@ mod test {
#[test]
fn serialize() {
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
Arc::new(RwLock::new(NFKC.into())),
Arc::new(RwLock::new(Lowercase.into())),
])));
tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(
vec![
Arc::new(RwLock::new(NFKC.into())),
Arc::new(RwLock::new(Lowercase.into())),
],
))));

let tmp = NamedTempFile::new().unwrap().into_temp_path();
tokenizer.save(&tmp, false).unwrap();
Expand All @@ -1450,10 +1455,12 @@ mod test {
#[test]
fn serde_pyo3() {
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
Arc::new(RwLock::new(NFKC.into())),
Arc::new(RwLock::new(Lowercase.into())),
])));
tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(
vec![
Arc::new(RwLock::new(NFKC.into())),
Arc::new(RwLock::new(Lowercase.into())),
],
))));

let output = crate::utils::serde_pyo3::to_string(&tokenizer).unwrap();
assert_eq!(output, "Tokenizer(version=\"1.0\", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[NFKC(), Lowercase()]), pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))");
Expand Down
8 changes: 8 additions & 0 deletions bindings/python/tests/bindings/test_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def test_can_make_sequences(self):
output = normalizer.normalize_str(" HELLO ")
assert output == "hello"

def test_items(self):
normalizers = Sequence([BertNormalizer(True, True), Prepend()])
assert normalizers[1].__class__ == Prepend
normalizers[0].lowercase = False
assert not normalizers[0].lowercase
with pytest.raises(IndexError):
print(normalizers[2])


class TestLowercase:
def test_instantiate(self):
Expand Down
7 changes: 7 additions & 0 deletions bindings/python/tests/bindings/test_pre_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ def test_bert_like(self):
("?", (29, 30)),
]

def test_items(self):
pre_tokenizers = Sequence([Metaspace("a", "never", split=True), Punctuation()])
assert pre_tokenizers[1].__class__ == Punctuation
assert pre_tokenizers[0].__class__ == Metaspace
pre_tokenizers[0].split = False
assert not pre_tokenizers[0].split


class TestDigits:
def test_instantiate(self):
Expand Down
13 changes: 12 additions & 1 deletion bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from tokenizers import AddedToken, Encoding, Tokenizer
from tokenizers.implementations import BertWordPieceTokenizer
from tokenizers.models import BPE, Model, Unigram
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.pre_tokenizers import ByteLevel, Metaspace
from tokenizers.processors import RobertaProcessing, TemplateProcessing
from tokenizers.normalizers import Strip, Lowercase, Sequence


from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files


Expand Down Expand Up @@ -551,6 +552,16 @@ def test_decode_special(self):
assert output == "name is john"
assert tokenizer.get_added_tokens_decoder()[0] == AddedToken("my", special=True)

def test_setting_to_none(self):
tokenizer = Tokenizer(BPE())
tokenizer.normalizer = Strip()
tokenizer.normalizer = None
assert tokenizer.normalizer == None

tokenizer.pre_tokenizer = Metaspace()
tokenizer.pre_tokenizer = None
assert tokenizer.pre_tokenizer == None


class TestTokenizerRepr:
def test_repr(self):
Expand Down
14 changes: 7 additions & 7 deletions tokenizers/benches/bert_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ fn create_bert_tokenizer(wp: WordPiece) -> BertTokenizer {
let sep_id = *wp.get_vocab().get("[SEP]").unwrap();
let cls_id = *wp.get_vocab().get("[CLS]").unwrap();
let mut tokenizer = TokenizerImpl::new(wp);
tokenizer.with_pre_tokenizer(BertPreTokenizer);
tokenizer.with_normalizer(BertNormalizer::default());
tokenizer.with_decoder(decoders::wordpiece::WordPiece::default());
tokenizer.with_post_processor(BertProcessing::new(
tokenizer.with_pre_tokenizer(Some(BertPreTokenizer));
tokenizer.with_normalizer(Some(BertNormalizer::default()));
tokenizer.with_decoder(Some(decoders::wordpiece::WordPiece::default()));
tokenizer.with_post_processor(Some(BertProcessing::new(
("[SEP]".to_string(), sep_id),
("[CLS]".to_string(), cls_id),
));
)));
tokenizer
}

Expand Down Expand Up @@ -81,7 +81,7 @@ fn bench_train(c: &mut Criterion) {
DecoderWrapper,
>;
let mut tokenizer = Tok::new(WordPiece::default());
tokenizer.with_pre_tokenizer(Whitespace {});
tokenizer.with_pre_tokenizer(Some(Whitespace {}));
c.bench_function("WordPiece Train vocabulary (small)", |b| {
b.iter_custom(|iters| {
iter_bench_train(
Expand All @@ -94,7 +94,7 @@ fn bench_train(c: &mut Criterion) {
});

let mut tokenizer = Tok::new(WordPiece::default());
tokenizer.with_pre_tokenizer(Whitespace {});
tokenizer.with_pre_tokenizer(Some(Whitespace {}));
c.bench_function("WordPiece Train vocabulary (big)", |b| {
b.iter_custom(|iters| {
iter_bench_train(
Expand Down
8 changes: 4 additions & 4 deletions tokenizers/benches/bpe_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ static BATCH_SIZE: usize = 1_000;

fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer {
let mut tokenizer = Tokenizer::new(bpe);
tokenizer.with_pre_tokenizer(ByteLevel::default());
tokenizer.with_decoder(ByteLevel::default());
tokenizer.with_pre_tokenizer(Some(ByteLevel::default()));
tokenizer.with_decoder(Some(ByteLevel::default()));
tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]);
tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]);
tokenizer
Expand Down Expand Up @@ -74,7 +74,7 @@ fn bench_train(c: &mut Criterion) {
.build()
.into();
let mut tokenizer = Tokenizer::new(BPE::default()).into_inner();
tokenizer.with_pre_tokenizer(Whitespace {});
tokenizer.with_pre_tokenizer(Some(Whitespace {}));
c.bench_function("BPE Train vocabulary (small)", |b| {
b.iter_custom(|iters| {
iter_bench_train(
Expand All @@ -87,7 +87,7 @@ fn bench_train(c: &mut Criterion) {
});

let mut tokenizer = Tokenizer::new(BPE::default()).into_inner();
tokenizer.with_pre_tokenizer(Whitespace {});
tokenizer.with_pre_tokenizer(Some(Whitespace {}));
c.bench_function("BPE Train vocabulary (big)", |b| {
b.iter_custom(|iters| {
iter_bench_train(
Expand Down
17 changes: 8 additions & 9 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -550,19 +550,18 @@ where
}

/// Set the normalizer
pub fn with_normalizer(&mut self, normalizer: impl Into<N>) -> &mut Self {
self.normalizer = Some(normalizer.into());
pub fn with_normalizer(&mut self, normalizer: Option<impl Into<N>>) -> &mut Self {
self.normalizer = normalizer.map(|norm| norm.into());
self
}

/// Get the normalizer
pub fn get_normalizer(&self) -> Option<&N> {
self.normalizer.as_ref()
}

/// Set the pre tokenizer
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: impl Into<PT>) -> &mut Self {
self.pre_tokenizer = Some(pre_tokenizer.into());
pub fn with_pre_tokenizer(&mut self, pre_tokenizer: Option<impl Into<PT>>) -> &mut Self {
self.pre_tokenizer = pre_tokenizer.map(|tok| tok.into());
self
}

Expand All @@ -572,8 +571,8 @@ where
}

/// Set the post processor
pub fn with_post_processor(&mut self, post_processor: impl Into<PP>) -> &mut Self {
self.post_processor = Some(post_processor.into());
pub fn with_post_processor(&mut self, post_processor: Option<impl Into<PP>>) -> &mut Self {
self.post_processor = post_processor.map(|post_proc| post_proc.into());
self
}

Expand All @@ -583,8 +582,8 @@ where
}

/// Set the decoder
pub fn with_decoder(&mut self, decoder: impl Into<D>) -> &mut Self {
self.decoder = Some(decoder.into());
pub fn with_decoder(&mut self, decoder: Option<impl Into<D>>) -> &mut Self {
self.decoder = decoder.map(|dec| dec.into());
self
}

Expand Down
18 changes: 10 additions & 8 deletions tokenizers/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ pub fn get_byte_level_bpe() -> BPE {
pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer {
let mut tokenizer = Tokenizer::new(get_byte_level_bpe());
tokenizer
.with_pre_tokenizer(ByteLevel::default().add_prefix_space(add_prefix_space))
.with_decoder(ByteLevel::default())
.with_post_processor(ByteLevel::default().trim_offsets(trim_offsets));
.with_pre_tokenizer(Some(
ByteLevel::default().add_prefix_space(add_prefix_space),
))
.with_decoder(Some(ByteLevel::default()))
.with_post_processor(Some(ByteLevel::default().trim_offsets(trim_offsets)));

tokenizer
}
Expand All @@ -43,13 +45,13 @@ pub fn get_bert() -> Tokenizer {
let sep = tokenizer.get_model().token_to_id("[SEP]").unwrap();
let cls = tokenizer.get_model().token_to_id("[CLS]").unwrap();
tokenizer
.with_normalizer(BertNormalizer::default())
.with_pre_tokenizer(BertPreTokenizer)
.with_decoder(WordPieceDecoder::default())
.with_post_processor(BertProcessing::new(
.with_normalizer(Some(BertNormalizer::default()))
.with_pre_tokenizer(Some(BertPreTokenizer))
.with_decoder(Some(WordPieceDecoder::default()))
.with_post_processor(Some(BertProcessing::new(
(String::from("[SEP]"), sep),
(String::from("[CLS]"), cls),
));
)));

tokenizer
}
Loading
Loading