Skip to content

Commit

Permalink
potential initial solution for the annoying unigram model :)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Oct 5, 2024
1 parent ee7ce80 commit e8933fa
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
22 changes: 16 additions & 6 deletions bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,15 +562,25 @@ def test_setting_to_none(self):
tokenizer.pre_tokenizer = None
assert tokenizer.pre_tokenizer == None

def test_re_assign_tokens(self):
def test_re_assign_tokens_bpe(self):
tokenizer = Tokenizer.from_pretrained("gpt2")
tokenizer.assign_tokens({"<|endoftext|>": "my_new_token"})
assert tokenizer.decode([50256]) == "my_new_token"
assert tokenizer.encode("my_new_token").tokens == ["my_new_token"]
assert tokenizer.encode("my_new_token").ids == [50256]
assert tokenizer.encode("<|endoftext|>").ids == [27, 91, 437, 1659, 5239, 91, 29]
assert tokenizer.encode("<|endoftext|>").tokens == ["<", "|", "end", "of", "text", "|", ">"]
assert "my_new_token" in {k.content for k in tokenizer.get_added_tokens_decoder().values()}

def test_re_assign_tokens_unigram(self):
tokenizer = Tokenizer.from_pretrained("t5-base")
tokenizer.assign_tokens({"<extra_id_0>": "my_new_token"})
assert tokenizer.decode([32099]) == "my_new_token"
assert tokenizer.encode("my_new_token").tokens == ["my_new_token", "</s>"]
assert tokenizer.encode("my_new_token").ids == [32099, 1]
assert tokenizer.encode("<extra_id_0>").ids == [0, 1]
assert tokenizer.encode("<extra_id_0>").tokens == ["▁", "<", "extra", "_", "i", "d", "_", "0", ">", "</s>"]
assert "my_new_token" in tokenizer.get_vocab(True).keys()
assert tokenizer.encode("my_new_token").tokens == ["my_new_token"]
assert tokenizer.encode("my_new_token").ids == [32099]
assert tokenizer.encode("<extra_id_0>").ids == [27, 91, 437, 1659, 5239, 91, 29]
assert tokenizer.encode("<extra_id_0>").tokens == ["<", "|", "end", "of", "text", "|", ">"]
assert "my_new_token" in {k.content for k in tokenizer.get_added_tokens_decoder().values()}


class TestTokenizerRepr:
Expand Down
24 changes: 15 additions & 9 deletions tokenizers/src/models/unigram/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ use super::{
trainer::UnigramTrainer,
trie::{Trie, TrieBuilder},
};
use crate::tokenizer::{Model, Result, Token};
use crate::utils::cache::Cache;
use crate::{
tokenizer::{Model, Result, Token},
AddedVocabulary,
};

use std::collections::HashMap;
use std::convert::TryInto;
Expand Down Expand Up @@ -81,7 +84,7 @@ pub enum UnigramError {
impl Default for Unigram {
fn default() -> Self {
let vocab = vec![("<unk>".to_string(), 0.0)];
Self::from(vocab, Some(0), false).unwrap()
Self::from(vocab, Some(0), false, &AddedVocabulary::default()).unwrap()
}
}

Expand All @@ -96,6 +99,7 @@ impl Unigram {
vocab: Vec<(String, f64)>,
unk_id: Option<usize>,
byte_fallback: bool,
added_tokens: &AddedVocabulary,
) -> Result<Self> {
let n = vocab.len();
let mut token_to_ids: TokenMap = HashMap::new();
Expand All @@ -114,11 +118,13 @@ impl Unigram {

let mut min_score = f64::INFINITY;
for (id, (token, score)) in vocab.iter().enumerate() {
token_to_ids.insert(token.to_string(), id as u32);
let bytes: Vec<u8> = token.bytes().collect();
builder.push(&bytes);
if score < &min_score {
min_score = *score;
if !added_tokens.is_special_token(token) {
token_to_ids.insert(token.to_string(), id as u32);
let bytes: Vec<u8> = token.bytes().collect();
builder.push(&bytes);
if score < &min_score {
min_score = *score;
}
}
}
let trie = builder.build();
Expand Down Expand Up @@ -480,7 +486,7 @@ mod tests {
#[test]
fn test_populate_nodes_unk() {
let pieces = vec![("<unk>".to_string(), 0.0)];
let model = Unigram::from(pieces, Some(0), false).unwrap();
let model = Unigram::from(pieces, Some(0), false, &AddedVocabulary::default()).unwrap();

let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
model.populate_nodes(&mut lattice);
Expand All @@ -505,7 +511,7 @@ mod tests {
("ab".to_string(), 0.3),
("bc".to_string(), 0.4),
];
let model = Unigram::from(pieces, Some(0), false).unwrap();
let model = Unigram::from(pieces, Some(0), false, &AddedVocabulary::default()).unwrap();

let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
model.populate_nodes(&mut lattice);
Expand Down

0 comments on commit e8933fa

Please sign in to comment.