From e8933fa5b996c11d9a8b61c7549b23e639c88b8d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 5 Oct 2024 17:16:31 +0200 Subject: [PATCH] potential initial solution for the annoying unigram model :) --- .../python/tests/bindings/test_tokenizer.py | 22 ++++++++++++----- tokenizers/src/models/unigram/model.rs | 24 ++++++++++++------- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 9a1fd7272..0fb2cbd7d 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -562,15 +562,25 @@ def test_setting_to_none(self): tokenizer.pre_tokenizer = None assert tokenizer.pre_tokenizer == None - def test_re_assign_tokens(self): + def test_re_assign_tokens_bpe(self): + tokenizer = Tokenizer.from_pretrained("gpt2") + tokenizer.assign_tokens({"<|endoftext|>": "my_new_token"}) + assert tokenizer.decode([50256]) == "my_new_token" + assert tokenizer.encode("my_new_token").tokens == ["my_new_token"] + assert tokenizer.encode("my_new_token").ids == [50256] + assert tokenizer.encode("<|endoftext|>").ids == [27, 91, 437, 1659, 5239, 91, 29] + assert tokenizer.encode("<|endoftext|>").tokens == ["<", "|", "end", "of", "text", "|", ">"] + assert "my_new_token" in {k.content for k in tokenizer.get_added_tokens_decoder().values()} + + def test_re_assign_tokens_unigram(self): tokenizer = Tokenizer.from_pretrained("t5-base") tokenizer.assign_tokens({"": "my_new_token"}) assert tokenizer.decode([32099]) == "my_new_token" - assert tokenizer.encode("my_new_token").tokens == ["my_new_token", ""] - assert tokenizer.encode("my_new_token").ids == [32099, 1] - assert tokenizer.encode("").ids == [0, 1] - assert tokenizer.encode("").tokens == ["▁", "<", "extra", "_", "i", "d", "_", "0", ">", ""] - assert "my_new_token" in tokenizer.get_vocab(True).keys() + assert tokenizer.encode("my_new_token").tokens == ["my_new_token"] + assert tokenizer.encode("my_new_token").ids == [32099] + assert tokenizer.encode("").ids == [27, 91, 437, 1659, 5239, 91, 29] + assert tokenizer.encode("").tokens == ["<", "|", "end", "of", "text", "|", ">"] + assert "my_new_token" in {k.content for k in tokenizer.get_added_tokens_decoder().values()} class TestTokenizerRepr: diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index dba5a0400..4a5371738 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -3,8 +3,11 @@ use super::{ trainer::UnigramTrainer, trie::{Trie, TrieBuilder}, }; -use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::Cache; +use crate::{ + tokenizer::{Model, Result, Token}, + AddedVocabulary, +}; use std::collections::HashMap; use std::convert::TryInto; @@ -81,7 +84,7 @@ pub enum UnigramError { impl Default for Unigram { fn default() -> Self { let vocab = vec![("".to_string(), 0.0)]; - Self::from(vocab, Some(0), false).unwrap() + Self::from(vocab, Some(0), false, &AddedVocabulary::default()).unwrap() } } @@ -96,6 +99,7 @@ impl Unigram { vocab: Vec<(String, f64)>, unk_id: Option, byte_fallback: bool, + added_tokens: &AddedVocabulary, ) -> Result { let n = vocab.len(); let mut token_to_ids: TokenMap = HashMap::new(); @@ -114,11 +118,13 @@ impl Unigram { let mut min_score = f64::INFINITY; for (id, (token, score)) in vocab.iter().enumerate() { - token_to_ids.insert(token.to_string(), id as u32); - let bytes: Vec = token.bytes().collect(); - builder.push(&bytes); - if score < &min_score { - min_score = *score; + if !added_tokens.is_special_token(token) { + token_to_ids.insert(token.to_string(), id as u32); + let bytes: Vec = token.bytes().collect(); + builder.push(&bytes); + if score < &min_score { + min_score = *score; + } } } let trie = builder.build(); @@ -480,7 +486,7 @@ mod tests { #[test] fn test_populate_nodes_unk() { let pieces = vec![("".to_string(), 0.0)]; - let model = Unigram::from(pieces, Some(0), false).unwrap(); + let model = Unigram::from(pieces, Some(0), false, &AddedVocabulary::default()).unwrap(); let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id); model.populate_nodes(&mut lattice); @@ -505,7 +511,7 @@ mod tests { ("ab".to_string(), 0.3), ("bc".to_string(), 0.4), ]; - let model = Unigram::from(pieces, Some(0), false).unwrap(); + let model = Unigram::from(pieces, Some(0), false, &AddedVocabulary::default()).unwrap(); let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id); model.populate_nodes(&mut lattice);