From 4bce6b2856ab425b1c7a4a850d2622da9053b0c0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 4 Aug 2024 14:01:14 +0200 Subject: [PATCH 01/14] initial commit --- bindings/python/src/normalizers.rs | 85 ++++++++++++++++++++++++++++-- 1 file changed, 81 insertions(+), 4 deletions(-) diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 864947e39..945335569 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -1,8 +1,6 @@ -use std::sync::{Arc, RwLock}; - -use pyo3::exceptions; -use pyo3::prelude::*; use pyo3::types::*; +use pyo3::{exceptions, prelude::*}; +use std::sync::{Arc, RwLock}; use crate::error::ToPyResult; use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; @@ -344,6 +342,7 @@ impl PyNFKC { /// A list of Normalizer to be run as a sequence #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Sequence")] pub struct PySequence {} + #[pymethods] impl PySequence { #[new] @@ -370,6 +369,84 @@ impl PySequence { fn __len__(&self) -> usize { 0 } + // fn _get_item<'p>( + // self_: PyRef<'_, Self>, + // py: Python<'p>, + // index: usize, + // ) -> PyResult> { + // let normalizer = &self_.as_ref().normalizer; + + // match normalizer { + // PyNormalizerTypeWrapper::Sequence(inner) => { + // if let Some(item) = inner.get(index) { + // // Clone the Arc to gain access + // let item_clone = Arc::clone(&item); + + // // Attempt to acquire the write lock + // let mut write_guard = item_clone.write().map_err(|_| { + // PyErr::new::( + // "Failed to acquire write lock", + // ) + // })?; + + // // Convert the PyNormalizerWrapper to a Python object + // let py_normalizer = PyNormalizerWrapper::clone(&*write_guard); + // Ok(py_normalizer.into()) // Assuming to_py is implemented + // } else { + // Err(PyErr::new::( + // "Index out of bounds", + // )) + // } + // } + // PyNormalizerTypeWrapper::Single(inner) => { + // // Attempt to acquire the write lock + // let mut write_guard = inner.write().map_err(|_| { + // PyErr::new::( + // "Failed to acquire write lock", + // ) + // })?; + + // // Convert the PyNormalizerWrapper to a Python object + // let py_normalizer = PyNormalizerWrapper::clone(&*write_guard); + // Ok(py_normalizer.to_py(py)) // Assuming to_py is implemented + // } + // } + // } + + fn __getitem__<'p>( + self_: PyRef<'_, Self>, + py: Python<'p>, + index: usize, + ) -> PyResult> { + match &self_.as_ref().normalizer { + PyNormalizerTypeWrapper::Sequence(inner) => { + if let Some(item) = inner.get(index) { + match Arc::clone(&item).read() { + Ok(read_guard) => { + let py_normalizer: PyNormalizerTypeWrapper = read_guard.clone().into(); + let pynorm = PyNormalizer::new(py_normalizer).get_as_subtype(py); + Ok(pynorm.unwrap()) + } + Err(_) => Err(PyErr::new::( + "Failed to acquire read lock", + )), + } + } else { + Err(PyErr::new::( + "Failed to acquire read lock", + )) + } + } + PyNormalizerTypeWrapper::Single(inner) => match inner.read() { + Ok(read_guard) => Ok(PyNormalizer::new(read_guard.clone().into()) + .get_as_subtype(py) + .unwrap()), + Err(_) => Err(PyErr::new::( + "Failed to acquire read lock", + )), + }, + } + } } /// Lowercase Normalizer From 908a5be268203584c9f47a4e101ec095a634adf1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 4 Aug 2024 18:55:40 +0200 Subject: [PATCH 02/14] support None --- bindings/python/src/tokenizer.rs | 11 +++++++---- tokenizers/src/tokenizer/mod.rs | 9 ++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 1c6bc9cc1..9311ce6cf 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -8,6 +8,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::*; use tk::models::bpe::BPE; +use tk::normalizer; use tk::tokenizer::{ Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, TruncationDirection, TruncationParams, TruncationStrategy, @@ -1359,8 +1360,9 @@ impl PyTokenizer { /// Set the :class:`~tokenizers.normalizers.Normalizer` #[setter] - fn set_normalizer(&mut self, normalizer: PyRef) { - self.tokenizer.with_normalizer(normalizer.clone()); + fn set_normalizer(&mut self, normalizer: Option>) { + let normalizer_option = normalizer.map(|norm| norm.clone()); + self.tokenizer.with_normalizer(normalizer_option); } /// The `optional` :class:`~tokenizers.pre_tokenizers.PreTokenizer` in use by the Tokenizer @@ -1375,8 +1377,9 @@ impl PyTokenizer { /// Set the :class:`~tokenizers.normalizers.Normalizer` #[setter] - fn set_pre_tokenizer(&mut self, pretok: PyRef) { - self.tokenizer.with_pre_tokenizer(pretok.clone()); + fn set_pre_tokenizer(&mut self, pretok: Option>) { + let pretok = pretok.map(|pre| pre.clone()); + self.tokenizer.with_pre_tokenizer(pretok); } /// The `optional` :class:`~tokenizers.processors.PostProcessor` in use by the Tokenizer diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 766ee1cd9..69ca9a3bf 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -550,19 +550,18 @@ where } /// Set the normalizer - pub fn with_normalizer(&mut self, normalizer: impl Into) -> &mut Self { - self.normalizer = Some(normalizer.into()); + pub fn with_normalizer(&mut self, normalizer: Option>) -> &mut Self { + self.normalizer = normalizer.map(|norm| norm.into()); self } - /// Get the normalizer pub fn get_normalizer(&self) -> Option<&N> { self.normalizer.as_ref() } /// Set the pre tokenizer - pub fn with_pre_tokenizer(&mut self, pre_tokenizer: impl Into) -> &mut Self { - self.pre_tokenizer = Some(pre_tokenizer.into()); + pub fn with_pre_tokenizer(&mut self, pre_tokenizer: Option>) -> &mut Self { + self.pre_tokenizer = pre_tokenizer.map(|tok| tok.into()); self } From e5872f02a37780a00f31df6280bb193cff556a81 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 4 Aug 2024 19:04:39 +0200 Subject: [PATCH 03/14] fix clippy --- bindings/python/src/normalizers.rs | 57 +++--------------------------- 1 file changed, 5 insertions(+), 52 deletions(-) diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 945335569..784e5fb86 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -369,71 +369,24 @@ impl PySequence { fn __len__(&self) -> usize { 0 } - // fn _get_item<'p>( - // self_: PyRef<'_, Self>, - // py: Python<'p>, - // index: usize, - // ) -> PyResult> { - // let normalizer = &self_.as_ref().normalizer; - - // match normalizer { - // PyNormalizerTypeWrapper::Sequence(inner) => { - // if let Some(item) = inner.get(index) { - // // Clone the Arc to gain access - // let item_clone = Arc::clone(&item); - - // // Attempt to acquire the write lock - // let mut write_guard = item_clone.write().map_err(|_| { - // PyErr::new::( - // "Failed to acquire write lock", - // ) - // })?; - - // // Convert the PyNormalizerWrapper to a Python object - // let py_normalizer = PyNormalizerWrapper::clone(&*write_guard); - // Ok(py_normalizer.into()) // Assuming to_py is implemented - // } else { - // Err(PyErr::new::( - // "Index out of bounds", - // )) - // } - // } - // PyNormalizerTypeWrapper::Single(inner) => { - // // Attempt to acquire the write lock - // let mut write_guard = inner.write().map_err(|_| { - // PyErr::new::( - // "Failed to acquire write lock", - // ) - // })?; - - // // Convert the PyNormalizerWrapper to a Python object - // let py_normalizer = PyNormalizerWrapper::clone(&*write_guard); - // Ok(py_normalizer.to_py(py)) // Assuming to_py is implemented - // } - // } - // } - - fn __getitem__<'p>( - self_: PyRef<'_, Self>, - py: Python<'p>, - index: usize, - ) -> PyResult> { + + fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult> { match &self_.as_ref().normalizer { PyNormalizerTypeWrapper::Sequence(inner) => { if let Some(item) = inner.get(index) { - match Arc::clone(&item).read() { + match Arc::clone(item).read() { Ok(read_guard) => { let py_normalizer: PyNormalizerTypeWrapper = read_guard.clone().into(); let pynorm = PyNormalizer::new(py_normalizer).get_as_subtype(py); Ok(pynorm.unwrap()) } Err(_) => Err(PyErr::new::( - "Failed to acquire read lock", + "Index not found", )), } } else { Err(PyErr::new::( - "Failed to acquire read lock", + "Index not found", )) } } From 79af84f6156c78f6ffc4152a0bc24eb83c6087c0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 4 Aug 2024 19:20:25 +0200 Subject: [PATCH 04/14] cleanup --- bindings/python/src/normalizers.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 784e5fb86..262144fdd 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -374,16 +374,8 @@ impl PySequence { match &self_.as_ref().normalizer { PyNormalizerTypeWrapper::Sequence(inner) => { if let Some(item) = inner.get(index) { - match Arc::clone(item).read() { - Ok(read_guard) => { - let py_normalizer: PyNormalizerTypeWrapper = read_guard.clone().into(); - let pynorm = PyNormalizer::new(py_normalizer).get_as_subtype(py); - Ok(pynorm.unwrap()) - } - Err(_) => Err(PyErr::new::( - "Index not found", - )), - } + PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(item))) + .get_as_subtype(py) } else { Err(PyErr::new::( "Index not found", From 1d44d2c9b9ae84cb18197d4b32a53934b0d0b0ec Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 4 Aug 2024 19:26:59 +0200 Subject: [PATCH 05/14] clean? --- bindings/python/src/normalizers.rs | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 262144fdd..6f866515d 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -372,24 +372,17 @@ impl PySequence { fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult> { match &self_.as_ref().normalizer { - PyNormalizerTypeWrapper::Sequence(inner) => { - if let Some(item) = inner.get(index) { - PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(item))) - .get_as_subtype(py) - } else { - Err(PyErr::new::( - "Index not found", - )) - } - } - PyNormalizerTypeWrapper::Single(inner) => match inner.read() { - Ok(read_guard) => Ok(PyNormalizer::new(read_guard.clone().into()) - .get_as_subtype(py) - .unwrap()), - Err(_) => Err(PyErr::new::( - "Failed to acquire read lock", + PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) { + Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(item))) + .get_as_subtype(py), + _ => Err(PyErr::new::( + "Index not found", )), }, + PyNormalizerTypeWrapper::Single(inner) => { + PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(inner))) + .get_as_subtype(py) + } } } } From c2cccc8df24fe8b0961c555a13009935495a1b85 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 4 Aug 2024 19:31:05 +0200 Subject: [PATCH 06/14] propagate to pre_tokenizer --- bindings/python/src/pre_tokenizers.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index a2bd9b39c..a30b4ca82 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -450,6 +450,24 @@ impl PySequence { fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> { PyTuple::new_bound(py, [PyList::empty_bound(py)]) } + + fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult> { + match &self_.as_ref().pretok { + PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) { + Some(item) => { + PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(item))) + .get_as_subtype(py) + } + _ => Err(PyErr::new::( + "Index not found", + )), + }, + PyPreTokenizerTypeWrapper::Single(inner) => { + PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(inner))) + .get_as_subtype(py) + } + } + } } pub(crate) fn from_string(string: String) -> Result { From ac15c156a2713318e423a1c2d889e4d83ffa626e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 4 Aug 2024 19:33:11 +0200 Subject: [PATCH 07/14] fix test --- bindings/python/src/tokenizer.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 9311ce6cf..dec0cb1df 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1427,10 +1427,12 @@ mod test { #[test] fn serialize() { let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default())); - tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![ - Arc::new(RwLock::new(NFKC.into())), - Arc::new(RwLock::new(Lowercase.into())), - ]))); + tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence( + vec![ + Arc::new(RwLock::new(NFKC.into())), + Arc::new(RwLock::new(Lowercase.into())), + ], + )))); let tmp = NamedTempFile::new().unwrap().into_temp_path(); tokenizer.save(&tmp, false).unwrap(); From 458368cbcd7037e86f5829d43f2a17ecb6774377 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 6 Aug 2024 16:16:31 +0200 Subject: [PATCH 08/14] fix rust tests --- bindings/python/src/tokenizer.rs | 1 - tokenizers/benches/bert_benchmark.rs | 8 ++++---- tokenizers/benches/bpe_benchmark.rs | 6 +++--- tokenizers/tests/documentation.rs | 12 ++++++------ tokenizers/tests/serialization.rs | 2 +- 5 files changed, 14 insertions(+), 15 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index dec0cb1df..10f040418 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -8,7 +8,6 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::*; use tk::models::bpe::BPE; -use tk::normalizer; use tk::tokenizer::{ Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, TruncationDirection, TruncationParams, TruncationStrategy, diff --git a/tokenizers/benches/bert_benchmark.rs b/tokenizers/benches/bert_benchmark.rs index bc84fe8c9..38037aae4 100644 --- a/tokenizers/benches/bert_benchmark.rs +++ b/tokenizers/benches/bert_benchmark.rs @@ -34,8 +34,8 @@ fn create_bert_tokenizer(wp: WordPiece) -> BertTokenizer { let sep_id = *wp.get_vocab().get("[SEP]").unwrap(); let cls_id = *wp.get_vocab().get("[CLS]").unwrap(); let mut tokenizer = TokenizerImpl::new(wp); - tokenizer.with_pre_tokenizer(BertPreTokenizer); - tokenizer.with_normalizer(BertNormalizer::default()); + tokenizer.with_pre_tokenizer(Some(BertPreTokenizer)); + tokenizer.with_normalizer(Some(BertNormalizer::default())); tokenizer.with_decoder(decoders::wordpiece::WordPiece::default()); tokenizer.with_post_processor(BertProcessing::new( ("[SEP]".to_string(), sep_id), @@ -81,7 +81,7 @@ fn bench_train(c: &mut Criterion) { DecoderWrapper, >; let mut tokenizer = Tok::new(WordPiece::default()); - tokenizer.with_pre_tokenizer(Whitespace {}); + tokenizer.with_pre_tokenizer(Some(Whitespace {})); c.bench_function("WordPiece Train vocabulary (small)", |b| { b.iter_custom(|iters| { iter_bench_train( @@ -94,7 +94,7 @@ fn bench_train(c: &mut Criterion) { }); let mut tokenizer = Tok::new(WordPiece::default()); - tokenizer.with_pre_tokenizer(Whitespace {}); + tokenizer.with_pre_tokenizer(Some(Whitespace {})); c.bench_function("WordPiece Train vocabulary (big)", |b| { b.iter_custom(|iters| { iter_bench_train( diff --git a/tokenizers/benches/bpe_benchmark.rs b/tokenizers/benches/bpe_benchmark.rs index dd65d233e..6ffb8e343 100644 --- a/tokenizers/benches/bpe_benchmark.rs +++ b/tokenizers/benches/bpe_benchmark.rs @@ -22,7 +22,7 @@ static BATCH_SIZE: usize = 1_000; fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer { let mut tokenizer = Tokenizer::new(bpe); - tokenizer.with_pre_tokenizer(ByteLevel::default()); + tokenizer.with_pre_tokenizer(Some(ByteLevel::default())); tokenizer.with_decoder(ByteLevel::default()); tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]); tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]); @@ -74,7 +74,7 @@ fn bench_train(c: &mut Criterion) { .build() .into(); let mut tokenizer = Tokenizer::new(BPE::default()).into_inner(); - tokenizer.with_pre_tokenizer(Whitespace {}); + tokenizer.with_pre_tokenizer(Some(Whitespace {})); c.bench_function("BPE Train vocabulary (small)", |b| { b.iter_custom(|iters| { iter_bench_train( @@ -87,7 +87,7 @@ fn bench_train(c: &mut Criterion) { }); let mut tokenizer = Tokenizer::new(BPE::default()).into_inner(); - tokenizer.with_pre_tokenizer(Whitespace {}); + tokenizer.with_pre_tokenizer(Some(Whitespace {})); c.bench_function("BPE Train vocabulary (big)", |b| { b.iter_custom(|iters| { iter_bench_train( diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index 0a9fbaec0..8c0f61116 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -93,7 +93,7 @@ fn quicktour_slow_train() -> tokenizers::Result<()> { // START quicktour_init_pretok use tokenizers::pre_tokenizers::whitespace::Whitespace; - tokenizer.with_pre_tokenizer(Whitespace {}); + tokenizer.with_pre_tokenizer(Some(Whitespace {})); // END quicktour_init_pretok // START quicktour_train @@ -261,7 +261,7 @@ fn pipeline() -> tokenizers::Result<()> { // END pipeline_test_normalizer assert_eq!(normalized.get(), "Hello how are u?"); // START pipeline_replace_normalizer - tokenizer.with_normalizer(normalizer); + tokenizer.with_normalizer(Some(normalizer)); // END pipeline_replace_normalizer // START pipeline_setup_pre_tokenizer use tokenizers::pre_tokenizers::whitespace::Whitespace; @@ -325,7 +325,7 @@ fn pipeline() -> tokenizers::Result<()> { ] ); // START pipeline_replace_pre_tokenizer - tokenizer.with_pre_tokenizer(pre_tokenizer); + tokenizer.with_pre_tokenizer(Some(pre_tokenizer)); // END pipeline_replace_pre_tokenizer // START pipeline_setup_processor use tokenizers::processors::template::TemplateProcessing; @@ -375,16 +375,16 @@ fn train_pipeline_bert() -> tokenizers::Result<()> { use tokenizers::normalizers::utils::Sequence as NormalizerSequence; use tokenizers::normalizers::{strip::StripAccents, unicode::NFD, utils::Lowercase}; - bert_tokenizer.with_normalizer(NormalizerSequence::new(vec![ + bert_tokenizer.with_normalizer(Some(NormalizerSequence::new(vec![ NFD.into(), Lowercase.into(), StripAccents.into(), - ])); + ]))); // END bert_setup_normalizer // START bert_setup_pre_tokenizer use tokenizers::pre_tokenizers::whitespace::Whitespace; - bert_tokenizer.with_pre_tokenizer(Whitespace {}); + bert_tokenizer.with_pre_tokenizer(Some(Whitespace {})); // END bert_setup_pre_tokenizer // START bert_setup_processor use tokenizers::processors::template::TemplateProcessing; diff --git a/tokenizers/tests/serialization.rs b/tokenizers/tests/serialization.rs index 4d51d4281..dc0c95a57 100644 --- a/tokenizers/tests/serialization.rs +++ b/tokenizers/tests/serialization.rs @@ -203,7 +203,7 @@ fn models() { fn tokenizer() { let wordpiece = WordPiece::default(); let mut tokenizer = Tokenizer::new(wordpiece); - tokenizer.with_normalizer(NFC); + tokenizer.with_normalizer(Some(NFC)); let ser = serde_json::to_string(&tokenizer).unwrap(); let _: Tokenizer = serde_json::from_str(&ser).unwrap(); let unwrapped_nfc_tok: TokenizerImpl< From 34b8f896188479b795c8e7c429b49de619732651 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 6 Aug 2024 16:17:53 +0200 Subject: [PATCH 09/14] fix node --- bindings/node/src/tokenizer.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bindings/node/src/tokenizer.rs b/bindings/node/src/tokenizer.rs index 100cfd75f..b47c4c8a8 100644 --- a/bindings/node/src/tokenizer.rs +++ b/bindings/node/src/tokenizer.rs @@ -208,7 +208,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_pre_tokenizer((*pre_tokenizer).clone()); + .with_pre_tokenizer(Some(*pre_tokenizer).clone()); } #[napi] @@ -240,7 +240,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_normalizer((*normalizer).clone()); + .with_normalizer(Some(*normalizer).clone()); } #[napi] From 1b926f348c4138de8a00a610d3e5ad5d99482987 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 6 Aug 2024 16:23:23 +0200 Subject: [PATCH 10/14] propagate to decoder and post processor --- bindings/python/src/tokenizer.rs | 11 +++++------ tokenizers/src/tokenizer/mod.rs | 8 ++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 10f040418..9fd596022 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1377,8 +1377,7 @@ impl PyTokenizer { /// Set the :class:`~tokenizers.normalizers.Normalizer` #[setter] fn set_pre_tokenizer(&mut self, pretok: Option>) { - let pretok = pretok.map(|pre| pre.clone()); - self.tokenizer.with_pre_tokenizer(pretok); + self.tokenizer.with_pre_tokenizer(pretok.map(|pre| pre.clone())); } /// The `optional` :class:`~tokenizers.processors.PostProcessor` in use by the Tokenizer @@ -1393,8 +1392,8 @@ impl PyTokenizer { /// Set the :class:`~tokenizers.processors.PostProcessor` #[setter] - fn set_post_processor(&mut self, processor: PyRef) { - self.tokenizer.with_post_processor(processor.clone()); + fn set_post_processor(&mut self, processor: Option>) { + self.tokenizer.with_post_processor(processor.map(|p| p.clone())); } /// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer @@ -1409,8 +1408,8 @@ impl PyTokenizer { /// Set the :class:`~tokenizers.decoders.Decoder` #[setter] - fn set_decoder(&mut self, decoder: PyRef) { - self.tokenizer.with_decoder(decoder.clone()); + fn set_decoder(&mut self, decoder: Option>) { + self.tokenizer.with_decoder(decoder.map(|d| d.clone())); } } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 69ca9a3bf..a3e2edafa 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -571,8 +571,8 @@ where } /// Set the post processor - pub fn with_post_processor(&mut self, post_processor: impl Into) -> &mut Self { - self.post_processor = Some(post_processor.into()); + pub fn with_post_processor(&mut self, post_processor: Option>) -> &mut Self { + self.post_processor = post_processor.map(|post_proc| post_proc.into()); self } @@ -582,8 +582,8 @@ where } /// Set the decoder - pub fn with_decoder(&mut self, decoder: impl Into) -> &mut Self { - self.decoder = Some(decoder.into()); + pub fn with_decoder(&mut self, decoder: Option>) -> &mut Self { + self.decoder = decoder.map(|dec|dec.into()); self } From 44a63ebb91ef69fb5df9f9707587b3831035caa7 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 6 Aug 2024 16:31:34 +0200 Subject: [PATCH 11/14] fix calls --- bindings/node/src/tokenizer.rs | 4 ++-- tokenizers/benches/bert_benchmark.rs | 6 +++--- tokenizers/benches/bpe_benchmark.rs | 2 +- tokenizers/tests/common/mod.rs | 16 ++++++++-------- tokenizers/tests/documentation.rs | 14 +++++++------- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/bindings/node/src/tokenizer.rs b/bindings/node/src/tokenizer.rs index b47c4c8a8..a34791c76 100644 --- a/bindings/node/src/tokenizer.rs +++ b/bindings/node/src/tokenizer.rs @@ -217,7 +217,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_decoder((*decoder).clone()); + .with_decoder(Some(*decoder).clone()); } #[napi] @@ -231,7 +231,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_post_processor((*post_processor).clone()); + .with_post_processor(Some(*post_processor).clone()); } #[napi] diff --git a/tokenizers/benches/bert_benchmark.rs b/tokenizers/benches/bert_benchmark.rs index 38037aae4..cfdab9070 100644 --- a/tokenizers/benches/bert_benchmark.rs +++ b/tokenizers/benches/bert_benchmark.rs @@ -36,11 +36,11 @@ fn create_bert_tokenizer(wp: WordPiece) -> BertTokenizer { let mut tokenizer = TokenizerImpl::new(wp); tokenizer.with_pre_tokenizer(Some(BertPreTokenizer)); tokenizer.with_normalizer(Some(BertNormalizer::default())); - tokenizer.with_decoder(decoders::wordpiece::WordPiece::default()); - tokenizer.with_post_processor(BertProcessing::new( + tokenizer.with_decoder(Some(decoders::wordpiece::WordPiece::default())); + tokenizer.with_post_processor(Some(BertProcessing::new( ("[SEP]".to_string(), sep_id), ("[CLS]".to_string(), cls_id), - )); + ))); tokenizer } diff --git a/tokenizers/benches/bpe_benchmark.rs b/tokenizers/benches/bpe_benchmark.rs index 6ffb8e343..f0097bf82 100644 --- a/tokenizers/benches/bpe_benchmark.rs +++ b/tokenizers/benches/bpe_benchmark.rs @@ -23,7 +23,7 @@ static BATCH_SIZE: usize = 1_000; fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer { let mut tokenizer = Tokenizer::new(bpe); tokenizer.with_pre_tokenizer(Some(ByteLevel::default())); - tokenizer.with_decoder(ByteLevel::default()); + tokenizer.with_decoder(Some(ByteLevel::default())); tokenizer.add_tokens(&[AddedToken::from("ing", false).single_word(false)]); tokenizer.add_special_tokens(&[AddedToken::from("[ENT]", true).single_word(true)]); tokenizer diff --git a/tokenizers/tests/common/mod.rs b/tokenizers/tests/common/mod.rs index ea70aba20..444cfb269 100644 --- a/tokenizers/tests/common/mod.rs +++ b/tokenizers/tests/common/mod.rs @@ -23,9 +23,9 @@ pub fn get_byte_level_bpe() -> BPE { pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer { let mut tokenizer = Tokenizer::new(get_byte_level_bpe()); tokenizer - .with_pre_tokenizer(ByteLevel::default().add_prefix_space(add_prefix_space)) - .with_decoder(ByteLevel::default()) - .with_post_processor(ByteLevel::default().trim_offsets(trim_offsets)); + .with_pre_tokenizer(Some(ByteLevel::default().add_prefix_space(add_prefix_space))) + .with_decoder(Some(ByteLevel::default())) + .with_post_processor(Some(ByteLevel::default().trim_offsets(trim_offsets))); tokenizer } @@ -43,13 +43,13 @@ pub fn get_bert() -> Tokenizer { let sep = tokenizer.get_model().token_to_id("[SEP]").unwrap(); let cls = tokenizer.get_model().token_to_id("[CLS]").unwrap(); tokenizer - .with_normalizer(BertNormalizer::default()) - .with_pre_tokenizer(BertPreTokenizer) - .with_decoder(WordPieceDecoder::default()) - .with_post_processor(BertProcessing::new( + .with_normalizer(Some(BertNormalizer::default())) + .with_pre_tokenizer(Some(BertPreTokenizer)) + .with_decoder(Some(WordPieceDecoder::default())) + .with_post_processor(Some(BertProcessing::new( (String::from("[SEP]"), sep), (String::from("[CLS]"), cls), - )); + ))); tokenizer } diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index 8c0f61116..283d53bd5 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -158,13 +158,13 @@ fn quicktour() -> tokenizers::Result<()> { ("[SEP]", tokenizer.token_to_id("[SEP]").unwrap()), ]; tokenizer.with_post_processor( - TemplateProcessing::builder() + Some(TemplateProcessing::builder() .try_single("[CLS] $A [SEP]") .unwrap() .try_pair("[CLS] $A [SEP] $B:1 [SEP]:1") .unwrap() .special_tokens(special_tokens) - .build()?, + .build()?), ); // END quicktour_init_template_processing // START quicktour_print_special_tokens @@ -331,14 +331,14 @@ fn pipeline() -> tokenizers::Result<()> { use tokenizers::processors::template::TemplateProcessing; tokenizer.with_post_processor( - TemplateProcessing::builder() + Some(TemplateProcessing::builder() .try_single("[CLS] $A [SEP]") .unwrap() .try_pair("[CLS] $A [SEP] $B:1 [SEP]:1") .unwrap() .special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)]) .build() - .unwrap(), + .unwrap()), ); // END pipeline_setup_processor // START pipeline_test_decoding @@ -390,14 +390,14 @@ fn train_pipeline_bert() -> tokenizers::Result<()> { use tokenizers::processors::template::TemplateProcessing; bert_tokenizer.with_post_processor( - TemplateProcessing::builder() + Some(TemplateProcessing::builder() .try_single("[CLS] $A [SEP]") .unwrap() .try_pair("[CLS] $A [SEP] $B:1 [SEP]:1") .unwrap() .special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)]) .build() - .unwrap(), + .unwrap()), ); // END bert_setup_processor // START bert_train_tokenizer @@ -450,7 +450,7 @@ fn pipeline_bert() -> tokenizers::Result<()> { // START bert_proper_decoding use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder; - bert_tokenizer.with_decoder(WordPieceDecoder::default()); + bert_tokenizer.with_decoder(Some(WordPieceDecoder::default())); let decoded = bert_tokenizer.decode(output.get_ids(), true)?; // "welcome to the tokenizers library." // END bert_proper_decoding From d30e10fd2fa4c9c3c5d3e2f243658f45d2967615 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 6 Aug 2024 16:32:13 +0200 Subject: [PATCH 12/14] lint --- tokenizers/src/tokenizer/mod.rs | 2 +- tokenizers/tests/common/mod.rs | 4 +++- tokenizers/tests/documentation.rs | 24 ++++++++++++------------ 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index a3e2edafa..1c2ad6e0b 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -583,7 +583,7 @@ where /// Set the decoder pub fn with_decoder(&mut self, decoder: Option>) -> &mut Self { - self.decoder = decoder.map(|dec|dec.into()); + self.decoder = decoder.map(|dec| dec.into()); self } diff --git a/tokenizers/tests/common/mod.rs b/tokenizers/tests/common/mod.rs index 444cfb269..26129699b 100644 --- a/tokenizers/tests/common/mod.rs +++ b/tokenizers/tests/common/mod.rs @@ -23,7 +23,9 @@ pub fn get_byte_level_bpe() -> BPE { pub fn get_byte_level(add_prefix_space: bool, trim_offsets: bool) -> Tokenizer { let mut tokenizer = Tokenizer::new(get_byte_level_bpe()); tokenizer - .with_pre_tokenizer(Some(ByteLevel::default().add_prefix_space(add_prefix_space))) + .with_pre_tokenizer(Some( + ByteLevel::default().add_prefix_space(add_prefix_space), + )) .with_decoder(Some(ByteLevel::default())) .with_post_processor(Some(ByteLevel::default().trim_offsets(trim_offsets))); diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index 283d53bd5..c0c471a93 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -157,15 +157,15 @@ fn quicktour() -> tokenizers::Result<()> { ("[CLS]", tokenizer.token_to_id("[CLS]").unwrap()), ("[SEP]", tokenizer.token_to_id("[SEP]").unwrap()), ]; - tokenizer.with_post_processor( - Some(TemplateProcessing::builder() + tokenizer.with_post_processor(Some( + TemplateProcessing::builder() .try_single("[CLS] $A [SEP]") .unwrap() .try_pair("[CLS] $A [SEP] $B:1 [SEP]:1") .unwrap() .special_tokens(special_tokens) - .build()?), - ); + .build()?, + )); // END quicktour_init_template_processing // START quicktour_print_special_tokens let output = tokenizer.encode("Hello, y'all! How are you 😁 ?", true)?; @@ -330,16 +330,16 @@ fn pipeline() -> tokenizers::Result<()> { // START pipeline_setup_processor use tokenizers::processors::template::TemplateProcessing; - tokenizer.with_post_processor( - Some(TemplateProcessing::builder() + tokenizer.with_post_processor(Some( + TemplateProcessing::builder() .try_single("[CLS] $A [SEP]") .unwrap() .try_pair("[CLS] $A [SEP] $B:1 [SEP]:1") .unwrap() .special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)]) .build() - .unwrap()), - ); + .unwrap(), + )); // END pipeline_setup_processor // START pipeline_test_decoding let output = tokenizer.encode("Hello, y'all! How are you 😁 ?", true)?; @@ -389,16 +389,16 @@ fn train_pipeline_bert() -> tokenizers::Result<()> { // START bert_setup_processor use tokenizers::processors::template::TemplateProcessing; - bert_tokenizer.with_post_processor( - Some(TemplateProcessing::builder() + bert_tokenizer.with_post_processor(Some( + TemplateProcessing::builder() .try_single("[CLS] $A [SEP]") .unwrap() .try_pair("[CLS] $A [SEP] $B:1 [SEP]:1") .unwrap() .special_tokens(vec![("[CLS]", 1), ("[SEP]", 2)]) .build() - .unwrap()), - ); + .unwrap(), + )); // END bert_setup_processor // START bert_train_tokenizer use tokenizers::models::{wordpiece::WordPieceTrainer, TrainerWrapper}; From c4b1470634ffe0c80078df3a14e6ecf5ed77f6b6 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 6 Aug 2024 16:54:41 +0200 Subject: [PATCH 13/14] fmt --- bindings/node/src/tokenizer.rs | 6 +++--- bindings/python/src/tokenizer.rs | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/bindings/node/src/tokenizer.rs b/bindings/node/src/tokenizer.rs index a34791c76..e5342084b 100644 --- a/bindings/node/src/tokenizer.rs +++ b/bindings/node/src/tokenizer.rs @@ -208,7 +208,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_pre_tokenizer(Some(*pre_tokenizer).clone()); + .with_pre_tokenizer(Some(*pre_tokenizer.clone())); } #[napi] @@ -231,7 +231,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_post_processor(Some(*post_processor).clone()); + .with_post_processor(Some(*post_processor.clone())); } #[napi] @@ -240,7 +240,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_normalizer(Some(*normalizer).clone()); + .with_normalizer(Some(*normalizer.clone())); } #[napi] diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 9fd596022..8b3e30617 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1377,7 +1377,8 @@ impl PyTokenizer { /// Set the :class:`~tokenizers.normalizers.Normalizer` #[setter] fn set_pre_tokenizer(&mut self, pretok: Option>) { - self.tokenizer.with_pre_tokenizer(pretok.map(|pre| pre.clone())); + self.tokenizer + .with_pre_tokenizer(pretok.map(|pre| pre.clone())); } /// The `optional` :class:`~tokenizers.processors.PostProcessor` in use by the Tokenizer @@ -1393,7 +1394,8 @@ impl PyTokenizer { /// Set the :class:`~tokenizers.processors.PostProcessor` #[setter] fn set_post_processor(&mut self, processor: Option>) { - self.tokenizer.with_post_processor(processor.map(|p| p.clone())); + self.tokenizer + .with_post_processor(processor.map(|p| p.clone())); } /// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer From a3308076a955982d8ec485b2babe98b693094940 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 6 Aug 2024 17:24:01 +0200 Subject: [PATCH 14/14] node be happy I am fixing you --- bindings/node/src/tokenizer.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bindings/node/src/tokenizer.rs b/bindings/node/src/tokenizer.rs index e5342084b..4acbcac83 100644 --- a/bindings/node/src/tokenizer.rs +++ b/bindings/node/src/tokenizer.rs @@ -208,7 +208,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_pre_tokenizer(Some(*pre_tokenizer.clone())); + .with_pre_tokenizer(Some((*pre_tokenizer).clone())); } #[napi] @@ -217,7 +217,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_decoder(Some(*decoder).clone()); + .with_decoder(Some((*decoder).clone())); } #[napi] @@ -231,7 +231,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_post_processor(Some(*post_processor.clone())); + .with_post_processor(Some((*post_processor).clone())); } #[napi] @@ -240,7 +240,7 @@ impl Tokenizer { .tokenizer .write() .unwrap() - .with_normalizer(Some(*normalizer.clone())); + .with_normalizer(Some((*normalizer).clone())); } #[napi]