From fc0f0656f07870a56c3352aa1a3b19ac14be26e7 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 12 Jul 2024 08:08:14 +0200 Subject: [PATCH 01/16] allow to assign a new token --- tokenizers/src/tokenizer/added_vocabulary.rs | 21 ++++++++++++++++++++ tokenizers/src/tokenizer/mod.rs | 10 ++++++++++ 2 files changed, 31 insertions(+) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index a0c2f4542..32fc76a26 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -255,6 +255,27 @@ impl AddedVocabulary { self.add_tokens(tokens, model, normalizer) } + /// Re assigns a token's content to a new content. This helps users how want to + /// use reserved tokens (which usually are in the original vocab, and in the added vocab) + pub fn assign_token( + &mut self, + old_token_content: &[AddedToken], + new_token_content: &[AddedToken], + model: &impl Model, + normalizer: Option<&N>, + ) { + for (old, new) in old_token_content.iter().zip(new_token_content.iter()) { + if let Some(id) = self.token_to_id(old.content.as_str(), model) { + self.added_tokens_map_r + .entry(id) + .and_modify(|t| *t = new.clone()); + self.refresh_added_tokens(model, normalizer); + } else { + error!("Error: you tried to re-assign a token that does not exist in the added vocab. Make sure {:?} is first added to the vocab", old.content.clone()) + } + } + } + /// Add some tokens to the vocabulary pub fn add_tokens( &mut self, diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 49bc539a2..1d2ee3995 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -960,6 +960,16 @@ where self.added_vocabulary .add_tokens(tokens, &self.model, self.normalizer.as_ref()) } + + /// Assign a new token + pub fn assign_token(&mut self, old_tokens: &[AddedToken], new_tokens: &[AddedToken]) { + self.added_vocabulary.assign_token( + old_tokens, + new_tokens, + &self.model, + self.normalizer.as_ref(), + ) + } } impl TokenizerImpl From 97e8818ecf43b418cbd97d4ac08762a154dfe2d1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 12 Jul 2024 08:52:04 +0200 Subject: [PATCH 02/16] add python bindongs as well --- .../python/py_src/tokenizers/__init__.pyi | 17 ++++++ bindings/python/src/tokenizer.rs | 58 +++++++++++++++++++ tokenizers/src/tokenizer/added_vocabulary.rs | 2 +- tokenizers/src/tokenizer/mod.rs | 4 +- 4 files changed, 78 insertions(+), 3 deletions(-) diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi index 0ad96fc8a..a480923ef 100644 --- a/bindings/python/py_src/tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/__init__.pyi @@ -725,6 +725,23 @@ class Tokenizer: """ pass + def assing_tokens(self, old_tokens, new_tokens): + """ + Add the given tokens to the vocabulary + + The given tokens are added only if they don't already exist in the vocabulary. + Each token then gets a new attributed id. + + Args: + tokens (A :obj:`List` of :class:`~tokenizers.AddedToken` or :obj:`str`): + The list of tokens we want to add to the vocabulary. Each token can be either a + string or an instance of :class:`~tokenizers.AddedToken` for more customization. + + Returns: + :obj:`int`: The number of tokens that were created in the vocabulary + """ + pass + def decode(self, ids, skip_special_tokens=True): """ Decode the given list of ids back to a string diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 24a68c6bb..34852c129 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1237,6 +1237,64 @@ impl PyTokenizer { Ok(self.tokenizer.add_tokens(&tokens)) } + /// Add the given tokens to the vocabulary + /// + /// The given tokens are added only if they don't already exist in the vocabulary. + /// Each token then gets a new attributed id. + /// + /// Args: + /// tokens (A :obj:`List` of :class:`~tokenizers.AddedToken` or :obj:`str`): + /// The list of tokens we want to add to the vocabulary. Each token can be either a + /// string or an instance of :class:`~tokenizers.AddedToken` for more customization. + /// + /// Returns: + /// :obj:`int`: The number of tokens that were created in the vocabulary + #[pyo3(text_signature = "(self, old_tokens, new_tokens)")] + fn assing_tokens( + &mut self, + old_tokens: &Bound<'_, PyList>, + new_tokens: &Bound<'_, PyList>, + ) -> PyResult<()> { + use pyo3::exceptions::PyTypeError; + if old_tokens.len() != new_tokens.len() { + return Err(PyTypeError::new_err( + "old_tokens and new_tokens must have the same length", + )); + } + + let mut processed_old_tokens = Vec::with_capacity(old_tokens.len()); + let mut processed_new_tokens = Vec::with_capacity(new_tokens.len()); + for (old, new) in old_tokens.iter().zip(new_tokens.iter()) { + let old_token = if let Ok(content) = old.extract::<&str>() { + PyAddedToken::from(content.to_string(), Some(false)).get_token() + } else if let Ok(token) = old.extract::>() { + token.get_token() + } else { + return Err(PyTypeError::new_err( + "old_tokens must be a List[Union[str, AddedToken]]", + )); + }; + + let new_token = if let Ok(content) = new.extract::<&str>() { + let mut updated_token = old_token.clone(); + updated_token.content = content.to_string(); + updated_token + } else if let Ok(token) = new.extract::>() { + token.get_token() + } else { + return Err(PyTypeError::new_err( + "new_tokens must be a List[Union[str, AddedToken]]", + )); + }; + + processed_old_tokens.push(old_token); + processed_new_tokens.push(new_token); + } + + Ok(self + .tokenizer + .assign_tokens(&processed_old_tokens, &processed_new_tokens)) + } /// Add the given special tokens to the Tokenizer. /// /// If these tokens are already part of the vocabulary, it just let the Tokenizer know about diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 32fc76a26..379a3075b 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -257,7 +257,7 @@ impl AddedVocabulary { /// Re assigns a token's content to a new content. This helps users how want to /// use reserved tokens (which usually are in the original vocab, and in the added vocab) - pub fn assign_token( + pub fn assign_tokens( &mut self, old_token_content: &[AddedToken], new_token_content: &[AddedToken], diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 1d2ee3995..c3cc6ede9 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -962,8 +962,8 @@ where } /// Assign a new token - pub fn assign_token(&mut self, old_tokens: &[AddedToken], new_tokens: &[AddedToken]) { - self.added_vocabulary.assign_token( + pub fn assign_tokens(&mut self, old_tokens: &[AddedToken], new_tokens: &[AddedToken]) { + self.added_vocabulary.assign_tokens( old_tokens, new_tokens, &self.model, From ddab9013382767ec70975e1ef525482c734e90cb Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 12 Jul 2024 09:49:42 +0200 Subject: [PATCH 03/16] current update --- bindings/python/src/tokenizer.rs | 2 +- tokenizers/src/tokenizer/added_vocabulary.rs | 135 +++++-------------- 2 files changed, 36 insertions(+), 101 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 34852c129..a2b191673 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1250,7 +1250,7 @@ impl PyTokenizer { /// Returns: /// :obj:`int`: The number of tokens that were created in the vocabulary #[pyo3(text_signature = "(self, old_tokens, new_tokens)")] - fn assing_tokens( + fn assign_tokens( &mut self, old_tokens: &Bound<'_, PyList>, new_tokens: &Bound<'_, PyList>, diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 379a3075b..dd8f61952 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -4,8 +4,10 @@ use super::{ use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; use regex::Regex; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; -use std::collections::{HashMap, HashSet}; - +use std::{ + collections::{HashMap, HashSet}, + sync::{Arc, Mutex}, +}; /// Represent a token added by the user on top of the existing Model vocabulary. /// AddedToken can be configured to specify the behavior they should have in various situations /// like: @@ -142,19 +144,12 @@ fn space_rightmost_at_start(sentence: &str) -> usize { pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. - added_tokens_map: HashMap, + added_tokens_map: Arc>>, /// Contains the mapping from ID to AddedToken for all the added tokens, both special /// and classic. - added_tokens_map_r: HashMap, - + added_tokens_map_r: Arc>>, /// Contains only the classic AddedToken, in the specific order the user gave them. added_tokens: Vec, - /// Contains only the special AddedToken, in the specific order the user gave them. - special_tokens: Vec, - - /// A Set, containing all the special token for easy access while decoding. This let's - /// us remove them easily with an O(1) complexity. - special_tokens_set: HashSet, /// A RegexSet containing all the non-normalized patterns used to split on AddedTokens split_trie: MatchingSet, @@ -176,11 +171,9 @@ impl AddedVocabulary { .build::<_, &&[u8]>([]) .expect("The normalized trie should build correctly"); Self { - added_tokens_map: HashMap::new(), - added_tokens_map_r: HashMap::new(), + added_tokens_map: Arc::new(Mutex::new(HashMap::new())), + added_tokens_map_r: Arc::new(Mutex::new(HashMap::new())), added_tokens: vec![], - special_tokens: vec![], - special_tokens_set: HashSet::new(), split_trie: (trie, vec![]), split_normalized_trie: (normalized_trie, vec![]), encode_special_tokens: false, @@ -189,30 +182,29 @@ impl AddedVocabulary { /// Size of the additional vocabulary #[allow(dead_code)] // Suppress the "method is never used" warning pub fn len(&self) -> usize { - self.added_tokens_map.len() + self.added_tokens_map.lock().unwrap().len() } /// Whether or not this vocabulary is empty pub fn is_empty(&self) -> bool { - self.added_tokens_map.is_empty() + self.added_tokens_map.lock().unwrap().is_empty() } /// Get the additional vocabulary - pub fn get_vocab(&self) -> &HashMap { - &self.added_tokens_map + pub fn get_vocab(&self) -> HashMap { + self.added_tokens_map.lock().unwrap().clone() } /// Get the additional vocabulary with the AddedTokens - pub fn get_added_tokens_decoder(&self) -> &HashMap { - &self.added_tokens_map_r + pub fn get_added_tokens_decoder(&self) -> HashMap { + self.added_tokens_map_r.lock().unwrap().clone() } /// Get the id matching one of our token if it exists pub fn token_to_id(&self, token: &str, model: &impl Model) -> Option { - self.added_tokens_map - .get(token) - .copied() - .or_else(|| model.token_to_id(token)) + let added_tokens_map = self.added_tokens_map.lock().unwrap(); + let id = added_tokens_map.get(token).copied(); + id.or_else(|| model.token_to_id(token)) } /// Get the token matching the given id if it exists @@ -220,15 +212,6 @@ impl AddedVocabulary { since = "0.19.0", note = "please use `added_vocabulary.simple_id_to_token(id).or_else(|| model.id_to_token(id)` instead" )] - pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option { - self.added_tokens_map_r - .get(&id) - .map(|t| t.content.clone()) - .or_else(|| model.id_to_token(id)) - } - - pub fn simple_id_to_token(&self, id: u32) -> Option { - self.added_tokens_map_r.get(&id).map(|t| t.content.clone()) } // @@ -253,6 +236,11 @@ impl AddedVocabulary { normalizer: Option<&N>, ) -> usize { self.add_tokens(tokens, model, normalizer) + /// Get the token matching the given id if it exists + pub fn simple_id_to_token(&self, id: &u32) -> Option { + let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap(); + let token = added_tokens_map_r.get(id).map(|t| t.content.clone()); + token } /// Re assigns a token's content to a new content. This helps users how want to @@ -276,69 +264,16 @@ impl AddedVocabulary { } } - /// Add some tokens to the vocabulary - pub fn add_tokens( - &mut self, - tokens: &[AddedToken], - model: &impl Model, - normalizer: Option<&N>, - ) -> usize { - // Handle special tokens (if any) - for token in tokens { - if token.special - && !token.content.is_empty() - && !self.special_tokens_set.contains(&token.content) - { - self.special_tokens.push(token.to_owned()); - self.special_tokens_set.insert(token.content.clone()); - } - } - - // Then we delegate to `add_tokens`, that will take care of refreshing added tokens too. - let mut ignored = 0; - for token in tokens { - if token.content.is_empty() || self.added_tokens_map_r.values().any(|val| val == token) - { - ignored += 1; - continue; - } - // If a token is already part of the vocabulary, we mark it as added - let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) { - new_id - } else { - self.added_tokens_map.values().cloned().max().map_or( - model.get_vocab_size() as u32, - |max| { - if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 { - max + 1 - } else { - model.get_vocab_size() as u32 - } - }, - ) - }; - // Make sure we modify the previous entry - self.added_tokens_map - .entry(token.content.clone()) - .and_modify(|old_id| *old_id = new_id) - .or_insert_with(|| new_id); - // Update the current revert operation - self.added_tokens_map_r - .entry(new_id) - .and_modify(|t| *t = token.clone()) - .or_insert_with(|| token.clone()); - // Make sure to remove previous entry (if the token gets a new id) - - // Finally add the token to the classic set if special - if !self.special_tokens_set.contains(&token.content) { - self.added_tokens.push(token.clone()); - } - } + /// Add a token to the added vocabulary + pub fn add_token(&mut self, token: &AddedToken) { + let mut added_tokens_map = self.added_tokens_map.lock().unwrap(); + let mut added_tokens_map_r = self.added_tokens_map_r.lock().unwrap(); - self.refresh_added_tokens(model, normalizer); + let id = added_tokens_map.len() as u32; + added_tokens_map.insert(token.content.clone(), id); + added_tokens_map_r.insert(id, token.clone()); - // Return the number of added tokens - tokens.len() - ignored + self.refresh_added_tokens(); } /// Reconstruct our internal RegexSet when new tokens are added to the vocabulary. @@ -348,9 +283,8 @@ impl AddedVocabulary { fn refresh_added_tokens(&mut self, model: &impl Model, normalizer: Option<&N>) { type TupleTokenId<'a> = (&'a AddedToken, u32); let (normalized, non_normalized): (Vec, Vec) = self - .special_tokens + .added_tokens .iter() - .chain(self.added_tokens.iter()) .map(|token| { ( token, @@ -402,10 +336,9 @@ impl AddedVocabulary { let mut stop = mat.end(); let aho_id = mat.pattern(); let id = split_re.1[aho_id]; - let added_token = &self.added_tokens_map_r.get(&id).unwrap(); + let added_token = self.added_tokens_map_r.lock().unwrap().get(&id).unwrap(); - if self.encode_special_tokens && self.special_tokens_set.contains(&added_token.content) - { + if self.encode_special_tokens && added_token.special { continue; } @@ -543,6 +476,8 @@ impl Serialize for AddedVocabulary { { let mut added_tokens = self .added_tokens_map_r + .lock() + .unwrap() .iter() .map(|(id, token)| AddedTokenWithId { id: *id, From b359bde47a49b95b3262c1e8e9a646aec2c77824 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 12 Jul 2024 09:58:06 +0200 Subject: [PATCH 04/16] nit --- tokenizers/src/tokenizer/added_vocabulary.rs | 29 ++++++-------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index dd8f61952..4d52ce77c 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -207,13 +207,6 @@ impl AddedVocabulary { id.or_else(|| model.token_to_id(token)) } - /// Get the token matching the given id if it exists - #[deprecated( - since = "0.19.0", - note = "please use `added_vocabulary.simple_id_to_token(id).or_else(|| model.id_to_token(id)` instead" - )] - } - // pub fn set_encode_special_tokens(&mut self, value: bool) { self.encode_special_tokens = value; @@ -225,7 +218,12 @@ impl AddedVocabulary { /// Check if a token is a special token pub fn is_special_token(&self, token: &str) -> bool { - self.special_tokens_set.contains(token) + self.added_tokens_map_r + .lock() + .unwrap() + .get(self.added_tokens_map.lock().unwrap().get(token).unwrap()) + .unwrap() + .special } /// Add some special tokens to the vocabulary @@ -236,6 +234,7 @@ impl AddedVocabulary { normalizer: Option<&N>, ) -> usize { self.add_tokens(tokens, model, normalizer) + } /// Get the token matching the given id if it exists pub fn simple_id_to_token(&self, id: &u32) -> Option { let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap(); @@ -255,6 +254,8 @@ impl AddedVocabulary { for (old, new) in old_token_content.iter().zip(new_token_content.iter()) { if let Some(id) = self.token_to_id(old.content.as_str(), model) { self.added_tokens_map_r + .lock() + .unwrap() .entry(id) .and_modify(|t| *t = new.clone()); self.refresh_added_tokens(model, normalizer); @@ -264,18 +265,6 @@ impl AddedVocabulary { } } - /// Add a token to the added vocabulary - pub fn add_token(&mut self, token: &AddedToken) { - let mut added_tokens_map = self.added_tokens_map.lock().unwrap(); - let mut added_tokens_map_r = self.added_tokens_map_r.lock().unwrap(); - - let id = added_tokens_map.len() as u32; - added_tokens_map.insert(token.content.clone(), id); - added_tokens_map_r.insert(id, token.clone()); - - self.refresh_added_tokens(); - } - /// Reconstruct our internal RegexSet when new tokens are added to the vocabulary. /// /// We keep two different RegexSet, one that will take care of matching against the From 4794ed516fb19fa36cc7533179b1e7c2e043c671 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 12 Jul 2024 10:07:27 +0200 Subject: [PATCH 05/16] fix --- tokenizers/src/tokenizer/added_vocabulary.rs | 69 +++++++++++++++++++- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 4d52ce77c..9274b69b6 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -235,6 +235,71 @@ impl AddedVocabulary { ) -> usize { self.add_tokens(tokens, model, normalizer) } + + /// Add some tokens to the vocabulary + pub fn add_tokens( + &mut self, + tokens: &[AddedToken], + model: &impl Model, + normalizer: Option<&N>, + ) -> usize { + // Handle special tokens (if any) + + // Then we delegate to `add_tokens`, that will take care of refreshing added tokens too. + let mut ignored = 0; + for token in tokens { + if token.content.is_empty() + || self + .added_tokens_map_r + .lock() + .unwrap() + .values() + .any(|val| val == token) + { + ignored += 1; + continue; + } + // If a token is already part of the vocabulary, we mark it as added + let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) { + new_id + } else { + self.added_tokens_map + .lock() + .unwrap() + .values() + .cloned() + .max() + .map_or(model.get_vocab_size() as u32, |max| { + if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 { + max + 1 + } else { + model.get_vocab_size() as u32 + } + }) + }; + // Make sure we modify the previous entry + self.added_tokens_map + .lock() + .unwrap() + .entry(token.content.clone()) + .and_modify(|old_id| *old_id = new_id) + .or_insert_with(|| new_id); + // Update the current revert operation + self.added_tokens_map_r + .lock() + .unwrap() + .entry(new_id) + .and_modify(|t| *t = token.clone()) + .or_insert_with(|| token.clone()); + // Make sure to remove previous entry (if the token gets a new id) + } + + self.refresh_added_tokens(model, normalizer); + + // Return the number of added tokens + tokens.len() - ignored + } + /// Get the token matching the given id if it exists pub fn simple_id_to_token(&self, id: &u32) -> Option { let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap(); @@ -325,8 +390,8 @@ impl AddedVocabulary { let mut stop = mat.end(); let aho_id = mat.pattern(); let id = split_re.1[aho_id]; - let added_token = self.added_tokens_map_r.lock().unwrap().get(&id).unwrap(); - + let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap(); + let added_token = added_tokens_map_r.get(&id).unwrap(); if self.encode_special_tokens && added_token.special { continue; } From 4190db7dddfd16133780accb503171fb8c6e7447 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 12 Jul 2024 10:12:29 +0200 Subject: [PATCH 06/16] pass compilation --- tokenizers/src/tokenizer/added_vocabulary.rs | 6 +++--- tokenizers/src/tokenizer/mod.rs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 9274b69b6..144f2c44e 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -5,7 +5,7 @@ use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; use regex::Regex; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, sync::{Arc, Mutex}, }; /// Represent a token added by the user on top of the existing Model vocabulary. @@ -301,9 +301,9 @@ impl AddedVocabulary { } /// Get the token matching the given id if it exists - pub fn simple_id_to_token(&self, id: &u32) -> Option { + pub fn simple_id_to_token(&self, id: u32) -> Option { let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap(); - let token = added_tokens_map_r.get(id).map(|t| t.content.clone()); + let token = added_tokens_map_r.get(&id).map(|t| t.content.clone()); token } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index c3cc6ede9..aee256a42 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -667,7 +667,7 @@ where if !added_vocab.is_empty() { final_vocab.reserve(added_vocab.len()); for (token, id) in added_vocab { - final_vocab.insert(token.clone(), *id); + final_vocab.insert(token.clone(), id); } } } From 2d4b3735e44ac242956adc73bd5ed0c69338f407 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 12 Jul 2024 10:38:40 +0200 Subject: [PATCH 07/16] fix everything --- bindings/python/src/tokenizer.rs | 1 - tokenizers/src/tokenizer/added_vocabulary.rs | 20 ++++++++------------ tokenizers/src/tokenizer/mod.rs | 2 -- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index a2b191673..9b0b82dcf 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1290,7 +1290,6 @@ impl PyTokenizer { processed_old_tokens.push(old_token); processed_new_tokens.push(new_token); } - Ok(self .tokenizer .assign_tokens(&processed_old_tokens, &processed_new_tokens)) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 144f2c44e..84f4927a7 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -322,7 +322,7 @@ impl AddedVocabulary { .lock() .unwrap() .entry(id) - .and_modify(|t| *t = new.clone()); + .and_modify(|t| t.content = new.content.clone()); self.refresh_added_tokens(model, normalizer); } else { error!("Error: you tried to re-assign a token that does not exist in the added vocab. Make sure {:?} is first added to the vocab", old.content.clone()) @@ -336,17 +336,12 @@ impl AddedVocabulary { /// non-normalized string, and one matching against the normalized one. fn refresh_added_tokens(&mut self, model: &impl Model, normalizer: Option<&N>) { type TupleTokenId<'a> = (&'a AddedToken, u32); - let (normalized, non_normalized): (Vec, Vec) = self - .added_tokens - .iter() - .map(|token| { - ( - token, - self.token_to_id(&token.content, model) - .expect("Missing additional token"), - ) - }) - .partition(|(token, _)| token.normalized); + let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap().clone(); + let (normalized, non_normalized): (Vec, Vec) = + added_tokens_map_r + .iter() + .map(|(id, token)| (token, *id)) + .partition(|(token, _)| token.normalized); let (tokens, ids): (Vec<&AddedToken>, Vec) = non_normalized.into_iter().unzip(); let trie = AhoCorasickBuilder::new() @@ -363,6 +358,7 @@ impl AddedVocabulary { if let Some(n) = normalizer { n.normalize(&mut content).unwrap(); } + println!("{:?}", token); content }) .collect(); diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index aee256a42..c6433dc43 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -541,9 +541,7 @@ where model, post_processor: None, decoder: None, - added_vocabulary: AddedVocabulary::new(), - truncation: None, padding: None, } From 6d48e58219cd414a99888bfbf8e28630234654ea Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 12 Jul 2024 10:47:38 +0200 Subject: [PATCH 08/16] remove print --- tokenizers/src/tokenizer/added_vocabulary.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 84f4927a7..d3ca1a484 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -358,7 +358,6 @@ impl AddedVocabulary { if let Some(n) = normalizer { n.normalize(&mut content).unwrap(); } - println!("{:?}", token); content }) .collect(); From b5640a65cf59cf6c4ac2458dd01fc695cb0c7504 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 4 Oct 2024 14:46:42 +0200 Subject: [PATCH 09/16] simplify the logic --- bindings/python/src/tokenizer.rs | 24 +++++--------------- tokenizers/src/tokenizer/added_vocabulary.rs | 15 ++++++------ tokenizers/src/tokenizer/mod.rs | 5 ++-- 3 files changed, 16 insertions(+), 28 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 9b0b82dcf..499cbd770 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1250,21 +1250,11 @@ impl PyTokenizer { /// Returns: /// :obj:`int`: The number of tokens that were created in the vocabulary #[pyo3(text_signature = "(self, old_tokens, new_tokens)")] - fn assign_tokens( - &mut self, - old_tokens: &Bound<'_, PyList>, - new_tokens: &Bound<'_, PyList>, - ) -> PyResult<()> { + fn assign_tokens(&mut self, old_to_new_map: &Bound<'_, PyDict>) -> PyResult<()> { use pyo3::exceptions::PyTypeError; - if old_tokens.len() != new_tokens.len() { - return Err(PyTypeError::new_err( - "old_tokens and new_tokens must have the same length", - )); - } - let mut processed_old_tokens = Vec::with_capacity(old_tokens.len()); - let mut processed_new_tokens = Vec::with_capacity(new_tokens.len()); - for (old, new) in old_tokens.iter().zip(new_tokens.iter()) { + let mut processed_old_tokens = HashMap::with_capacity(old_to_new_map.len()); + for (old, new) in old_to_new_map.iter() { let old_token = if let Ok(content) = old.extract::<&str>() { PyAddedToken::from(content.to_string(), Some(false)).get_token() } else if let Ok(token) = old.extract::>() { @@ -1287,12 +1277,10 @@ impl PyTokenizer { )); }; - processed_old_tokens.push(old_token); - processed_new_tokens.push(new_token); + processed_old_tokens.insert(old_token, new_token); } - Ok(self - .tokenizer - .assign_tokens(&processed_old_tokens, &processed_new_tokens)) + self.tokenizer.assign_tokens(&processed_old_tokens); + Ok(()) } /// Add the given special tokens to the Tokenizer. /// diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index d3ca1a484..6f79ba660 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -311,25 +311,26 @@ impl AddedVocabulary { /// use reserved tokens (which usually are in the original vocab, and in the added vocab) pub fn assign_tokens( &mut self, - old_token_content: &[AddedToken], - new_token_content: &[AddedToken], + token_map: &HashMap, // HashMap of old token to new token model: &impl Model, normalizer: Option<&N>, ) { - for (old, new) in old_token_content.iter().zip(new_token_content.iter()) { - if let Some(id) = self.token_to_id(old.content.as_str(), model) { + for (old_token, new_token) in token_map.iter() { + if let Some(id) = self.token_to_id(old_token.content.as_str(), model) { self.added_tokens_map_r .lock() .unwrap() .entry(id) - .and_modify(|t| t.content = new.content.clone()); + .and_modify(|t| *t = new_token.clone()); // Replace entire entry with new_token self.refresh_added_tokens(model, normalizer); } else { - error!("Error: you tried to re-assign a token that does not exist in the added vocab. Make sure {:?} is first added to the vocab", old.content.clone()) + error!( + "Error: you tried to re-assign a token that does not exist in the added vocab. Make sure {:?} is first added to the vocab", + old_token.content.clone() + ) } } } - /// Reconstruct our internal RegexSet when new tokens are added to the vocabulary. /// /// We keep two different RegexSet, one that will take care of matching against the diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index c6433dc43..c24654fc8 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -960,10 +960,9 @@ where } /// Assign a new token - pub fn assign_tokens(&mut self, old_tokens: &[AddedToken], new_tokens: &[AddedToken]) { + pub fn assign_tokens(&mut self, old_to_new_map: &HashMap) { self.added_vocabulary.assign_tokens( - old_tokens, - new_tokens, + old_to_new_map, // HashMap of old token to new token &self.model, self.normalizer.as_ref(), ) From ed34ffd3342dd1c6b1226948297529dc2d6d2a8c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 4 Oct 2024 15:00:35 +0200 Subject: [PATCH 10/16] add a small test --- bindings/python/py_src/tokenizers/__init__.pyi | 2 +- bindings/python/tests/bindings/test_tokenizer.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi index a480923ef..bc2cb0cab 100644 --- a/bindings/python/py_src/tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/__init__.pyi @@ -725,7 +725,7 @@ class Tokenizer: """ pass - def assing_tokens(self, old_tokens, new_tokens): + def assign_tokens(self, old_tokens, new_tokens): """ Add the given tokens to the vocabulary diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 2118709a0..370fda087 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -562,6 +562,13 @@ def test_setting_to_none(self): tokenizer.pre_tokenizer = None assert tokenizer.pre_tokenizer == None + def test_re_assign_tokens(self): + tokenizer = Tokenizer.from_pretrained("t5-base") + tokenizer.assign_tokens({"": "my_new_token"}) + assert tokenizer.decode([32099]) == "my_new_token" + assert tokenizer.encode("").tokens == ["▁", "<", "extra", "_", "i", "d", "_", "0", ">", ""] + assert "my_new_token" in tokenizer.get_vocab(True).keys() + class TestTokenizerRepr: def test_repr(self): From 545d7230f485d2f199647c627f8b98f5a728c9bc Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 4 Oct 2024 15:24:14 +0200 Subject: [PATCH 11/16] fix unwrap errors --- tokenizers/src/tokenizer/added_vocabulary.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 6f79ba660..984201075 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -148,8 +148,6 @@ pub struct AddedVocabulary { /// Contains the mapping from ID to AddedToken for all the added tokens, both special /// and classic. added_tokens_map_r: Arc>>, - /// Contains only the classic AddedToken, in the specific order the user gave them. - added_tokens: Vec, /// A RegexSet containing all the non-normalized patterns used to split on AddedTokens split_trie: MatchingSet, @@ -173,7 +171,6 @@ impl AddedVocabulary { Self { added_tokens_map: Arc::new(Mutex::new(HashMap::new())), added_tokens_map_r: Arc::new(Mutex::new(HashMap::new())), - added_tokens: vec![], split_trie: (trie, vec![]), split_normalized_trie: (normalized_trie, vec![]), encode_special_tokens: false, @@ -218,12 +215,14 @@ impl AddedVocabulary { /// Check if a token is a special token pub fn is_special_token(&self, token: &str) -> bool { - self.added_tokens_map_r - .lock() - .unwrap() - .get(self.added_tokens_map.lock().unwrap().get(token).unwrap()) - .unwrap() - .special + let hash_map = &self.added_tokens_map_r.lock().unwrap(); + let revert_hash_map = &self.added_tokens_map.lock().unwrap(); + if let Some(id) = revert_hash_map.get(token) { + if let Some(token) = hash_map.get(id) { + return token.special; + } + } + false } /// Add some special tokens to the vocabulary @@ -335,7 +334,7 @@ impl AddedVocabulary { /// /// We keep two different RegexSet, one that will take care of matching against the /// non-normalized string, and one matching against the normalized one. - fn refresh_added_tokens(&mut self, model: &impl Model, normalizer: Option<&N>) { + fn refresh_added_tokens(&mut self, _model: &impl Model, normalizer: Option<&N>) { type TupleTokenId<'a> = (&'a AddedToken, u32); let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap().clone(); let (normalized, non_normalized): (Vec, Vec) = From ee7ce80e0b65fd53c186ab4b36e1eba38383d5a4 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 4 Oct 2024 15:55:43 +0200 Subject: [PATCH 12/16] forgot to remove from added tokens map! --- bindings/python/tests/bindings/test_tokenizer.py | 3 +++ tokenizers/src/tokenizer/added_vocabulary.rs | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 370fda087..9a1fd7272 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -566,6 +566,9 @@ def test_re_assign_tokens(self): tokenizer = Tokenizer.from_pretrained("t5-base") tokenizer.assign_tokens({"": "my_new_token"}) assert tokenizer.decode([32099]) == "my_new_token" + assert tokenizer.encode("my_new_token").tokens == ["my_new_token", ""] + assert tokenizer.encode("my_new_token").ids == [32099, 1] + assert tokenizer.encode("").ids == [0, 1] assert tokenizer.encode("").tokens == ["▁", "<", "extra", "_", "i", "d", "_", "0", ">", ""] assert "my_new_token" in tokenizer.get_vocab(True).keys() diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 984201075..e22249048 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -321,6 +321,10 @@ impl AddedVocabulary { .unwrap() .entry(id) .and_modify(|t| *t = new_token.clone()); // Replace entire entry with new_token + self.added_tokens_map + .lock() + .unwrap() + .remove(old_token.content.as_str()); self.refresh_added_tokens(model, normalizer); } else { error!( From e8933fa5b996c11d9a8b61c7549b23e639c88b8d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 5 Oct 2024 17:16:31 +0200 Subject: [PATCH 13/16] potential initial solution for the annoying unigram model :) --- .../python/tests/bindings/test_tokenizer.py | 22 ++++++++++++----- tokenizers/src/models/unigram/model.rs | 24 ++++++++++++------- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 9a1fd7272..0fb2cbd7d 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -562,15 +562,25 @@ def test_setting_to_none(self): tokenizer.pre_tokenizer = None assert tokenizer.pre_tokenizer == None - def test_re_assign_tokens(self): + def test_re_assign_tokens_bpe(self): + tokenizer = Tokenizer.from_pretrained("gpt2") + tokenizer.assign_tokens({"<|endoftext|>": "my_new_token"}) + assert tokenizer.decode([50256]) == "my_new_token" + assert tokenizer.encode("my_new_token").tokens == ["my_new_token"] + assert tokenizer.encode("my_new_token").ids == [50256] + assert tokenizer.encode("<|endoftext|>").ids == [27, 91, 437, 1659, 5239, 91, 29] + assert tokenizer.encode("<|endoftext|>").tokens == ["<", "|", "end", "of", "text", "|", ">"] + assert "my_new_token" in {k.content for k in tokenizer.get_added_tokens_decoder().values()} + + def test_re_assign_tokens_unigram(self): tokenizer = Tokenizer.from_pretrained("t5-base") tokenizer.assign_tokens({"": "my_new_token"}) assert tokenizer.decode([32099]) == "my_new_token" - assert tokenizer.encode("my_new_token").tokens == ["my_new_token", ""] - assert tokenizer.encode("my_new_token").ids == [32099, 1] - assert tokenizer.encode("").ids == [0, 1] - assert tokenizer.encode("").tokens == ["▁", "<", "extra", "_", "i", "d", "_", "0", ">", ""] - assert "my_new_token" in tokenizer.get_vocab(True).keys() + assert tokenizer.encode("my_new_token").tokens == ["my_new_token"] + assert tokenizer.encode("my_new_token").ids == [32099] + assert tokenizer.encode("").ids == [27, 91, 437, 1659, 5239, 91, 29] + assert tokenizer.encode("").tokens == ["<", "|", "end", "of", "text", "|", ">"] + assert "my_new_token" in {k.content for k in tokenizer.get_added_tokens_decoder().values()} class TestTokenizerRepr: diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index dba5a0400..4a5371738 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -3,8 +3,11 @@ use super::{ trainer::UnigramTrainer, trie::{Trie, TrieBuilder}, }; -use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::Cache; +use crate::{ + tokenizer::{Model, Result, Token}, + AddedVocabulary, +}; use std::collections::HashMap; use std::convert::TryInto; @@ -81,7 +84,7 @@ pub enum UnigramError { impl Default for Unigram { fn default() -> Self { let vocab = vec![("".to_string(), 0.0)]; - Self::from(vocab, Some(0), false).unwrap() + Self::from(vocab, Some(0), false, &AddedVocabulary::default()).unwrap() } } @@ -96,6 +99,7 @@ impl Unigram { vocab: Vec<(String, f64)>, unk_id: Option, byte_fallback: bool, + added_tokens: &AddedVocabulary, ) -> Result { let n = vocab.len(); let mut token_to_ids: TokenMap = HashMap::new(); @@ -114,11 +118,13 @@ impl Unigram { let mut min_score = f64::INFINITY; for (id, (token, score)) in vocab.iter().enumerate() { - token_to_ids.insert(token.to_string(), id as u32); - let bytes: Vec = token.bytes().collect(); - builder.push(&bytes); - if score < &min_score { - min_score = *score; + if !added_tokens.is_special_token(token) { + token_to_ids.insert(token.to_string(), id as u32); + let bytes: Vec = token.bytes().collect(); + builder.push(&bytes); + if score < &min_score { + min_score = *score; + } } } let trie = builder.build(); @@ -480,7 +486,7 @@ mod tests { #[test] fn test_populate_nodes_unk() { let pieces = vec![("".to_string(), 0.0)]; - let model = Unigram::from(pieces, Some(0), false).unwrap(); + let model = Unigram::from(pieces, Some(0), false, &AddedVocabulary::default()).unwrap(); let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id); model.populate_nodes(&mut lattice); @@ -505,7 +511,7 @@ mod tests { ("ab".to_string(), 0.3), ("bc".to_string(), 0.4), ]; - let model = Unigram::from(pieces, Some(0), false).unwrap(); + let model = Unigram::from(pieces, Some(0), false, &AddedVocabulary::default()).unwrap(); let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id); model.populate_nodes(&mut lattice); From 0475c057dd8f5d5141b15f47ad16358ddc5a9f6b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 5 Oct 2024 17:17:52 +0200 Subject: [PATCH 14/16] fix added vocab tests --- tokenizers/src/tokenizer/added_vocabulary.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index e22249048..f91a7b82f 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -724,15 +724,15 @@ mod tests { assert_eq!(vocab.len(), 3); // New token was added assert!(vocab.is_special_token("test")); assert_eq!( - *vocab.get_added_tokens_decoder(), + vocab.get_added_tokens_decoder(), HashMap::from([ (0, AddedToken::from("test", true)), (2, AddedToken::from("added_token_1", true)), (3, AddedToken::from("added_token_2", true)), ]) ); - assert!(vocab.added_tokens_map.contains_key("test")); - assert!(vocab.added_tokens_map_r.contains_key(&0)); + assert!(vocab.added_tokens_map.lock().unwrap().contains_key("test")); + assert!(vocab.added_tokens_map_r.lock().unwrap().contains_key(&0)); vocab.add_tokens( &[ From 167ecdebfb217316a156cf2500e44aa354d30d53 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 5 Oct 2024 17:56:06 +0200 Subject: [PATCH 15/16] small fixed --- tokenizers/src/models/unigram/serialization.rs | 8 +++++--- tokenizers/src/models/unigram/trainer.rs | 10 +++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tokenizers/src/models/unigram/serialization.rs b/tokenizers/src/models/unigram/serialization.rs index a6e56b735..1ad95002e 100644 --- a/tokenizers/src/models/unigram/serialization.rs +++ b/tokenizers/src/models/unigram/serialization.rs @@ -78,12 +78,14 @@ impl<'de> Visitor<'de> for UnigramVisitor { #[cfg(test)] mod test { + use crate::AddedVocabulary; + use super::*; #[test] fn test_serialization() { let vocab = vec![("".to_string(), 0.0), ("a".to_string(), -0.5)]; - let model = Unigram::from(vocab, Some(0), false).unwrap(); + let model = Unigram::from(vocab, Some(0), false, &AddedVocabulary::default()).unwrap(); let data = serde_json::to_string(&model).unwrap(); let reconstructed = serde_json::from_str(&data).unwrap(); @@ -94,7 +96,7 @@ mod test { #[test] fn test_serialization_unk_id_not_zero() { let vocab = vec![("a".to_string(), -0.5), ("".to_string(), 0.0)]; - let model = Unigram::from(vocab, Some(1), false).unwrap(); + let model = Unigram::from(vocab, Some(1), false, &AddedVocabulary::default()).unwrap(); let data = serde_json::to_string(&model).unwrap(); let reconstructed = serde_json::from_str(&data).unwrap(); @@ -105,7 +107,7 @@ mod test { #[test] fn test_serialization_no_unk_id() { let vocab = vec![("a".to_string(), -0.5)]; - let model = Unigram::from(vocab, None, false).unwrap(); + let model = Unigram::from(vocab, None, false, &AddedVocabulary::default()).unwrap(); let data = serde_json::to_string(&model).unwrap(); let reconstructed = serde_json::from_str(&data).unwrap(); diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index 5d178e77b..b3e816a59 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -2,6 +2,7 @@ use crate::models::unigram::{lattice::Lattice, model::Unigram}; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use crate::AddedVocabulary; use log::debug; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; @@ -182,6 +183,7 @@ impl UnigramTrainer { special_tokens.into_iter().chain(pieces).collect(), unk_id, model.byte_fallback(), + &AddedVocabulary::default(), ) } @@ -567,7 +569,8 @@ impl UnigramTrainer { if required_chars.len() as u32 > self.vocab_size { return Err(Box::new(UnigramTrainerError::VocabularyTooSmall)); } - let mut new_model = Unigram::from(pieces.clone(), Some(0), false)?; + let mut new_model = + Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?; loop { // Sub-EM iteration. for _iter in 0..self.n_sub_iterations { @@ -576,7 +579,8 @@ impl UnigramTrainer { // Executes M step. pieces = self.run_m_step(&pieces, &expected); - new_model = Unigram::from(pieces.clone(), Some(0), false)?; + new_model = + Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?; // Useful comment for checking compatibility with spm debug!( @@ -600,7 +604,7 @@ impl UnigramTrainer { // Prunes pieces. pieces = self.prune_sentence_pieces(&new_model, &pieces, &sentences); - new_model = Unigram::from(pieces.clone(), Some(0), false)?; + new_model = Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?; } self.finalize_progress(&progress, expected_updates); From 81d83361d0bfc466616d65f3eff91d723cc48630 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 5 Oct 2024 17:58:22 +0200 Subject: [PATCH 16/16] fix the unigram::from calls --- tokenizers/src/models/unigram/model.rs | 9 ++++++--- tokenizers/src/models/unigram/serialization.rs | 10 ++++++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 4a5371738..c604b11c6 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -548,7 +548,8 @@ mod tests { ("abcd".to_string(), 10.0), ]; - let model = Unigram::from(sentencepieces, Some(0), false).unwrap(); + let model = + Unigram::from(sentencepieces, Some(0), false, &AddedVocabulary::default()).unwrap(); let result = model.encode("abcd").unwrap(); assert_eq!(result, vec!["abcd"]); } @@ -570,7 +571,8 @@ mod tests { ("qr".to_string(), -0.5), ]; - let mut model = Unigram::from(sentencepieces, Some(0), false).unwrap(); + let mut model = + Unigram::from(sentencepieces, Some(0), false, &AddedVocabulary::default()).unwrap(); for is_optimized in &[true, false] { model.set_optimized(*is_optimized); @@ -617,7 +619,8 @@ mod tests { ("<0xC3>".to_string(), -0.01), ("<0xA9>".to_string(), -0.03), ]; - let unigram = Unigram::from(sentencepieces, Some(0), true).unwrap(); + let unigram = + Unigram::from(sentencepieces, Some(0), true, &AddedVocabulary::default()).unwrap(); let tokens: Vec = unigram.tokenize("é").unwrap(); assert_eq!( tokens, diff --git a/tokenizers/src/models/unigram/serialization.rs b/tokenizers/src/models/unigram/serialization.rs index 1ad95002e..f0ff30694 100644 --- a/tokenizers/src/models/unigram/serialization.rs +++ b/tokenizers/src/models/unigram/serialization.rs @@ -1,3 +1,5 @@ +use crate::AddedVocabulary; + use super::model::Unigram; use serde::{ de::{Error, MapAccess, Visitor}, @@ -69,8 +71,12 @@ impl<'de> Visitor<'de> for UnigramVisitor { } } match (vocab, unk_id, byte_fallback) { - (Some(vocab), unk_id, byte_fallback) => Ok(Unigram::from(vocab, unk_id, byte_fallback) - .map_err(|err| Error::custom(format!("Unable to load vocab {err:?}")))?), + (Some(vocab), unk_id, byte_fallback) => { + Ok( + Unigram::from(vocab, unk_id, byte_fallback, &AddedVocabulary::default()) + .map_err(|err| Error::custom(format!("Unable to load vocab {err:?}")))?, + ) + } (None, _, _) => Err(Error::custom("Missing vocab")), } }