Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify ser-as-any mechanism #1478

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
14 changes: 2 additions & 12 deletions src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::build_tools::py_schema_error_type;
use crate::errors::LocItem;
use crate::get_pydantic_version;
use crate::input::InputType;
use crate::serializers::{DuckTypingSerMode, Extra, SerMode, SerializationState};
use crate::serializers::{Extra, SerMode, SerializationState};
use crate::tools::{safe_repr, write_truncated_to_limited_bytes, SchemaDict};

use super::line_error::ValLineError;
Expand Down Expand Up @@ -323,17 +323,7 @@ impl ValidationError {
include_input: bool,
) -> PyResult<Bound<'py, PyString>> {
let state = SerializationState::new("iso8601", "utf8", "constants")?;
let extra = state.extra(
py,
&SerMode::Json,
true,
false,
false,
true,
None,
DuckTypingSerMode::SchemaBased,
None,
);
let extra = state.extra(py, &SerMode::Json, true, false, false, true, None, false, None);
let serializer = ValidationErrorSerializer {
py,
line_errors: &self.line_errors,
Expand Down
17 changes: 12 additions & 5 deletions src/serializers/computed_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::serializers::shared::{BuildSerializer, CombinedSerializer, PydanticSe
use crate::tools::SchemaDict;

use super::errors::py_err_se_err;
use super::type_serializers::any::AnySerializer;
use super::Extra;

#[derive(Debug)]
Expand Down Expand Up @@ -156,10 +157,12 @@ impl ComputedField {

if let Some((next_include, next_exclude)) = filter.key_filter(property_name_py, include, exclude)? {
let next_value = model.getattr(property_name_py)?;

let value = self
.serializer
.to_python(&next_value, next_include.as_ref(), next_exclude.as_ref(), extra)?;
let serializer = if extra.serialize_as_any {
AnySerializer::get()
} else {
&self.serializer
};
let value = serializer.to_python(&next_value, next_include.as_ref(), next_exclude.as_ref(), extra)?;
if extra.exclude_none && value.is_none(py) {
return Ok(());
}
Expand Down Expand Up @@ -198,7 +201,11 @@ impl<'py> Serialize for ComputedFieldSerializer<'py> {
let next_value = self.model.getattr(property_name_py).map_err(py_err_se_err)?;
let s = PydanticSerializer::new(
&next_value,
&self.computed_field.serializer,
if self.extra.serialize_as_any {
AnySerializer::get()
} else {
&self.computed_field.serializer
},
self.include,
self.exclude,
self.extra,
Expand Down
55 changes: 8 additions & 47 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,45 +26,6 @@ pub(crate) struct SerializationState {
config: SerializationConfig,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum DuckTypingSerMode {
// Don't check the type of the value, use the type of the schema
SchemaBased,
// Check the type of the value, use the type of the value
NeedsInference,
// We already checked the type of the value
// we don't want to infer again, but if we recurse down
// we do want to flip this back to NeedsInference for the
// fields / keys / items of any inner serializers
Inferred,
}

impl DuckTypingSerMode {
pub fn from_bool(serialize_as_any: bool) -> Self {
if serialize_as_any {
DuckTypingSerMode::NeedsInference
} else {
DuckTypingSerMode::SchemaBased
}
}

pub fn to_bool(self) -> bool {
match self {
DuckTypingSerMode::SchemaBased => false,
DuckTypingSerMode::NeedsInference => true,
DuckTypingSerMode::Inferred => true,
}
}

pub fn next_mode(self) -> Self {
match self {
DuckTypingSerMode::SchemaBased => DuckTypingSerMode::SchemaBased,
DuckTypingSerMode::NeedsInference => DuckTypingSerMode::Inferred,
DuckTypingSerMode::Inferred => DuckTypingSerMode::NeedsInference,
}
}
}

impl SerializationState {
pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
let warnings = CollectWarnings::new(WarningsMode::None);
Expand All @@ -87,7 +48,7 @@ impl SerializationState {
round_trip: bool,
serialize_unknown: bool,
fallback: Option<&'py Bound<'_, PyAny>>,
duck_typing_ser_mode: DuckTypingSerMode,
serialize_as_any: bool,
context: Option<&'py Bound<'_, PyAny>>,
) -> Extra<'py> {
Extra::new(
Expand All @@ -103,7 +64,7 @@ impl SerializationState {
&self.rec_guard,
serialize_unknown,
fallback,
duck_typing_ser_mode,
serialize_as_any,
context,
)
}
Expand Down Expand Up @@ -136,7 +97,7 @@ pub(crate) struct Extra<'a> {
pub field_name: Option<&'a str>,
pub serialize_unknown: bool,
pub fallback: Option<&'a Bound<'a, PyAny>>,
pub duck_typing_ser_mode: DuckTypingSerMode,
pub serialize_as_any: bool,
pub context: Option<&'a Bound<'a, PyAny>>,
}

Expand All @@ -155,7 +116,7 @@ impl<'a> Extra<'a> {
rec_guard: &'a SerRecursionState,
serialize_unknown: bool,
fallback: Option<&'a Bound<'a, PyAny>>,
duck_typing_ser_mode: DuckTypingSerMode,
serialize_as_any: bool,
context: Option<&'a Bound<'a, PyAny>>,
) -> Self {
Self {
Expand All @@ -174,7 +135,7 @@ impl<'a> Extra<'a> {
field_name: None,
serialize_unknown,
fallback,
duck_typing_ser_mode,
serialize_as_any,
context,
}
}
Expand Down Expand Up @@ -234,7 +195,7 @@ pub(crate) struct ExtraOwned {
field_name: Option<String>,
serialize_unknown: bool,
pub fallback: Option<PyObject>,
duck_typing_ser_mode: DuckTypingSerMode,
serialize_as_any: bool,
pub context: Option<PyObject>,
}

Expand All @@ -255,7 +216,7 @@ impl ExtraOwned {
field_name: extra.field_name.map(ToString::to_string),
serialize_unknown: extra.serialize_unknown,
fallback: extra.fallback.map(|model| model.clone().into()),
duck_typing_ser_mode: extra.duck_typing_ser_mode,
serialize_as_any: extra.serialize_as_any,
context: extra.context.map(|model| model.clone().into()),
}
}
Expand All @@ -277,7 +238,7 @@ impl ExtraOwned {
field_name: self.field_name.as_deref(),
serialize_unknown: self.serialize_unknown,
fallback: self.fallback.as_ref().map(|m| m.bind(py)),
duck_typing_ser_mode: self.duck_typing_ser_mode,
serialize_as_any: self.serialize_as_any,
context: self.context.as_ref().map(|m| m.bind(py)),
}
}
Expand Down
60 changes: 20 additions & 40 deletions src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use serde::ser::SerializeMap;
use smallvec::SmallVec;

use crate::serializers::extra::SerCheck;
use crate::serializers::DuckTypingSerMode;
use crate::tools::truncate_safe_repr;
use crate::PydanticSerializationUnexpectedValue;

Expand All @@ -19,6 +18,7 @@ use super::filter::SchemaFilter;
use super::infer::{infer_json_key, infer_serialize, infer_to_python, SerializeInfer};
use super::shared::PydanticSerializer;
use super::shared::{CombinedSerializer, TypeSerializer};
use super::type_serializers::any::AnySerializer;

/// representation of a field for serialization
#[derive(Debug)]
Expand Down Expand Up @@ -177,12 +177,16 @@ impl GeneralFieldsSerializer {
if let Some(field) = op_field {
if let Some(ref serializer) = field.serializer {
if !exclude_default(&value, &field_extra, serializer)? {
let value = serializer.to_python(
&value,
next_include.as_ref(),
next_exclude.as_ref(),
&field_extra,
)?;
let value = if extra.serialize_as_any {
infer_to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?
} else {
serializer.to_python(
&value,
next_include.as_ref(),
next_exclude.as_ref(),
&field_extra,
)?
};
let output_key = field.get_key_py(output_dict.py(), &field_extra);
output_dict.set_item(output_key, value)?;
}
Expand All @@ -193,10 +197,10 @@ impl GeneralFieldsSerializer {
}
} else if self.mode == FieldsMode::TypedDictAllow {
let value = match &self.extra_serializer {
Some(serializer) => {
Some(serializer) if !extra.serialize_as_any => {
serializer.to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?
}
None => infer_to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?,
_ => infer_to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?,
};
output_dict.set_item(key, value)?;
} else if field_extra.check == SerCheck::Strict {
Expand Down Expand Up @@ -265,7 +269,11 @@ impl GeneralFieldsSerializer {
if !exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
let s = PydanticSerializer::new(
&value,
serializer,
if extra.serialize_as_any {
AnySerializer::get()
} else {
serializer
},
next_include.as_ref(),
next_exclude.as_ref(),
&field_extra,
Expand Down Expand Up @@ -342,20 +350,6 @@ impl TypeSerializer for GeneralFieldsSerializer {
// If there is no model, we (a TypedDict) are the model
let model = extra.model.map_or_else(|| Some(value), Some);

// If there is no model, use duck typing ser logic for TypedDict
// If there is a model, skip this step, as BaseModel and dataclass duck typing
// is handled in their respective serializers
if extra.model.is_none() {
let duck_typing_ser_mode = extra.duck_typing_ser_mode.next_mode();
let td_extra = Extra {
model,
duck_typing_ser_mode,
..*extra
};
if td_extra.duck_typing_ser_mode == DuckTypingSerMode::Inferred {
return infer_to_python(value, include, exclude, &td_extra);
}
}
let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) {
main_extra_dict
} else {
Expand All @@ -374,10 +368,10 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
if let Some((next_include, next_exclude)) = self.filter.key_filter(&key, include, exclude)? {
let value = match &self.extra_serializer {
Some(serializer) => {
Some(serializer) if !extra.serialize_as_any => {
serializer.to_python(&value, next_include.as_ref(), next_exclude.as_ref(), extra)?
}
None => infer_to_python(&value, next_include.as_ref(), next_exclude.as_ref(), extra)?,
_ => infer_to_python(&value, next_include.as_ref(), next_exclude.as_ref(), extra)?,
};
output_dict.set_item(key, value)?;
}
Expand Down Expand Up @@ -411,20 +405,6 @@ impl TypeSerializer for GeneralFieldsSerializer {
// If there is no model, we (a TypedDict) are the model
let model = extra.model.map_or_else(|| Some(value), Some);

// If there is no model, use duck typing ser logic for TypedDict
// If there is a model, skip this step, as BaseModel and dataclass duck typing
// is handled in their respective serializers
if extra.model.is_none() {
let duck_typing_ser_mode = extra.duck_typing_ser_mode.next_mode();
let td_extra = Extra {
model,
duck_typing_ser_mode,
..*extra
};
if td_extra.duck_typing_ser_mode == DuckTypingSerMode::Inferred {
return infer_serialize(value, serializer, include, exclude, &td_extra);
}
}
let expected_len = match self.mode {
FieldsMode::TypedDictAllow => main_dict.len() + self.computed_field_count(),
_ => self.fields.len() + option_length!(extra_dict) + self.computed_field_count(),
Expand Down
4 changes: 2 additions & 2 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pub(crate) fn infer_to_python_known(
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
extra.duck_typing_ser_mode,
extra.serialize_as_any,
extra.context,
);
serializer.serializer.to_python(value, include, exclude, &extra)
Expand Down Expand Up @@ -505,7 +505,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
extra.duck_typing_ser_mode,
extra.serialize_as_any,
extra.context,
);
let pydantic_serializer =
Expand Down
Loading
Loading