diff --git a/bindings/node/src/tokenizer.rs b/bindings/node/src/tokenizer.rs index 100cfd75f..4acbcac83 100644 --- a/bindings/node/src/tokenizer.rs +++ b/bindings/node/src/tokenizer.rs @@ -208,7 +208,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_pre_tokenizer((*pre_tokenizer).clone()); + .with_pre_tokenizer(Some((*pre_tokenizer).clone())); } #[napi] @@ -217,7 +217,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_decoder((*decoder).clone()); + .with_decoder(Some((*decoder).clone())); } #[napi] @@ -231,7 +231,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_post_processor((*post_processor).clone()); + .with_post_processor(Some((*post_processor).clone())); } #[napi] @@ -240,7 +240,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_normalizer((*normalizer).clone()); + .with_normalizer(Some((*normalizer).clone())); } #[napi] diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 51c1e8bfe..ba143c3f8 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -1,8 +1,6 @@ -use std::sync::{Arc, RwLock}; - -use pyo3::exceptions; -use pyo3::prelude::*; use pyo3::types::*; +use pyo3::{exceptions, prelude::*}; +use std::sync::{Arc, RwLock}; use crate::error::ToPyResult; use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; @@ -354,6 +352,7 @@ impl PyNFKC { /// A list of Normalizer to be run as a sequence #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Sequence")] pub struct PySequence {} + #[pymethods] impl PySequence { #[new] @@ -380,6 +379,22 @@ impl PySequence { fn __len__(&self) -> usize { 0 } + + fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult> { + match &self_.as_ref().normalizer { + PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) { + Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(item))) + .get_as_subtype(py), + _ => Err(PyErr::new::( + "Index not found", + )), + }, + PyNormalizerTypeWrapper::Single(inner) => { + PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(inner))) + .get_as_subtype(py) + } + } + } } /// Lowercase Normalizer diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 4b97319d3..02556e59c 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -460,6 +460,24 @@ impl PySequence { fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> { PyTuple::new_bound(py, [PyList::empty_bound(py)]) } + + fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult> { + match &self_.as_ref().pretok { + PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) { + Some(item) => { + PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(item))) + .get_as_subtype(py) + } + _ => Err(PyErr::new::( + "Index not found", + )), + }, + PyPreTokenizerTypeWrapper::Single(inner) => { + PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(inner))) + .get_as_subtype(py) + } + } + } } pub(crate) fn from_string(string: String) -> Result { diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index f41bf335f..c967f74ff 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1371,8 +1371,9 @@ impl PyTokenizer { /// Set the :class:`~tokenizers.normalizers.Normalizer` #[setter] - fn set_normalizer(&mut self, normalizer: PyRef) { - self.tokenizer.with_normalizer(normalizer.clone()); + fn set_normalizer(&mut self, normalizer: Option>) { + let normalizer_option = normalizer.map(|norm| norm.clone()); + self.tokenizer.with_normalizer(normalizer_option); } /// The `optional` :class:`~tokenizers.pre_tokenizers.PreTokenizer` in use by the Tokenizer @@ -1387,8 +1388,9 @@ impl PyTokenizer { /// Set the :class:`~tokenizers.normalizers.Normalizer` #[setter] - fn set_pre_tokenizer(&mut self, pretok: PyRef) { - self.tokenizer.with_pre_tokenizer(pretok.clone()); + fn set_pre_tokenizer(&mut self, pretok: Option>) { + self.tokenizer + .with_pre_tokenizer(pretok.map(|pre| pre.clone())); } /// The `optional` :class:`~tokenizers.processors.PostProcessor` in use by the Tokenizer @@ -1403,8 +1405,9 @@ impl PyTokenizer { /// Set the :class:`~tokenizers.processors.PostProcessor` #[setter] - fn set_post_processor(&mut self, processor: PyRef) { - self.tokenizer.with_post_processor(processor.clone()); + fn set_post_processor(&mut self, processor: Option>) { + self.tokenizer + .with_post_processor(processor.map(|p| p.clone())); } /// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer @@ -1419,8 +1422,8 @@ impl PyTokenizer { /// Set the :class:`~tokenizers.decoders.Decoder` #[setter] - fn set_decoder(&mut self, decoder: PyRef) { - self.tokenizer.with_decoder(decoder.clone()); + fn set_decoder(&mut self, decoder: Option>) { + self.tokenizer.with_decoder(decoder.map(|d| d.clone())); } } @@ -1436,10 +1439,12 @@ mod test { #[test] fn serialize() { let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default())); - tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![ - Arc::new(RwLock::new(NFKC.into())), - Arc::new(RwLock::new(Lowercase.into())), - ]))); + tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence( + vec![ + Arc::new(RwLock::new(NFKC.into())), + Arc::new(RwLock::new(Lowercase.into())), + ], + )))); let tmp = NamedTempFile::new().unwrap().into_temp_path(); tokenizer.save(&tmp, false).unwrap(); @@ -1450,10 +1455,12 @@ mod test { #[test] fn serde_pyo3() { let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default())); - tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![ - Arc::new(RwLock::new(NFKC.into())), - Arc::new(RwLock::new(Lowercase.into())), - ]))); + tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence( + vec![ + Arc::new(RwLock::new(NFKC.into())), + Arc::new(RwLock::new(Lowercase.into())), + ], + )))); let output = crate::utils::serde_pyo3::to_string(&tokenizer).unwrap(); assert_eq!(output, "Tokenizer(version=\"1.0\", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[NFKC(), Lowercase()]), pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))"); diff --git a/bindings/python/tests/bindings/test_normalizers.py b/bindings/python/tests/bindings/test_normalizers.py index 3fafd60d1..b67845188 100644 --- a/bindings/python/tests/bindings/test_normalizers.py +++ b/bindings/python/tests/bindings/test_normalizers.py @@ -67,6 +67,14 @@ def test_can_make_sequences(self): output = normalizer.normalize_str(" HELLO ") assert output == "hello" + def test_items(self): + normalizers = Sequence([BertNormalizer(True, True), Prepend()]) + assert normalizers[1].__class__ == Prepend + normalizers[0].lowercase = False + assert not normalizers[0].lowercase + with pytest.raises(IndexError): + print(normalizers[2]) + class TestLowercase: def test_instantiate(self): diff --git a/bindings/python/tests/bindings/test_pre_tokenizers.py b/bindings/python/tests/bindings/test_pre_tokenizers.py index fda9adb2a..80086f42e 100644 --- a/bindings/python/tests/bindings/test_pre_tokenizers.py +++ b/bindings/python/tests/bindings/test_pre_tokenizers.py @@ -169,6 +169,13 @@ def test_bert_like(self): ("?", (29, 30)), ] + def test_items(self): + pre_tokenizers = Sequence([Metaspace("a", "never", split=True), Punctuation()]) + assert pre_tokenizers[1].__class__ == Punctuation + assert pre_tokenizers[0].__class__ == Metaspace + pre_tokenizers[0].split = False + assert not pre_tokenizers[0].split + class TestDigits: def test_instantiate(self): diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 3851f0764..2118709a0 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -6,10 +6,11 @@ from tokenizers import AddedToken, Encoding, Tokenizer from tokenizers.implementations import BertWordPieceTokenizer from tokenizers.models import BPE, Model, Unigram -from tokenizers.pre_tokenizers import ByteLevel +from tokenizers.pre_tokenizers import ByteLevel, Metaspace from tokenizers.processors import RobertaProcessing, TemplateProcessing from tokenizers.normalizers import Strip, Lowercase, Sequence + from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files @@ -551,6 +552,16 @@ def test_decode_special(self): assert output == "name is john" assert tokenizer.get_added_tokens_decoder()[0] == AddedToken("my", special=True) + def test_setting_to_none(self): + tokenizer = Tokenizer(BPE()) + tokenizer.normalizer = Strip() + tokenizer.normalizer = None + assert tokenizer.normalizer == None + + tokenizer.pre_tokenizer = Metaspace() + tokenizer.pre_tokenizer = None + assert tokenizer.pre_tokenizer == None + class TestTokenizerRepr: def test_repr(self): diff --git a/tokenizers/benches/bert_benchmark.rs b/tokenizers/benches/bert_benchmark.rs index bc84fe8c9..cfdab9070 100644 --- a/tokenizers/benches/bert_benchmark.rs +++ b/tokenizers/benches/bert_benchmark.rs @@ -34,13 +34,13 @@ fn create_bert_tokenizer(wp: WordPiece) -> BertTokenizer { let sep_id = *wp.get_vocab().get("[SEP]").unwrap(); let cls_id = *wp.get_vocab().get("[CLS]").unwrap(); let mut tokenizer = TokenizerImpl::new(wp); - tokenizer.with_pre_tokenizer(BertPreTokenizer); - tokenizer.with_normalizer(BertNormalizer::default()); - tokenizer.with_decoder(decoders::wordpiece::WordPiece::default()); - tokenizer.with_post_processor(BertProcessing::new( + tokenizer.with_pre_tokenizer(Some(BertPreTokenizer)); + tokenizer.with_normalizer(Some(BertNormalizer::default())); + tokenizer.with_decoder(Some(decoders::wordpiece::WordPiece::default())); + tokenizer.with_post_processor(Some(BertProcessing::new( ("[SEP]".to_string(), sep_id), ("[CLS]".to_string(), cls_id), - )); + ))); tokenizer } @@ -81,7 +81,7 @@ fn bench_train(c: &mut Criterion) { DecoderWrapper, >; let mut tokenizer = Tok::new(WordPiece::default()); - tokenizer.with_pre_tokenizer(Whitespace {}); + tokenizer.with_pre_tokenizer(Some(Whitespace {})); c.bench_function("WordPiece Train vocabulary (small)", |b| { b.iter_custom(|iters| { iter_bench_train( @@ -94,7 +94,7 @@ fn bench_train(c: &mut Criterion) { }); let mut tokenizer = Tok::new(WordPiece::default()); - tokenizer.with_pre_tokenizer(Whitespace {}); + tokenizer.with_pre_tokenizer(Some(Whitespace {})); c.bench_function("WordPiece Train vocabulary (big)", |b| { b.iter_custom(|iters| { iter_bench_train( diff --git a/tokenizers/benches/bpe_benchmark.rs b/tokenizers/benches/bpe_benchmark.rs index dd65d233e..f0097bf82 100644 --- a/tokenizers/benches/bpe_benchmark.rs +++ b/tokenizers/benches/bpe_benchmark.rs @@ -22,8 +22,8 @@ static BATCH_SIZE: usize = 1_000; fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer { let mut tokenizer = Tokenizer::new(bpe); - tokenizer.with_pre_tokenizer(ByteLevel::default()); - tokenizer.with_decoder(ByteLevel::default()); + tokenizer.with_pre_tokenizer(Some(ByteLevel::default())); + tokenizer.with_decoder(Some(ByteLevel::default())); tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]); tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]); tokenizer @@ -74,7 +74,7 @@ fn bench_train(c: &mut Criterion) { .build() .into(); let mut tokenizer = Tokenizer::new(BPE::default()).into_inner(); - tokenizer.with_pre_tokenizer(Whitespace {}); + tokenizer.with_pre_tokenizer(Some(Whitespace {})); c.bench_function("BPE Train vocabulary (small)", |b| { b.iter_custom(|iters| { iter_bench_train( @@ -87,7 +87,7 @@ fn bench_train(c: &mut Criterion) { }); let mut tokenizer = Tokenizer::new(BPE::default()).into_inner(); - tokenizer.with_pre_tokenizer(Whitespace {}); + tokenizer.with_pre_tokenizer(Some(Whitespace {})); c.bench_function("BPE Train vocabulary (big)", |b| { b.iter_custom(|iters| { iter_bench_train( diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 766ee1cd9..1c2ad6e0b 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -550,19 +550,18 @@ where } /// Set the normalizer - pub fn with_normalizer(&mut self, normalizer: impl Into) -> &mut Self { - self.normalizer = Some(normalizer.into()); + pub fn with_normalizer(&mut self, normalizer: Option>) -> &mut Self { + self.normalizer = normalizer.map(|norm| norm.into()); self } - /// Get the normalizer pub fn get_normalizer(&self) -> Option<&N> { self.normalizer.as_ref() } /// Set the pre tokenizer - pub fn with_pre_tokenizer(&mut self, pre_tokenizer: impl Into) -> &mut Self { - self.pre_tokenizer = Some(pre_tokenizer.into()); + pub fn with_pre_tokenizer(&mut self, pre_tokenizer: Option>) -> &mut Self { + self.pre_tokenizer = pre_tokenizer.map(|tok| tok.into()); self } @@ -572,8 +571,8 @@ where } /// Set the post processor - pub fn with_post_processor(&mut self, post_processor: impl Into) -> &mut Self { - self.post_processor = Some(post_processor.into()); + pub fn with_post_processor(&mut self, post_processor: Option>) -> &mut Self { + self.post_processor = post_processor.map(|post_proc| post_proc.into()); self } @@ -583,8 +582,8 @@ where } /// Set the decoder - pub fn with_decoder(&mut self, decoder: impl Into) -> &mut Self { - self.decoder = Some(decoder.into()); + pub fn with_decoder(&mut self, decoder: Option>) -> &mut Self { + self.decoder = decoder.map(|dec| dec.into()); self } diff --git a/tokenizers/tests/common/mod.rs b/tokenizers/tests/common/mod.rs index ea70aba20..26129699b 100644 --- a/tokenizers/tests/common/mod.rs +++ b/tokenizers/tests/common/mod.rs @@ -23,9 +23,11 @@ pub fn get_byte_level_bpe() -> BPE { pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer { let mut tokenizer = Tokenizer::new(get_byte_level_bpe()); tokenizer - .with_pre_tokenizer(ByteLevel::default().add_prefix_space(add_prefix_space)) - .with_decoder(ByteLevel::default()) - .with_post_processor(ByteLevel::default().trim_offsets(trim_offsets)); + .with_pre_tokenizer(Some( + ByteLevel::default().add_prefix_space(add_prefix_space), + )) + .with_decoder(Some(ByteLevel::default())) + .with_post_processor(Some(ByteLevel::default().trim_offsets(trim_offsets))); tokenizer } @@ -43,13 +45,13 @@ pub fn get_bert() -> Tokenizer { let sep = tokenizer.get_model().token_to_id("[SEP]").unwrap(); let cls = tokenizer.get_model().token_to_id("[CLS]").unwrap(); tokenizer - .with_normalizer(BertNormalizer::default()) - .with_pre_tokenizer(BertPreTokenizer) - .with_decoder(WordPieceDecoder::default()) - .with_post_processor(BertProcessing::new( + .with_normalizer(Some(BertNormalizer::default())) + .with_pre_tokenizer(Some(BertPreTokenizer)) + .with_decoder(Some(WordPieceDecoder::default())) + .with_post_processor(Some(BertProcessing::new( (String::from("[SEP]"), sep), (String::from("[CLS]"), cls), - )); + ))); tokenizer } diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index 0a9fbaec0..c0c471a93 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -93,7 +93,7 @@ fn quicktour_slow_train() -> tokenizers::Result<()> { // START quicktour_init_pretok use tokenizers::pre_tokenizers::whitespace::Whitespace; - tokenizer.with_pre_tokenizer(Whitespace {}); + tokenizer.with_pre_tokenizer(Some(Whitespace {})); // END quicktour_init_pretok // START quicktour_train @@ -157,7 +157,7 @@ fn quicktour() -> tokenizers::Result<()> { ("[CLS]", tokenizer.token_to_id("[CLS]").unwrap()), ("[SEP]", tokenizer.token_to_id("[SEP]").unwrap()), ]; - tokenizer.with_post_processor( + tokenizer.with_post_processor(Some( TemplateProcessing::builder() .try_single("[CLS] $A [SEP]") .unwrap() @@ -165,7 +165,7 @@ fn quicktour() -> tokenizers::Result<()> { .unwrap() .special_tokens(special_tokens) .build()?, - ); + )); // END quicktour_init_template_processing // START quicktour_print_special_tokens let output = tokenizer.encode("Hello, y'all! How are you 😁 ?", true)?; @@ -261,7 +261,7 @@ fn pipeline() -> tokenizers::Result<()> { // END pipeline_test_normalizer assert_eq!(normalized.get(), "Hello how are u?"); // START pipeline_replace_normalizer - tokenizer.with_normalizer(normalizer); + tokenizer.with_normalizer(Some(normalizer)); // END pipeline_replace_normalizer // START pipeline_setup_pre_tokenizer use tokenizers::pre_tokenizers::whitespace::Whitespace; @@ -325,12 +325,12 @@ fn pipeline() -> tokenizers::Result<()> { ] ); // START pipeline_replace_pre_tokenizer - tokenizer.with_pre_tokenizer(pre_tokenizer); + tokenizer.with_pre_tokenizer(Some(pre_tokenizer)); // END pipeline_replace_pre_tokenizer // START pipeline_setup_processor use tokenizers::processors::template::TemplateProcessing; - tokenizer.with_post_processor( + tokenizer.with_post_processor(Some( TemplateProcessing::builder() .try_single("[CLS] $A [SEP]") .unwrap() @@ -339,7 +339,7 @@ fn pipeline() -> tokenizers::Result<()> { .special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)]) .build() .unwrap(), - ); + )); // END pipeline_setup_processor // START pipeline_test_decoding let output = tokenizer.encode("Hello, y'all! How are you 😁 ?", true)?; @@ -375,21 +375,21 @@ fn train_pipeline_bert() -> tokenizers::Result<()> { use tokenizers::normalizers::utils::Sequence as NormalizerSequence; use tokenizers::normalizers::{strip::StripAccents, unicode::NFD, utils::Lowercase}; - bert_tokenizer.with_normalizer(NormalizerSequence::new(vec![ + bert_tokenizer.with_normalizer(Some(NormalizerSequence::new(vec![ NFD.into(), Lowercase.into(), StripAccents.into(), - ])); + ]))); // END bert_setup_normalizer // START bert_setup_pre_tokenizer use tokenizers::pre_tokenizers::whitespace::Whitespace; - bert_tokenizer.with_pre_tokenizer(Whitespace {}); + bert_tokenizer.with_pre_tokenizer(Some(Whitespace {})); // END bert_setup_pre_tokenizer // START bert_setup_processor use tokenizers::processors::template::TemplateProcessing; - bert_tokenizer.with_post_processor( + bert_tokenizer.with_post_processor(Some( TemplateProcessing::builder() .try_single("[CLS] $A [SEP]") .unwrap() @@ -398,7 +398,7 @@ fn train_pipeline_bert() -> tokenizers::Result<()> { .special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)]) .build() .unwrap(), - ); + )); // END bert_setup_processor // START bert_train_tokenizer use tokenizers::models::{wordpiece::WordPieceTrainer, TrainerWrapper}; @@ -450,7 +450,7 @@ fn pipeline_bert() -> tokenizers::Result<()> { // START bert_proper_decoding use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder; - bert_tokenizer.with_decoder(WordPieceDecoder::default()); + bert_tokenizer.with_decoder(Some(WordPieceDecoder::default())); let decoded = bert_tokenizer.decode(output.get_ids(), true)?; // "welcome to the tokenizers library." // END bert_proper_decoding diff --git a/tokenizers/tests/serialization.rs b/tokenizers/tests/serialization.rs index 4d51d4281..dc0c95a57 100644 --- a/tokenizers/tests/serialization.rs +++ b/tokenizers/tests/serialization.rs @@ -203,7 +203,7 @@ fn models() { fn tokenizer() { let wordpiece = WordPiece::default(); let mut tokenizer = Tokenizer::new(wordpiece); - tokenizer.with_normalizer(NFC); + tokenizer.with_normalizer(Some(NFC)); let ser = serde_json::to_string(&tokenizer).unwrap(); let _: Tokenizer = serde_json::from_str(&ser).unwrap(); let unwrapped_nfc_tok: TokenizerImpl<