From 545d7230f485d2f199647c627f8b98f5a728c9bc Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 4 Oct 2024 15:24:14 +0200 Subject: [PATCH] fix unwrap errors --- tokenizers/src/tokenizer/added_vocabulary.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 6f79ba660..984201075 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -148,8 +148,6 @@ pub struct AddedVocabulary { /// Contains the mapping from ID to AddedToken for all the added tokens, both special /// and classic. added_tokens_map_r: Arc>>, - /// Contains only the classic AddedToken, in the specific order the user gave them. - added_tokens: Vec, /// A RegexSet containing all the non-normalized patterns used to split on AddedTokens split_trie: MatchingSet, @@ -173,7 +171,6 @@ impl AddedVocabulary { Self { added_tokens_map: Arc::new(Mutex::new(HashMap::new())), added_tokens_map_r: Arc::new(Mutex::new(HashMap::new())), - added_tokens: vec![], split_trie: (trie, vec![]), split_normalized_trie: (normalized_trie, vec![]), encode_special_tokens: false, @@ -218,12 +215,14 @@ impl AddedVocabulary { /// Check if a token is a special token pub fn is_special_token(&self, token: &str) -> bool { - self.added_tokens_map_r - .lock() - .unwrap() - .get(self.added_tokens_map.lock().unwrap().get(token).unwrap()) - .unwrap() - .special + let hash_map = &self.added_tokens_map_r.lock().unwrap(); + let revert_hash_map = &self.added_tokens_map.lock().unwrap(); + if let Some(id) = revert_hash_map.get(token) { + if let Some(token) = hash_map.get(id) { + return token.special; + } + } + false } /// Add some special tokens to the vocabulary @@ -335,7 +334,7 @@ impl AddedVocabulary { /// /// We keep two different RegexSet, one that will take care of matching against the /// non-normalized string, and one matching against the normalized one. - fn refresh_added_tokens(&mut self, model: &impl Model, normalizer: Option<&N>) { + fn refresh_added_tokens(&mut self, _model: &impl Model, normalizer: Option<&N>) { type TupleTokenId<'a> = (&'a AddedToken, u32); let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap().clone(); let (normalized, non_normalized): (Vec, Vec) =