Skip to content

Commit

Permalink
Merges cannot handle tokens containing spaces. (#909)
Browse files Browse the repository at this point in the history
* Merges cannot handle tokens containing spaces.

This fixes this while keeping backward support.
We don't want to merge that blindly.

* Update the tests.

* Fixing clippy.

* Add a test with spaces in the token/merge.
  • Loading branch information
Narsil authored Aug 7, 2024
1 parent ab9c7de commit 6a5fce9
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 8 deletions.
52 changes: 45 additions & 7 deletions tokenizers/src/models/bpe/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ impl Serialize for BPE {
.map(|(pair, (rank, _))| (pair, rank))
.collect();
merges.sort_unstable_by_key(|k| *k.1);
let merges_str = merges
let merges = merges
.into_iter()
.map(|(pair, _)| format!("{} {}", self.vocab_r[&pair.0], self.vocab_r[&pair.1]))
.map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone()))
.collect::<Vec<_>>();
let ordered_vocab = OrderedVocabIter::new(&self.vocab_r);

model.serialize_field("vocab", &ordered_vocab)?;
model.serialize_field("merges", &merges_str)?;
model.serialize_field("merges", &merges)?;

model.end()
}
Expand Down Expand Up @@ -81,7 +81,14 @@ impl<'de> Visitor<'de> for BPEVisitor {
{
let mut builder = BpeBuilder::new();
let mut vocab: Option<HashMap<String, u32>> = None;
let mut merges: Option<Vec<String>> = None;

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum MergeType {
Tuple(Vec<(String, String)>),
Legacy(Vec<String>),
}
let mut merges: Option<MergeType> = None;
while let Some(key) = map.next_key::<String>()? {
match key.as_ref() {
"dropout" => {
Expand Down Expand Up @@ -134,8 +141,12 @@ impl<'de> Visitor<'de> for BPEVisitor {
}
}
if let (Some(vocab), Some(merges)) = (vocab, merges) {
let merges =
convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)?;
let merges = match merges {
MergeType::Tuple(merges) => merges,
MergeType::Legacy(merges) => {
convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)?
}
};
builder = builder.vocab_and_merges(vocab, merges);
Ok(builder.build().map_err(Error::custom)?)
} else {
Expand Down Expand Up @@ -167,13 +178,40 @@ mod test {
.build()
.unwrap();

let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
let legacy = serde_json::from_str(legacy).unwrap();
assert_eq!(bpe, legacy);

let data = serde_json::to_string(&bpe).unwrap();
assert_eq!(
data,
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"#
);
let reconstructed = serde_json::from_str(&data).unwrap();
assert_eq!(bpe, reconstructed);

// With a space in the token
let vocab: Vocab = [
("<unk>".into(), 0),
("a".into(), 1),
("b c d".into(), 2),
("ab c d".into(), 3),
]
.iter()
.cloned()
.collect();
let bpe = BpeBuilder::default()
.vocab_and_merges(vocab, vec![("a".to_string(), "b c d".to_string())])
.unk_token("<unk>".to_string())
.ignore_merges(true)
.build()
.unwrap();
let data = serde_json::to_string(&bpe).unwrap();
assert_eq!(
data,
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b c d":2,"ab c d":3},"merges":[["a","b c d"]]}"#
);
let reconstructed = serde_json::from_str(&data).unwrap();
assert_eq!(bpe, reconstructed);
}

Expand Down
5 changes: 4 additions & 1 deletion tokenizers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,14 @@ mod tests {
.unwrap();

let model = ModelWrapper::BPE(bpe);
let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
let legacy = serde_json::from_str(legacy).unwrap();
assert_eq!(model, legacy);

let data = serde_json::to_string(&model).unwrap();
assert_eq!(
data,
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"#
);
let reconstructed = serde_json::from_str(&data).unwrap();
assert_eq!(model, reconstructed);
Expand Down

0 comments on commit 6a5fce9

Please sign in to comment.