Skip to content

Commit

Permalink
Better serialization error (#1595)
Browse files Browse the repository at this point in the history
* Updating the deserialization error for models.

* Update tokenizers/src/models/mod.rs

Co-authored-by: Arthur <[email protected]>

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
Narsil and ArthurZucker authored Aug 6, 2024
1 parent 2d27761 commit fe41687
Showing 1 changed file with 70 additions and 6 deletions.
76 changes: 70 additions & 6 deletions tokenizers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub mod wordpiece;
use std::collections::HashMap;
use std::path::{Path, PathBuf};

use serde::{Deserialize, Serialize, Serializer};
use serde::{Deserialize, Deserializer, Serialize, Serializer};

use crate::models::bpe::{BpeTrainer, BPE};
use crate::models::unigram::{Unigram, UnigramTrainer};
Expand Down Expand Up @@ -57,7 +57,7 @@ impl<'a> Serialize for OrderedVocabIter<'a> {
}
}

#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
#[derive(Serialize, Debug, PartialEq, Clone)]
#[serde(untagged)]
pub enum ModelWrapper {
BPE(BPE),
Expand All @@ -68,6 +68,73 @@ pub enum ModelWrapper {
Unigram(Unigram),
}

impl<'de> Deserialize<'de> for ModelWrapper {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
pub struct Tagged {
#[serde(rename = "type")]
variant: EnumType,
#[serde(flatten)]
rest: serde_json::Value,
}
#[derive(Deserialize)]
pub enum EnumType {
BPE,
WordPiece,
WordLevel,
Unigram,
}

#[derive(Deserialize)]
#[serde(untagged)]
pub enum ModelHelper {
Tagged(Tagged),
Legacy(serde_json::Value),
}

#[derive(Deserialize)]
#[serde(untagged)]
pub enum ModelUntagged {
BPE(BPE),
// WordPiece must stay before WordLevel here for deserialization (for retrocompatibility
// with the versions not including the "type"), since WordLevel is a subset of WordPiece
WordPiece(WordPiece),
WordLevel(WordLevel),
Unigram(Unigram),
}

let helper = ModelHelper::deserialize(deserializer)?;
Ok(match helper {
ModelHelper::Tagged(model) => match model.variant {
EnumType::BPE => ModelWrapper::BPE(
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
),
EnumType::WordPiece => ModelWrapper::WordPiece(
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
),
EnumType::WordLevel => ModelWrapper::WordLevel(
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
),
EnumType::Unigram => ModelWrapper::Unigram(
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
),
},
ModelHelper::Legacy(value) => {
let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
match untagged {
ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe),
ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe),
ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe),
ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe),
}
}
})
}
}

impl_enum_from!(WordLevel, ModelWrapper, WordLevel);
impl_enum_from!(WordPiece, ModelWrapper, WordPiece);
impl_enum_from!(BPE, ModelWrapper, BPE);
Expand Down Expand Up @@ -263,10 +330,7 @@ mod tests {
let reconstructed: std::result::Result<ModelWrapper, serde_json::Error> =
serde_json::from_str(invalid);
match reconstructed {
Err(err) => assert_eq!(
err.to_string(),
"data did not match any variant of untagged enum ModelWrapper"
),
Err(err) => assert_eq!(err.to_string(), "Merges text file invalid at line 1"),
_ => panic!("Expected an error here"),
}
}
Expand Down

0 comments on commit fe41687

Please sign in to comment.