Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support Assign token to update the content of a token #1570

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
17 changes: 17 additions & 0 deletions bindings/python/py_src/tokenizers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,23 @@ class Tokenizer:
"""
pass

def assign_tokens(self, old_tokens, new_tokens):
"""
Add the given tokens to the vocabulary

The given tokens are added only if they don't already exist in the vocabulary.
Each token then gets a new attributed id.

Args:
tokens (A :obj:`List` of :class:`~tokenizers.AddedToken` or :obj:`str`):
The list of tokens we want to add to the vocabulary. Each token can be either a
string or an instance of :class:`~tokenizers.AddedToken` for more customization.

Returns:
:obj:`int`: The number of tokens that were created in the vocabulary
"""
pass

def decode(self, ids, skip_special_tokens=True):
"""
Decode the given list of ids back to a string
Expand Down
45 changes: 45 additions & 0 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,51 @@ impl PyTokenizer {
Ok(self.tokenizer.add_tokens(&tokens))
}

/// Add the given tokens to the vocabulary
///
/// The given tokens are added only if they don't already exist in the vocabulary.
/// Each token then gets a new attributed id.
///
/// Args:
/// tokens (A :obj:`List` of :class:`~tokenizers.AddedToken` or :obj:`str`):
/// The list of tokens we want to add to the vocabulary. Each token can be either a
/// string or an instance of :class:`~tokenizers.AddedToken` for more customization.
///
/// Returns:
/// :obj:`int`: The number of tokens that were created in the vocabulary
#[pyo3(text_signature = "(self, old_tokens, new_tokens)")]
fn assign_tokens(&mut self, old_to_new_map: &Bound<'_, PyDict>) -> PyResult<()> {
use pyo3::exceptions::PyTypeError;

let mut processed_old_tokens = HashMap::with_capacity(old_to_new_map.len());
for (old, new) in old_to_new_map.iter() {
let old_token = if let Ok(content) = old.extract::<&str>() {
PyAddedToken::from(content.to_string(), Some(false)).get_token()
} else if let Ok(token) = old.extract::<PyRefMut<PyAddedToken>>() {
token.get_token()
} else {
return Err(PyTypeError::new_err(
"old_tokens must be a List[Union[str, AddedToken]]",
));
};

let new_token = if let Ok(content) = new.extract::<&str>() {
let mut updated_token = old_token.clone();
updated_token.content = content.to_string();
updated_token
} else if let Ok(token) = new.extract::<PyRefMut<PyAddedToken>>() {
token.get_token()
} else {
return Err(PyTypeError::new_err(
"new_tokens must be a List[Union[str, AddedToken]]",
));
};

processed_old_tokens.insert(old_token, new_token);
}
self.tokenizer.assign_tokens(&processed_old_tokens);
Ok(())
}
/// Add the given special tokens to the Tokenizer.
///
/// If these tokens are already part of the vocabulary, it just let the Tokenizer know about
Expand Down
20 changes: 20 additions & 0 deletions bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,26 @@ def test_setting_to_none(self):
tokenizer.pre_tokenizer = None
assert tokenizer.pre_tokenizer == None

def test_re_assign_tokens_bpe(self):
tokenizer = Tokenizer.from_pretrained("gpt2")
tokenizer.assign_tokens({"<|endoftext|>": "my_new_token"})
assert tokenizer.decode([50256]) == "my_new_token"
assert tokenizer.encode("my_new_token").tokens == ["my_new_token"]
assert tokenizer.encode("my_new_token").ids == [50256]
assert tokenizer.encode("<|endoftext|>").ids == [27, 91, 437, 1659, 5239, 91, 29]
assert tokenizer.encode("<|endoftext|>").tokens == ["<", "|", "end", "of", "text", "|", ">"]
assert "my_new_token" in {k.content for k in tokenizer.get_added_tokens_decoder().values()}

def test_re_assign_tokens_unigram(self):
tokenizer = Tokenizer.from_pretrained("t5-base")
tokenizer.assign_tokens({"<extra_id_0>": "my_new_token"})
assert tokenizer.decode([32099]) == "my_new_token"
assert tokenizer.encode("my_new_token").tokens == ["my_new_token"]
assert tokenizer.encode("my_new_token").ids == [32099]
assert tokenizer.encode("<extra_id_0>").ids == [27, 91, 437, 1659, 5239, 91, 29]
assert tokenizer.encode("<extra_id_0>").tokens == ["<", "|", "end", "of", "text", "|", ">"]
assert "my_new_token" in {k.content for k in tokenizer.get_added_tokens_decoder().values()}


class TestTokenizerRepr:
def test_repr(self):
Expand Down
33 changes: 21 additions & 12 deletions tokenizers/src/models/unigram/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ use super::{
trainer::UnigramTrainer,
trie::{Trie, TrieBuilder},
};
use crate::tokenizer::{Model, Result, Token};
use crate::utils::cache::Cache;
use crate::{
tokenizer::{Model, Result, Token},
AddedVocabulary,
};

use std::collections::HashMap;
use std::convert::TryInto;
Expand Down Expand Up @@ -81,7 +84,7 @@ pub enum UnigramError {
impl Default for Unigram {
fn default() -> Self {
let vocab = vec![("<unk>".to_string(), 0.0)];
Self::from(vocab, Some(0), false).unwrap()
Self::from(vocab, Some(0), false, &AddedVocabulary::default()).unwrap()
}
}

Expand All @@ -96,6 +99,7 @@ impl Unigram {
vocab: Vec<(String, f64)>,
unk_id: Option<usize>,
byte_fallback: bool,
added_tokens: &AddedVocabulary,
) -> Result<Self> {
let n = vocab.len();
let mut token_to_ids: TokenMap = HashMap::new();
Expand All @@ -114,11 +118,13 @@ impl Unigram {

let mut min_score = f64::INFINITY;
for (id, (token, score)) in vocab.iter().enumerate() {
token_to_ids.insert(token.to_string(), id as u32);
let bytes: Vec<u8> = token.bytes().collect();
builder.push(&bytes);
if score < &min_score {
min_score = *score;
if !added_tokens.is_special_token(token) {
token_to_ids.insert(token.to_string(), id as u32);
let bytes: Vec<u8> = token.bytes().collect();
builder.push(&bytes);
if score < &min_score {
min_score = *score;
}
}
}
let trie = builder.build();
Expand Down Expand Up @@ -480,7 +486,7 @@ mod tests {
#[test]
fn test_populate_nodes_unk() {
let pieces = vec![("<unk>".to_string(), 0.0)];
let model = Unigram::from(pieces, Some(0), false).unwrap();
let model = Unigram::from(pieces, Some(0), false, &AddedVocabulary::default()).unwrap();

let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
model.populate_nodes(&mut lattice);
Expand All @@ -505,7 +511,7 @@ mod tests {
("ab".to_string(), 0.3),
("bc".to_string(), 0.4),
];
let model = Unigram::from(pieces, Some(0), false).unwrap();
let model = Unigram::from(pieces, Some(0), false, &AddedVocabulary::default()).unwrap();

let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
model.populate_nodes(&mut lattice);
Expand Down Expand Up @@ -542,7 +548,8 @@ mod tests {
("abcd".to_string(), 10.0),
];

let model = Unigram::from(sentencepieces, Some(0), false).unwrap();
let model =
Unigram::from(sentencepieces, Some(0), false, &AddedVocabulary::default()).unwrap();
let result = model.encode("abcd").unwrap();
assert_eq!(result, vec!["abcd"]);
}
Expand All @@ -564,7 +571,8 @@ mod tests {
("qr".to_string(), -0.5),
];

let mut model = Unigram::from(sentencepieces, Some(0), false).unwrap();
let mut model =
Unigram::from(sentencepieces, Some(0), false, &AddedVocabulary::default()).unwrap();

for is_optimized in &[true, false] {
model.set_optimized(*is_optimized);
Expand Down Expand Up @@ -611,7 +619,8 @@ mod tests {
("<0xC3>".to_string(), -0.01),
("<0xA9>".to_string(), -0.03),
];
let unigram = Unigram::from(sentencepieces, Some(0), true).unwrap();
let unigram =
Unigram::from(sentencepieces, Some(0), true, &AddedVocabulary::default()).unwrap();
let tokens: Vec<Token> = unigram.tokenize("é").unwrap();
assert_eq!(
tokens,
Expand Down
18 changes: 13 additions & 5 deletions tokenizers/src/models/unigram/serialization.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::AddedVocabulary;

use super::model::Unigram;
use serde::{
de::{Error, MapAccess, Visitor},
Expand Down Expand Up @@ -69,21 +71,27 @@ impl<'de> Visitor<'de> for UnigramVisitor {
}
}
match (vocab, unk_id, byte_fallback) {
(Some(vocab), unk_id, byte_fallback) => Ok(Unigram::from(vocab, unk_id, byte_fallback)
.map_err(|err| Error::custom(format!("Unable to load vocab {err:?}")))?),
(Some(vocab), unk_id, byte_fallback) => {
Ok(
Unigram::from(vocab, unk_id, byte_fallback, &AddedVocabulary::default())
.map_err(|err| Error::custom(format!("Unable to load vocab {err:?}")))?,
)
}
(None, _, _) => Err(Error::custom("Missing vocab")),
}
}
}

#[cfg(test)]
mod test {
use crate::AddedVocabulary;

use super::*;

#[test]
fn test_serialization() {
let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
let model = Unigram::from(vocab, Some(0), false).unwrap();
let model = Unigram::from(vocab, Some(0), false, &AddedVocabulary::default()).unwrap();

let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap();
Expand All @@ -94,7 +102,7 @@ mod test {
#[test]
fn test_serialization_unk_id_not_zero() {
let vocab = vec![("a".to_string(), -0.5), ("<unk>".to_string(), 0.0)];
let model = Unigram::from(vocab, Some(1), false).unwrap();
let model = Unigram::from(vocab, Some(1), false, &AddedVocabulary::default()).unwrap();

let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap();
Expand All @@ -105,7 +113,7 @@ mod test {
#[test]
fn test_serialization_no_unk_id() {
let vocab = vec![("a".to_string(), -0.5)];
let model = Unigram::from(vocab, None, false).unwrap();
let model = Unigram::from(vocab, None, false, &AddedVocabulary::default()).unwrap();

let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap();
Expand Down
10 changes: 7 additions & 3 deletions tokenizers/src/models/unigram/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::models::unigram::{lattice::Lattice, model::Unigram};
use crate::tokenizer::{AddedToken, Result, Trainer};
use crate::utils::parallelism::*;
use crate::utils::progress::{ProgressBar, ProgressStyle};
use crate::AddedVocabulary;
use log::debug;
use serde::{Deserialize, Serialize};
use std::cmp::Reverse;
Expand Down Expand Up @@ -182,6 +183,7 @@ impl UnigramTrainer {
special_tokens.into_iter().chain(pieces).collect(),
unk_id,
model.byte_fallback(),
&AddedVocabulary::default(),
)
}

Expand Down Expand Up @@ -567,7 +569,8 @@ impl UnigramTrainer {
if required_chars.len() as u32 > self.vocab_size {
return Err(Box::new(UnigramTrainerError::VocabularyTooSmall));
}
let mut new_model = Unigram::from(pieces.clone(), Some(0), false)?;
let mut new_model =
Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?;
loop {
// Sub-EM iteration.
for _iter in 0..self.n_sub_iterations {
Expand All @@ -576,7 +579,8 @@ impl UnigramTrainer {

// Executes M step.
pieces = self.run_m_step(&pieces, &expected);
new_model = Unigram::from(pieces.clone(), Some(0), false)?;
new_model =
Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?;

// Useful comment for checking compatibility with spm
debug!(
Expand All @@ -600,7 +604,7 @@ impl UnigramTrainer {

// Prunes pieces.
pieces = self.prune_sentence_pieces(&new_model, &pieces, &sentences);
new_model = Unigram::from(pieces.clone(), Some(0), false)?;
new_model = Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?;
}
self.finalize_progress(&progress, expected_updates);

Expand Down
Loading
Loading