Skip to content

Commit

Permalink
Give error when initializing tokenizer with too high stride (#1306)
Browse files Browse the repository at this point in the history
* Split `get_n_added_tokens` into separate method

* Modify `TokenizerImpl.with_truncation()` to raise an error if given bad parameters

* Return Python error if `tokenizer.with_truncation()` fails

* Add dummy variable assignment for `no_truncation()` case

* Unrelated fmt fix.

---------

Co-authored-by: Nicolas Patry <[email protected]>
  • Loading branch information
boyleconnor and Narsil authored Jul 28, 2023
1 parent bb38f39 commit c2664ae
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 17 deletions.
7 changes: 4 additions & 3 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,15 +712,16 @@ impl PyTokenizer {
}
}

self.tokenizer.with_truncation(Some(params));

if let Err(error_message) = self.tokenizer.with_truncation(Some(params)) {
return Err(PyError(error_message.to_string()).into_pyerr::<exceptions::PyValueError>());
}
Ok(())
}

/// Disable truncation
#[pyo3(text_signature = "(self)")]
fn no_truncation(&mut self) {
self.tokenizer.with_truncation(None);
let _ = self.tokenizer.with_truncation(None);
}

/// Get the currently set truncation parameters
Expand Down
14 changes: 7 additions & 7 deletions tokenizers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ impl<'a> Serialize for OrderedVocabIter<'a> {
let mut holes = vec![];
let result = if let Some(max) = self.vocab_r.iter().map(|(key, _)| key).max() {
let iter = (0..*max + 1).filter_map(|i| {
if let Some(token) = self.vocab_r.get(&i){
Some((token, i))
}else{
holes.push(i);
None
}
if let Some(token) = self.vocab_r.get(&i) {
Some((token, i))
} else {
holes.push(i);
None
}
});
serializer.collect_map(iter)
} else {
serializer.collect_map(std::iter::empty::<(&str, u32)>())
};

if !holes.is_empty(){
if !holes.is_empty() {
warn!("The OrderedVocab you are attempting to save contains holes for indices {:?}, your vocabulary could be corrupted !", holes);
println!("The OrderedVocab you are attempting to save contains holes for indices {:?}, your vocabulary could be corrupted !", holes);
}
Expand Down
34 changes: 27 additions & 7 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,10 @@ impl DerefMut for Tokenizer {
}
}

#[derive(thiserror::Error, Debug)]
#[error("{0}")]
pub struct TruncationParamError(String);

/// A `Tokenizer` is capable of encoding/decoding any text.
#[derive(Clone, Debug)]
pub struct TokenizerImpl<M, N, PT, PP, D> {
Expand Down Expand Up @@ -595,9 +599,21 @@ where
}

/// Set the truncation parameters
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> &mut Self {
///
/// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()`
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> Result<&mut Self> {
if let Some(trunc_params) = &trunc {
let n_added_tokens = self.get_n_added_tokens(false);
let effective_max_length = trunc_params.max_length - n_added_tokens;
if effective_max_length <= trunc_params.stride {
return Err(Box::new(TruncationParamError(format!(
"tokenizer stride set to {}, which is greater than or equal to its effective max length of {} (= {} original max length - {} added special tokens), ",
trunc_params.stride, effective_max_length, trunc_params.max_length, n_added_tokens
))));
}
}
self.truncation = trunc;
self
Ok(self)
}

/// Get the currently set truncation parameters
Expand Down Expand Up @@ -902,11 +918,7 @@ where
// 1. First we truncate if needed
let (encoding, pair_encoding) = {
if let Some(trunc) = &self.truncation {
let n_added_tokens = if let Some(processor) = &self.post_processor {
processor.added_tokens(pair_encoding.is_some())
} else {
0
};
let n_added_tokens = self.get_n_added_tokens(pair_encoding.is_some());

if add_special_tokens && n_added_tokens > 0 {
let params = TruncationParams {
Expand Down Expand Up @@ -950,6 +962,14 @@ where

Ok(final_encoding)
}

fn get_n_added_tokens(&self, is_pair: bool) -> usize {
if let Some(processor) = &self.post_processor {
processor.added_tokens(is_pair)
} else {
0
}
}
}

impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
Expand Down

0 comments on commit c2664ae

Please sign in to comment.