diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index b17c35c07..e12e70ffc 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -8,7 +8,7 @@ pub mod wordpiece; use std::collections::HashMap; use std::path::{Path, PathBuf}; -use serde::{Deserialize, Serialize, Serializer}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::models::bpe::{BpeTrainer, BPE}; use crate::models::unigram::{Unigram, UnigramTrainer}; @@ -57,7 +57,7 @@ impl<'a> Serialize for OrderedVocabIter<'a> { } } -#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] +#[derive(Serialize, Debug, PartialEq, Clone)] #[serde(untagged)] pub enum ModelWrapper { BPE(BPE), @@ -68,6 +68,73 @@ pub enum ModelWrapper { Unigram(Unigram), } +impl<'de> Deserialize<'de> for ModelWrapper { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + pub struct Tagged { + #[serde(rename = "type")] + variant: EnumType, + #[serde(flatten)] + rest: serde_json::Value, + } + #[derive(Deserialize)] + pub enum EnumType { + BPE, + WordPiece, + WordLevel, + Unigram, + } + + #[derive(Deserialize)] + #[serde(untagged)] + pub enum ModelHelper { + Tagged(Tagged), + Legacy(serde_json::Value), + } + + #[derive(Deserialize)] + #[serde(untagged)] + pub enum ModelUntagged { + BPE(BPE), + // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility + // with the versions not including the "type"), since WordLevel is a subset of WordPiece + WordPiece(WordPiece), + WordLevel(WordLevel), + Unigram(Unigram), + } + + let helper = ModelHelper::deserialize(deserializer)?; + Ok(match helper { + ModelHelper::Tagged(model) => match model.variant { + EnumType::BPE => ModelWrapper::BPE( + serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?, + ), + EnumType::WordPiece => ModelWrapper::WordPiece( + serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?, + ), + EnumType::WordLevel => ModelWrapper::WordLevel( + serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?, + ), + EnumType::Unigram => ModelWrapper::Unigram( + serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?, + ), + }, + ModelHelper::Legacy(value) => { + let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?; + match untagged { + ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe), + ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe), + ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe), + ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe), + } + } + }) + } +} + impl_enum_from!(WordLevel, ModelWrapper, WordLevel); impl_enum_from!(WordPiece, ModelWrapper, WordPiece); impl_enum_from!(BPE, ModelWrapper, BPE); @@ -263,10 +330,7 @@ mod tests { let reconstructed: std::result::Result = serde_json::from_str(invalid); match reconstructed { - Err(err) => assert_eq!( - err.to_string(), - "data did not match any variant of untagged enum ModelWrapper" - ), + Err(err) => assert_eq!(err.to_string(), "Merges text file invalid at line 1"), _ => panic!("Expected an error here"), } }