diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index 152586efb..98cc15102 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -30,14 +30,14 @@ impl Serialize for BPE { .map(|(pair, (rank, _))| (pair, rank)) .collect(); merges.sort_unstable_by_key(|k| *k.1); - let merges_str = merges + let merges = merges .into_iter() - .map(|(pair, _)| format!("{} {}", self.vocab_r[&pair.0], self.vocab_r[&pair.1])) + .map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone())) .collect::>(); let ordered_vocab = OrderedVocabIter::new(&self.vocab_r); model.serialize_field("vocab", &ordered_vocab)?; - model.serialize_field("merges", &merges_str)?; + model.serialize_field("merges", &merges)?; model.end() } @@ -81,7 +81,14 @@ impl<'de> Visitor<'de> for BPEVisitor { { let mut builder = BpeBuilder::new(); let mut vocab: Option> = None; - let mut merges: Option> = None; + + #[derive(Debug, Deserialize)] + #[serde(untagged)] + enum MergeType { + Tuple(Vec<(String, String)>), + Legacy(Vec), + } + let mut merges: Option = None; while let Some(key) = map.next_key::()? { match key.as_ref() { "dropout" => { @@ -134,8 +141,12 @@ impl<'de> Visitor<'de> for BPEVisitor { } } if let (Some(vocab), Some(merges)) = (vocab, merges) { - let merges = - convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)?; + let merges = match merges { + MergeType::Tuple(merges) => merges, + MergeType::Legacy(merges) => { + convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)? + } + }; builder = builder.vocab_and_merges(vocab, merges); Ok(builder.build().map_err(Error::custom)?) } else { @@ -167,13 +178,40 @@ mod test { .build() .unwrap(); + let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#; + let legacy = serde_json::from_str(legacy).unwrap(); + assert_eq!(bpe, legacy); + let data = serde_json::to_string(&bpe).unwrap(); assert_eq!( data, - r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"# + r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"# ); let reconstructed = serde_json::from_str(&data).unwrap(); + assert_eq!(bpe, reconstructed); + // With a space in the token + let vocab: Vocab = [ + ("".into(), 0), + ("a".into(), 1), + ("b c d".into(), 2), + ("ab c d".into(), 3), + ] + .iter() + .cloned() + .collect(); + let bpe = BpeBuilder::default() + .vocab_and_merges(vocab, vec![("a".to_string(), "b c d".to_string())]) + .unk_token("".to_string()) + .ignore_merges(true) + .build() + .unwrap(); + let data = serde_json::to_string(&bpe).unwrap(); + assert_eq!( + data, + r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b c d":2,"ab c d":3},"merges":[["a","b c d"]]}"# + ); + let reconstructed = serde_json::from_str(&data).unwrap(); assert_eq!(bpe, reconstructed); } diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index e12e70ffc..cdfb731a8 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -312,11 +312,14 @@ mod tests { .unwrap(); let model = ModelWrapper::BPE(bpe); + let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#; + let legacy = serde_json::from_str(legacy).unwrap(); + assert_eq!(model, legacy); let data = serde_json::to_string(&model).unwrap(); assert_eq!( data, - r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"# + r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"# ); let reconstructed = serde_json::from_str(&data).unwrap(); assert_eq!(model, reconstructed);