diff --git a/generate_self_schema.py b/generate_self_schema.py index 2c190bbad..83167d038 100644 --- a/generate_self_schema.py +++ b/generate_self_schema.py @@ -54,8 +54,9 @@ def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core return type_dict_schema(obj, definitions) elif obj == Any or obj == type: return {'type': 'any'} - if isinstance(obj, type) and issubclass(obj, core_schema.Protocol): + elif isinstance(obj, type) and issubclass(obj, core_schema.Protocol): return {'type': 'callable'} + # elif isinstance(obj, ForwardRef): origin = get_origin(obj) assert origin is not None, f'origin cannot be None, obj={obj}, you probably need to fix generate_self_schema.py' @@ -151,6 +152,9 @@ def type_dict_schema( # noqa: C901 else: field_type = eval_forward_ref(field_type) + if fr_arg == 'SchemaValidator' or fr_arg == 'SchemaSerializer': + schema = {'type': 'is-instance', 'cls': Ident(fr_arg), 'cls_repr': f'pydantic_core.{fr_arg}'} + if schema is None: if get_origin(field_type) == core_schema.Required: required = True @@ -202,7 +206,7 @@ def main() -> None: definitions: dict[str, core_schema.CoreSchema] = {} choices = {} - for s in schema_union.__args__: + for s in get_args(schema_union): type_ = s.__annotations__['type'] m = re.search(r"Literal\['(.+?)']", type_.__forward_arg__) assert m, f'Unknown schema type: {type_}' @@ -217,9 +221,9 @@ def main() -> None: *definitions.values(), ], ) - python_code = ( - f'# this file is auto-generated by generate_self_schema.py, DO NOT edit manually\nself_schema = {schema}\n' - ) + python_code = f"""# this file is auto-generated by generate_self_schema.py, DO NOT edit manually +from pydantic_core import SchemaValidator, SchemaSerializer +self_schema = {schema}\n""" try: from black import Mode, TargetVersion, format_file_contents except ImportError: @@ -236,5 +240,12 @@ def main() -> None: print(f'Self schema definition written to {SAVE_PATH}') +class Ident(str): + """Format a literal as a Ident in the output""" + + def __repr__(self) -> str: + return str(self) + + if __name__ == '__main__': main() diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 2d7061ffd..fd541e243 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -30,15 +30,17 @@ from typing import Literal if TYPE_CHECKING: - from pydantic_core import PydanticUndefined + from pydantic_core import PydanticUndefined, SchemaSerializer, SchemaValidator else: - # The initial build of pydantic_core requires PydanticUndefined to generate + # The initial build of pydantic_core requires some Rust structures to generate # the core schema; so we need to conditionally skip it. mypy doesn't like # this at all, hence the TYPE_CHECKING branch above. try: - from pydantic_core import PydanticUndefined + from pydantic_core import PydanticUndefined, SchemaSerializer, SchemaValidator except ImportError: - PydanticUndefined = object() + PydanticUndefined = 'PydanticUndefined' + SchemaValidator = 'SchemaValidator' + SchemaSerializer = 'SchemaSerializer' ExtraBehavior = Literal['allow', 'forbid', 'ignore'] @@ -3605,6 +3607,50 @@ def definition_reference_schema( return _dict_not_none(type='definition-ref', schema_ref=schema_ref, metadata=metadata, serialization=serialization) +class PrecompiledSchema(TypedDict, total=False): + type: Required[Literal['precompiled']] + schema: CoreSchema + validator: SchemaValidator + serializer: SchemaSerializer + ref: str + metadata: Any + + +def precompiled_schema( + schema: CoreSchema, + validator: SchemaValidator, + serializer: SchemaSerializer, + ref: str | None = None, + metadata: Any = None, +) -> PrecompiledSchema: + """ + Returns a schema that points to a schema stored in "definitions", this is useful for nested recursive + models and also when you want to define validators separately from the main schema, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema_definition = core_schema.definition_reference_schema('list-schema') + schema = core_schema.definitions_schema( + schema=schema_definition, + definitions=[ + core_schema.list_schema(items_schema=schema_definition, ref='list-schema'), + ], + ) + v = SchemaValidator(schema) + assert v.validate_python([()]) == [[]] + ``` + + Args: + schema_ref: The schema ref to use for the definition reference schema + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='precompiled', schema=schema, validator=validator, serializer=serializer, ref=ref, metadata=metadata + ) + + MYPY = False # See https://github.com/python/mypy/issues/14034 for details, in summary mypy is extremely slow to process this # union which kills performance not just for pydantic, but even for code using pydantic @@ -3658,6 +3704,7 @@ def definition_reference_schema( DefinitionsSchema, DefinitionReferenceSchema, UuidSchema, + PrecompiledSchema, ] elif False: CoreSchema: TypeAlias = Mapping[str, Any] @@ -3713,6 +3760,7 @@ def definition_reference_schema( 'definitions', 'definition-ref', 'uuid', + 'precompiled', ] CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field'] diff --git a/src/definitions.rs b/src/definitions.rs index 0d01fd2ae..1eb813015 100644 --- a/src/definitions.rs +++ b/src/definitions.rs @@ -3,16 +3,20 @@ /// Unlike json schema we let you put definitions inline, not just in a single '#/$defs/' block or similar. /// We use DefinitionsBuilder to collect the references / definitions into a single vector /// and then get a definition from a reference using an integer id (just for performance of not using a HashMap) -use std::collections::hash_map::Entry; +use std::{ + collections::hash_map::Entry, + fmt::Debug, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, OnceLock, + }, +}; -use pyo3::prelude::*; +use pyo3::{prelude::*, PyTraverseError, PyVisit}; use ahash::AHashMap; -use crate::build_tools::py_schema_err; - -// An integer id for the reference -pub type ReferenceId = usize; +use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse}; /// Definitions are validators and serializers that are /// shared by reference. @@ -24,91 +28,227 @@ pub type ReferenceId = usize; /// They get indexed by a ReferenceId, which are integer identifiers /// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer} /// gets build. -pub type Definitions = [T]; +#[derive(Clone)] +pub struct Definitions(AHashMap, Definition>); -#[derive(Clone, Debug)] -struct Definition { - pub id: ReferenceId, - pub value: Option, +impl Definitions { + pub fn values(&self) -> impl Iterator> { + self.0.values() + } +} + +/// Internal type which contains a definition to be filled +pub struct Definition(Arc>); + +impl Definition { + pub fn get(&self) -> Option<&T> { + self.0.value.get() + } +} + +struct DefinitionInner { + value: OnceLock, + name: LazyName, +} + +/// Reference to a definition. +pub struct DefinitionRef { + name: Arc, + value: Definition, +} + +// DefinitionRef can always be cloned (#[derive(Clone)] would require T: Clone) +impl Clone for DefinitionRef { + fn clone(&self) -> Self { + Self { + name: self.name.clone(), + value: self.value.clone(), + } + } +} + +impl DefinitionRef { + pub fn id(&self) -> usize { + Arc::as_ptr(&self.value.0) as usize + } + + pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str { + match self.value.0.value.get() { + Some(value) => self.value.0.name.get_or_init(|| init(value)), + None => "...", + } + } + + pub fn get(&self) -> Option<&T> { + self.value.0.value.get() + } +} + +impl Debug for DefinitionRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To avoid possible infinite recursion from recursive definitions, + // a DefinitionRef just displays debug as its name + self.name.fmt(f) + } +} + +impl Debug for Definitions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Formatted as a list for backwards compatibility; in principle + // this could be formatted as a map. Maybe change in a future + // minor release of pydantic. + write![f, "["]?; + let mut first = true; + for def in self.0.values() { + write![f, "{sep}{def:?}", sep = if first { "" } else { ", " }]?; + first = false; + } + write![f, "]"]?; + Ok(()) + } +} + +impl Clone for Definition { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Debug for Definition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0.value.get() { + Some(value) => value.fmt(f), + None => "...".fmt(f), + } + } +} + +impl PyGcTraverse for DefinitionRef { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + if let Some(value) = self.value.0.value.get() { + value.py_gc_traverse(visit)?; + } + Ok(()) + } +} + +impl PyGcTraverse for Definitions { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + for value in self.0.values() { + if let Some(value) = value.0.value.get() { + value.py_gc_traverse(visit)?; + } + } + Ok(()) + } } #[derive(Clone, Debug)] pub struct DefinitionsBuilder { - definitions: AHashMap>, + definitions: Definitions, } -impl DefinitionsBuilder { +impl DefinitionsBuilder { pub fn new() -> Self { Self { - definitions: AHashMap::new(), + definitions: Definitions(AHashMap::new()), } } /// Get a ReferenceId for the given reference string. - // This ReferenceId can later be used to retrieve a definition - pub fn get_reference_id(&mut self, reference: &str) -> ReferenceId { - let next_id = self.definitions.len(); + pub fn get_definition(&mut self, reference: &str) -> DefinitionRef { // We either need a String copy or two hashmap lookups // Neither is better than the other // We opted for the easier outward facing API - match self.definitions.entry(reference.to_string()) { - Entry::Occupied(entry) => entry.get().id, - Entry::Vacant(entry) => { - entry.insert(Definition { - id: next_id, - value: None, - }); - next_id - } + let name = Arc::new(reference.to_string()); + let value = match self.definitions.0.entry(name.clone()) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner { + value: OnceLock::new(), + name: LazyName::new(), + }))), + }; + DefinitionRef { + name, + value: value.clone(), } } /// Add a definition, returning the ReferenceId that maps to it - pub fn add_definition(&mut self, reference: String, value: T) -> PyResult { - let next_id = self.definitions.len(); - match self.definitions.entry(reference.clone()) { - Entry::Occupied(mut entry) => match entry.get_mut().value.replace(value) { - Some(_) => py_schema_err!("Duplicate ref: `{}`", reference), - None => Ok(entry.get().id), - }, - Entry::Vacant(entry) => { - entry.insert(Definition { - id: next_id, - value: Some(value), - }); - Ok(next_id) + pub fn add_definition(&mut self, reference: String, value: T) -> PyResult> { + let name = Arc::new(reference); + let value = match self.definitions.0.entry(name.clone()) { + Entry::Occupied(entry) => { + let definition = entry.into_mut(); + match definition.0.value.set(value) { + Ok(()) => definition.clone(), + Err(_) => return py_schema_err!("Duplicate ref: `{}`", name), + } + } + Entry::Vacant(entry) => entry + .insert(Definition(Arc::new(DefinitionInner { + value: OnceLock::from(value), + name: LazyName::new(), + }))) + .clone(), + }; + Ok(DefinitionRef { name, value }) + } + + /// Consume this Definitions into a vector of items, indexed by each items ReferenceId + pub fn finish(self) -> PyResult> { + for (reference, def) in &self.definitions.0 { + if def.0.value.get().is_none() { + return py_schema_err!("Definitions error: definition `{}` was never filled", reference); } } + Ok(self.definitions) } +} - /// Retrieve an item definition using a ReferenceId - /// If the definition doesn't yet exist (as happens in recursive types) then we create it - /// At the end (in finish()) we check that there are no undefined definitions - pub fn get_definition(&self, reference_id: ReferenceId) -> PyResult<&T> { - let (reference, def) = match self.definitions.iter().find(|(_, def)| def.id == reference_id) { - Some(v) => v, - None => return py_schema_err!("Definitions error: no definition for ReferenceId `{}`", reference_id), - }; - match def.value.as_ref() { - Some(v) => Ok(v), - None => py_schema_err!( - "Definitions error: attempted to use `{}` before it was filled", - reference - ), +struct LazyName { + initialized: OnceLock, + in_recursion: AtomicBool, +} + +impl LazyName { + fn new() -> Self { + Self { + initialized: OnceLock::new(), + in_recursion: AtomicBool::new(false), } } - /// Consume this Definitions into a vector of items, indexed by each items ReferenceId - pub fn finish(self) -> PyResult> { - // We need to create a vec of defs according to the order in their ids - let mut defs: Vec<(usize, T)> = Vec::new(); - for (reference, def) in self.definitions { - match def.value { - None => return py_schema_err!("Definitions error: definition {} was never filled", reference), - Some(v) => defs.push((def.id, v)), - } + /// Gets the validator name, returning the default in the case of recursion loops + fn get_or_init(&self, init: impl FnOnce() -> String) -> &str { + if let Some(s) = self.initialized.get() { + return s.as_str(); + } + + if self + .in_recursion + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_err() + { + return "..."; + } + let result = self.initialized.get_or_init(init).as_str(); + self.in_recursion.store(false, Ordering::SeqCst); + result + } +} + +impl Debug for LazyName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.initialized.get().map_or("...", String::as_str).fmt(f) + } +} + +impl Clone for LazyName { + fn clone(&self) -> Self { + Self { + initialized: OnceLock::new(), + in_recursion: AtomicBool::new(false), } - defs.sort_by_key(|(id, _)| *id); - Ok(defs.into_iter().map(|(_, v)| v).collect()) } } diff --git a/src/py_gc.rs b/src/py_gc.rs index 02df02e13..8af285afb 100644 --- a/src/py_gc.rs +++ b/src/py_gc.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use ahash::AHashMap; use enum_dispatch::enum_dispatch; use pyo3::{AsPyPointer, Py, PyTraverseError, PyVisit}; @@ -35,6 +37,12 @@ impl PyGcTraverse for AHashMap { } } +impl PyGcTraverse for Arc { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + T::py_gc_traverse(self, visit) + } +} + impl PyGcTraverse for Box { fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { T::py_gc_traverse(self, visit) diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 9972a82c4..65c5a1ba9 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -10,8 +10,6 @@ use serde::ser::Error; use super::config::SerializationConfig; use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER}; use super::ob_type::ObTypeLookup; -use super::shared::CombinedSerializer; -use crate::definitions::Definitions; use crate::recursion_guard::RecursionGuard; /// this is ugly, would be much better if extra could be stored in `SerializationState` @@ -48,7 +46,6 @@ impl SerializationState { Extra::new( py, mode, - &[], by_alias, &self.warnings, false, @@ -72,7 +69,6 @@ impl SerializationState { #[cfg_attr(debug_assertions, derive(Debug))] pub(crate) struct Extra<'a> { pub mode: &'a SerMode, - pub definitions: &'a Definitions, pub ob_type_lookup: &'a ObTypeLookup, pub warnings: &'a CollectWarnings, pub by_alias: bool, @@ -98,7 +94,6 @@ impl<'a> Extra<'a> { pub fn new( py: Python<'a>, mode: &'a SerMode, - definitions: &'a Definitions, by_alias: bool, warnings: &'a CollectWarnings, exclude_unset: bool, @@ -112,7 +107,6 @@ impl<'a> Extra<'a> { ) -> Self { Self { mode, - definitions, ob_type_lookup: ObTypeLookup::cached(py), warnings, by_alias, @@ -156,7 +150,6 @@ impl SerCheck { #[cfg_attr(debug_assertions, derive(Debug))] pub(crate) struct ExtraOwned { mode: SerMode, - definitions: Vec, warnings: CollectWarnings, by_alias: bool, exclude_unset: bool, @@ -176,7 +169,6 @@ impl ExtraOwned { pub fn new(extra: &Extra) -> Self { Self { mode: extra.mode.clone(), - definitions: extra.definitions.to_vec(), warnings: extra.warnings.clone(), by_alias: extra.by_alias, exclude_unset: extra.exclude_unset, @@ -196,7 +188,6 @@ impl ExtraOwned { pub fn to_extra<'py>(&'py self, py: Python<'py>) -> Extra<'py> { Extra { mode: &self.mode, - definitions: &self.definitions, ob_type_lookup: ObTypeLookup::cached(py), warnings: &self.warnings, by_alias: self.by_alias, diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 6dbc076fe..c1f828a30 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -5,7 +5,7 @@ use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict}; use pyo3::{PyTraverseError, PyVisit}; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::py_gc::PyGcTraverse; use config::SerializationConfig; @@ -26,11 +26,12 @@ mod ob_type; mod shared; mod type_serializers; -#[pyclass(module = "pydantic_core._pydantic_core")] +#[pyclass(module = "pydantic_core._pydantic_core", frozen)] #[derive(Debug)] pub struct SchemaSerializer { serializer: CombinedSerializer, - definitions: Vec, + schema: PyObject, + definitions: Definitions, expected_json_size: AtomicUsize, config: SerializationConfig, } @@ -54,7 +55,6 @@ impl SchemaSerializer { Extra::new( py, mode, - &self.definitions, by_alias, warnings, exclude_unset, @@ -78,6 +78,7 @@ impl SchemaSerializer { let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?; Ok(Self { serializer, + schema: schema.into(), definitions: definitions_builder.finish()?, expected_json_size: AtomicUsize::new(1024), config: SerializationConfig::from_config(config)?, @@ -184,9 +185,8 @@ impl SchemaSerializer { fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { self.serializer.py_gc_traverse(&visit)?; - for slot in &self.definitions { - slot.py_gc_traverse(&visit)?; - } + visit.call(&self.schema)?; + self.definitions.py_gc_traverse(&visit)?; Ok(()) } } diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index b9b0c1fe1..96a9e7fd5 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -13,7 +13,7 @@ use serde_json::ser::PrettyFormatter; use crate::build_tools::py_schema_err; use crate::build_tools::py_schema_error_type; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::py_gc::PyGcTraverse; use crate::tools::{py_err, SchemaDict}; @@ -141,6 +141,7 @@ combined_serializer! { Recursive: super::type_serializers::definitions::DefinitionRefSerializer; TuplePositional: super::type_serializers::tuple::TuplePositionalSerializer; TupleVariable: super::type_serializers::tuple::TupleVariableSerializer; + Precompiled: super::type_serializers::precompiled::PrecompiledSerializer; } } @@ -250,6 +251,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::TuplePositional(inner) => inner.py_gc_traverse(visit), CombinedSerializer::TupleVariable(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::Precompiled(inner) => inner.py_gc_traverse(visit), } } } @@ -293,7 +295,7 @@ pub(crate) trait TypeSerializer: Send + Sync + Clone + Debug { fn get_name(&self) -> &str; /// Used by union serializers to decide if it's worth trying again while allowing subclasses - fn retry_with_lax_check(&self, _definitions: &Definitions) -> bool { + fn retry_with_lax_check(&self) -> bool { false } diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index 124f962ad..787e267dd 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -6,7 +6,7 @@ use std::borrow::Cow; use ahash::AHashMap; use crate::build_tools::{py_schema_error_type, ExtraBehavior}; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; use super::{ @@ -179,7 +179,7 @@ impl TypeSerializer for DataclassSerializer { &self.name } - fn retry_with_lax_check(&self, _definitions: &Definitions) -> bool { + fn retry_with_lax_check(&self) -> bool { true } } diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 4614bbc56..cf92df244 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -4,7 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; -use crate::definitions::Definitions; +use crate::definitions::DefinitionRef; use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; @@ -41,7 +41,7 @@ impl BuildSerializer for DefinitionsSerializerBuilder { #[derive(Debug, Clone)] pub struct DefinitionRefSerializer { - serializer_id: usize, + definition: DefinitionRef, } impl BuildSerializer for DefinitionRefSerializer { @@ -52,9 +52,9 @@ impl BuildSerializer for DefinitionRefSerializer { _config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - let schema_ref: String = schema.get_as_req(intern!(schema.py(), "schema_ref"))?; - let serializer_id = definitions.get_reference_id(&schema_ref); - Ok(Self { serializer_id }.into()) + let schema_ref = schema.get_as_req(intern!(schema.py(), "schema_ref"))?; + let definition = definitions.get_definition(schema_ref); + Ok(Self { definition }.into()) } } @@ -68,10 +68,10 @@ impl TypeSerializer for DefinitionRefSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> PyResult { - let value_id = extra.rec_guard.add(value, self.serializer_id)?; - let comb_serializer = extra.definitions.get(self.serializer_id).unwrap(); + let comb_serializer = self.definition.get().unwrap(); + let value_id = extra.rec_guard.add(value, self.definition.id())?; let r = comb_serializer.to_python(value, include, exclude, extra); - extra.rec_guard.pop(value_id, self.serializer_id); + extra.rec_guard.pop(value_id, self.definition.id()); r } @@ -87,10 +87,13 @@ impl TypeSerializer for DefinitionRefSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> Result { - let value_id = extra.rec_guard.add(value, self.serializer_id).map_err(py_err_se_err)?; - let comb_serializer = extra.definitions.get(self.serializer_id).unwrap(); + let comb_serializer = self.definition.get().unwrap(); + let value_id = extra + .rec_guard + .add(value, self.definition.id()) + .map_err(py_err_se_err)?; let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, self.serializer_id); + extra.rec_guard.pop(value_id, self.definition.id()); r } @@ -98,8 +101,7 @@ impl TypeSerializer for DefinitionRefSerializer { Self::EXPECTED_TYPE } - fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { - let comb_serializer = definitions.get(self.serializer_id).unwrap(); - comb_serializer.retry_with_lax_check(definitions) + fn retry_with_lax_check(&self) -> bool { + self.definition.get().unwrap().retry_with_lax_check() } } diff --git a/src/serializers/type_serializers/mod.rs b/src/serializers/type_serializers/mod.rs index b942b5b86..b97762a36 100644 --- a/src/serializers/type_serializers/mod.rs +++ b/src/serializers/type_serializers/mod.rs @@ -15,6 +15,7 @@ pub mod literal; pub mod model; pub mod nullable; pub mod other; +pub mod precompiled; pub mod set_frozenset; pub mod simple; pub mod string; diff --git a/src/serializers/type_serializers/model.rs b/src/serializers/type_serializers/model.rs index c5b252fbf..8a2eeb4e1 100644 --- a/src/serializers/type_serializers/model.rs +++ b/src/serializers/type_serializers/model.rs @@ -13,7 +13,7 @@ use super::{ }; use crate::build_tools::py_schema_err; use crate::build_tools::{py_schema_error_type, ExtraBehavior}; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::serializers::errors::PydanticSerializationUnexpectedValue; use crate::tools::SchemaDict; @@ -228,7 +228,7 @@ impl TypeSerializer for ModelSerializer { &self.name } - fn retry_with_lax_check(&self, _definitions: &Definitions) -> bool { + fn retry_with_lax_check(&self) -> bool { true } } diff --git a/src/serializers/type_serializers/nullable.rs b/src/serializers/type_serializers/nullable.rs index 837d6c5f1..23349ec81 100644 --- a/src/serializers/type_serializers/nullable.rs +++ b/src/serializers/type_serializers/nullable.rs @@ -4,7 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; use super::{infer_json_key_known, BuildSerializer, CombinedSerializer, Extra, IsType, ObType, TypeSerializer}; @@ -75,7 +75,7 @@ impl TypeSerializer for NullableSerializer { Self::EXPECTED_TYPE } - fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { - self.serializer.retry_with_lax_check(definitions) + fn retry_with_lax_check(&self) -> bool { + self.serializer.retry_with_lax_check() } } diff --git a/src/serializers/type_serializers/precompiled.rs b/src/serializers/type_serializers/precompiled.rs new file mode 100644 index 000000000..e38f45a49 --- /dev/null +++ b/src/serializers/type_serializers/precompiled.rs @@ -0,0 +1,80 @@ +use std::borrow::Cow; + +use pyo3::types::PyDict; +use pyo3::{intern, prelude::*}; + +use crate::build_tools::py_schema_err; +use crate::definitions::DefinitionsBuilder; +use crate::serializers::shared::TypeSerializer; +use crate::serializers::Extra; +use crate::tools::SchemaDict; +use crate::SchemaSerializer; + +use super::{BuildSerializer, CombinedSerializer}; + +#[derive(Debug, Clone)] +pub struct PrecompiledSerializer { + serializer: Py, +} + +impl BuildSerializer for PrecompiledSerializer { + const EXPECTED_TYPE: &'static str = "precompiled"; + + fn build( + schema: &PyDict, + _config: Option<&PyDict>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + let sub_schema: &PyAny = schema.get_as_req(intern!(py, "schema"))?; + let serializer: PyRef = schema.get_as_req(intern!(py, "serializer"))?; + + // TODO DEBUG THIS LATER + // if !serializer.schema.is(sub_schema) { + // return py_schema_err!("precompiled schema mismatch"); + // } + + Ok(CombinedSerializer::Precompiled(PrecompiledSerializer { + serializer: serializer.into(), + })) + } +} + +impl_py_gc_traverse!(PrecompiledSerializer { serializer }); + +impl TypeSerializer for PrecompiledSerializer { + fn to_python( + &self, + value: &PyAny, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> PyResult { + self.serializer + .get() + .serializer + .to_python(value, include, exclude, extra) + } + + fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { + self.serializer.get().serializer.json_key(key, extra) + } + + fn serde_serialize( + &self, + value: &PyAny, + serializer: S, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> Result { + self.serializer + .get() + .serializer + .serde_serialize(value, serializer, include, exclude, extra) + } + + fn get_name(&self) -> &str { + self.serializer.get().serializer.get_name() + } +} diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 70818959e..788620408 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -4,7 +4,7 @@ use pyo3::types::{PyDict, PyList, PyTuple}; use std::borrow::Cow; use crate::build_tools::py_schema_err; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; use crate::PydanticSerializationUnexpectedValue; @@ -87,7 +87,7 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check(extra.definitions) { + if self.retry_with_lax_check() { new_extra.check = SerCheck::Lax; for comb_serializer in &self.choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { @@ -116,7 +116,7 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check(extra.definitions) { + if self.retry_with_lax_check() { new_extra.check = SerCheck::Lax; for comb_serializer in &self.choices { match comb_serializer.json_key(key, &new_extra) { @@ -153,7 +153,7 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check(extra.definitions) { + if self.retry_with_lax_check() { new_extra.check = SerCheck::Lax; for comb_serializer in &self.choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { @@ -174,10 +174,8 @@ impl TypeSerializer for UnionSerializer { &self.name } - fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { - self.choices - .iter() - .any(|choice| choice.retry_with_lax_check(definitions)) + fn retry_with_lax_check(&self) -> bool { + self.choices.iter().any(CombinedSerializer::retry_with_lax_check) } } diff --git a/src/serializers/type_serializers/with_default.rs b/src/serializers/type_serializers/with_default.rs index 148c05052..d20c273a1 100644 --- a/src/serializers/type_serializers/with_default.rs +++ b/src/serializers/type_serializers/with_default.rs @@ -4,7 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; use crate::validators::DefaultType; @@ -67,8 +67,8 @@ impl TypeSerializer for WithDefaultSerializer { Self::EXPECTED_TYPE } - fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { - self.serializer.retry_with_lax_check(definitions) + fn retry_with_lax_check(&self) -> bool { + self.serializer.retry_with_lax_check() } fn get_default(&self, py: Python) -> PyResult> { diff --git a/src/validators/any.rs b/src/validators/any.rs index eddde1725..625eb4adf 100644 --- a/src/validators/any.rs +++ b/src/validators/any.rs @@ -34,11 +34,7 @@ impl Validator for AnyValidator { Ok(input.to_object(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { false } @@ -46,7 +42,7 @@ impl Validator for AnyValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index 2c0fe4a0a..aa0870043 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -15,7 +15,7 @@ use crate::tools::SchemaDict; use super::validation_state::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] struct Parameter { positional: bool, name: String, @@ -24,7 +24,7 @@ struct Parameter { validator: CombinedValidator, } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ArgumentsValidator { parameters: Vec, positional_params_count: usize, @@ -332,29 +332,25 @@ impl Validator for ArgumentsValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { self.parameters .iter() - .any(|p| p.validator.different_strict_behavior(definitions, ultra_strict)) + .any(|p| p.validator.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { self.parameters - .iter_mut() - .try_for_each(|parameter| parameter.validator.complete(definitions))?; - if let Some(v) = &mut self.var_args_validator { - v.complete(definitions)?; + .iter() + .try_for_each(|parameter| parameter.validator.complete())?; + if let Some(v) = &self.var_args_validator { + v.complete()?; } - if let Some(v) = &mut self.var_kwargs_validator { - v.complete(definitions)?; + if let Some(v) = &self.var_kwargs_validator { + v.complete()?; }; Ok(()) } diff --git a/src/validators/bool.rs b/src/validators/bool.rs index d87c1c1d7..3a38cf3e5 100644 --- a/src/validators/bool.rs +++ b/src/validators/bool.rs @@ -42,11 +42,7 @@ impl Validator for BoolValidator { Ok(input.validate_bool(strict)?.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -54,7 +50,7 @@ impl Validator for BoolValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/bytes.rs b/src/validators/bytes.rs index 2084f916e..0f662af1c 100644 --- a/src/validators/bytes.rs +++ b/src/validators/bytes.rs @@ -50,11 +50,7 @@ impl Validator for BytesValidator { Ok(either_bytes.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -62,7 +58,7 @@ impl Validator for BytesValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } @@ -112,11 +108,7 @@ impl Validator for BytesConstrainedValidator { Ok(either_bytes.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -124,7 +116,7 @@ impl Validator for BytesConstrainedValidator { "constrained-bytes" } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/call.rs b/src/validators/call.rs index 24c7f4111..3f6eb4d35 100644 --- a/src/validators/call.rs +++ b/src/validators/call.rs @@ -11,7 +11,7 @@ use crate::tools::SchemaDict; use super::validation_state::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct CallValidator { function: PyObject, arguments_validator: Box, @@ -98,28 +98,23 @@ impl Validator for CallValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if let Some(return_validator) = &self.return_validator { - if return_validator.different_strict_behavior(definitions, ultra_strict) { + if return_validator.different_strict_behavior(ultra_strict) { return true; } } - self.arguments_validator - .different_strict_behavior(definitions, ultra_strict) + self.arguments_validator.different_strict_behavior(ultra_strict) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.arguments_validator.complete(definitions)?; - match &mut self.return_validator { - Some(v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + self.arguments_validator.complete()?; + match &self.return_validator { + Some(v) => v.complete(), None => Ok(()), } } diff --git a/src/validators/callable.rs b/src/validators/callable.rs index 9b565e3eb..83eb37cbe 100644 --- a/src/validators/callable.rs +++ b/src/validators/callable.rs @@ -36,11 +36,7 @@ impl Validator for CallableValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { false } @@ -48,7 +44,7 @@ impl Validator for CallableValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/chain.rs b/src/validators/chain.rs index 001947d1f..c0f356fa0 100644 --- a/src/validators/chain.rs +++ b/src/validators/chain.rs @@ -10,7 +10,7 @@ use crate::tools::SchemaDict; use super::validation_state::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ChainValidator { steps: Vec, name: String, @@ -83,21 +83,15 @@ impl Validator for ChainValidator { steps_iter.try_fold(value, |v, step| step.validate(py, v.into_ref(py), state)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.steps - .iter() - .any(|v| v.different_strict_behavior(definitions, ultra_strict)) + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + self.steps.iter().any(|v| v.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.steps.iter_mut().try_for_each(|v| v.complete(definitions)) + fn complete(&self) -> PyResult<()> { + self.steps.iter().try_for_each(CombinedValidator::complete) } } diff --git a/src/validators/custom_error.rs b/src/validators/custom_error.rs index 1e8258090..0d9931c62 100644 --- a/src/validators/custom_error.rs +++ b/src/validators/custom_error.rs @@ -57,7 +57,7 @@ impl CustomError { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct CustomErrorValidator { validator: Box, custom_error: CustomError, @@ -99,19 +99,15 @@ impl Validator for CustomErrorValidator { .map_err(|_| self.custom_error.as_val_error(input)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.validator.different_strict_behavior(definitions, ultra_strict) + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + self.validator.different_strict_behavior(ultra_strict) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 117596b9f..dff7735a3 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -19,7 +19,7 @@ use super::{ build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, }; -#[derive(Debug, Clone)] +#[derive(Debug)] struct Field { kw_only: bool, name: String, @@ -30,7 +30,7 @@ struct Field { frozen: bool, } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct DataclassArgsValidator { fields: Vec, positional_count: usize, @@ -426,28 +426,22 @@ impl Validator for DataclassArgsValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { self.fields .iter() - .any(|f| f.validator.different_strict_behavior(definitions, ultra_strict)) + .any(|f| f.validator.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { &self.validator_name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.fields - .iter_mut() - .try_for_each(|field| field.validator.complete(definitions)) + fn complete(&self) -> PyResult<()> { + self.fields.iter().try_for_each(|field| field.validator.complete()) } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct DataclassValidator { strict: bool, validator: Box, @@ -588,13 +582,9 @@ impl Validator for DataclassValidator { Ok(obj.to_object(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.validator.different_strict_behavior(definitions, ultra_strict) + self.validator.different_strict_behavior(ultra_strict) } else { true } @@ -604,8 +594,8 @@ impl Validator for DataclassValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } diff --git a/src/validators/date.rs b/src/validators/date.rs index a771a5045..3549f66f0 100644 --- a/src/validators/date.rs +++ b/src/validators/date.rs @@ -96,11 +96,7 @@ impl Validator for DateValidator { Ok(date.try_into_py(py)?) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -108,7 +104,7 @@ impl Validator for DateValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index 7596b7aca..baf4ca467 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -125,11 +125,7 @@ impl Validator for DateTimeValidator { Ok(datetime.try_into_py(py)?) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -137,7 +133,7 @@ impl Validator for DateTimeValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index 2564e096a..211befe07 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -230,11 +230,7 @@ impl Validator for DecimalValidator { Ok(decimal.into()) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { true } @@ -242,7 +238,7 @@ impl Validator for DecimalValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 3a35fce4c..16aea8cd4 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -1,7 +1,12 @@ +use std::cell::RefCell; + +use ahash::HashSet; +use ahash::HashSetExt; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; +use crate::definitions::DefinitionRef; use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; @@ -39,17 +44,12 @@ impl BuildValidator for DefinitionsValidatorBuilder { #[derive(Debug, Clone)] pub struct DefinitionRefValidator { - validator_id: usize, - inner_name: String, - // we have to record the answers to `Question`s as we can't access the validator when `ask()` is called + definition: DefinitionRef, } impl DefinitionRefValidator { - pub fn new(validator_id: usize) -> Self { - Self { - validator_id, - inner_name: "...".to_string(), - } + pub fn new(definition: DefinitionRef) -> Self { + Self { definition } } } @@ -61,15 +61,10 @@ impl BuildValidator for DefinitionRefValidator { _config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - let schema_ref: String = schema.get_as_req(intern!(schema.py(), "schema_ref"))?; - - let validator_id = definitions.get_reference_id(&schema_ref); + let schema_ref = schema.get_as_req(intern!(schema.py(), "schema_ref"))?; - Ok(Self { - validator_id, - inner_name: "...".to_string(), - } - .into()) + let definition = definitions.get_definition(schema_ref); + Ok(Self::new(definition).into()) } } @@ -82,21 +77,22 @@ impl Validator for DefinitionRefValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { + let validator = self.definition.get().unwrap(); if let Some(id) = input.identity() { - if state.recursion_guard.contains_or_insert(id, self.validator_id) { + if state.recursion_guard.contains_or_insert(id, self.definition.id()) { // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) } else { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); } - let output = validate(self.validator_id, py, input, state); - state.recursion_guard.remove(id, self.validator_id); + let output = validator.validate(py, input, state); + state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output } } else { - validate(self.validator_id, py, input, state) + validator.validate(py, input, state) } } @@ -108,69 +104,51 @@ impl Validator for DefinitionRefValidator { field_value: &'data PyAny, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { + let validator = self.definition.get().unwrap(); if let Some(id) = obj.identity() { - if state.recursion_guard.contains_or_insert(id, self.validator_id) { + if state.recursion_guard.contains_or_insert(id, self.definition.id()) { // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) } else { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); } - let output = validate_assignment(self.validator_id, py, obj, field_name, field_value, state); - state.recursion_guard.remove(id, self.validator_id); + let output = validator.validate_assignment(py, obj, field_name, field_value, state); + state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output } } else { - validate_assignment(self.validator_id, py, obj, field_name, field_value, state) + validator.validate_assignment(py, obj, field_name, field_value, state) } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if let Some(definitions) = definitions { - // have to unwrap here, because we can't return an error from this function, should be okay - let validator = definitions.get_definition(self.validator_id).unwrap(); - validator.different_strict_behavior(None, ultra_strict) + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + thread_local! { + static RECURSION_SET: RefCell>> = RefCell::new(None); + } + + let id = self as *const _ as usize; + // have to unwrap here, because we can't return an error from this function, should be okay + let validator: &CombinedValidator = self.definition.get().unwrap(); + if RECURSION_SET.with( + |set: &RefCell>>| { + set.borrow_mut().get_or_insert_with(HashSet::new).insert(id) + }, + ) { + let different_strict_behavior = validator.different_strict_behavior(ultra_strict); + RECURSION_SET.with(|set| set.borrow_mut().get_or_insert_with(HashSet::new).remove(&id)); + different_strict_behavior } else { false } } fn get_name(&self) -> &str { - &self.inner_name + self.definition.get_or_init_name(|v| v.get_name().into()) } - /// don't need to call complete on the inner validator here, complete_validators takes care of that. - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - let validator = definitions.get_definition(self.validator_id)?; - self.inner_name = validator.get_name().to_string(); + fn complete(&self) -> PyResult<()> { Ok(()) } } - -fn validate<'data>( - validator_id: usize, - py: Python<'data>, - input: &'data impl Input<'data>, - state: &mut ValidationState, -) -> ValResult<'data, PyObject> { - let validator = state.definitions.get(validator_id).unwrap(); - validator.validate(py, input, state) -} - -#[allow(clippy::too_many_arguments)] -fn validate_assignment<'data>( - validator_id: usize, - py: Python<'data>, - obj: &'data PyAny, - field_name: &'data str, - field_value: &'data PyAny, - state: &mut ValidationState, -) -> ValResult<'data, PyObject> { - let validator = state.definitions.get(validator_id).unwrap(); - validator.validate_assignment(py, obj, field_name, field_value, state) -} diff --git a/src/validators/dict.rs b/src/validators/dict.rs index dc8f03937..250145290 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -16,7 +16,7 @@ use super::any::AnyValidator; use super::list::length_check; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct DictValidator { strict: bool, key_validator: Box, @@ -92,14 +92,9 @@ impl Validator for DictValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.key_validator.different_strict_behavior(definitions, true) - || self.value_validator.different_strict_behavior(definitions, true) + self.key_validator.different_strict_behavior(true) || self.value_validator.different_strict_behavior(true) } else { true } @@ -109,9 +104,9 @@ impl Validator for DictValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.key_validator.complete(definitions)?; - self.value_validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.key_validator.complete()?; + self.value_validator.complete() } } diff --git a/src/validators/float.rs b/src/validators/float.rs index f0eb41750..2e9434d9f 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -76,11 +76,7 @@ impl Validator for FloatValidator { Ok(either_float.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { true } @@ -88,7 +84,7 @@ impl Validator for FloatValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } @@ -179,11 +175,7 @@ impl Validator for ConstrainedFloatValidator { Ok(either_float.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { true } @@ -191,7 +183,7 @@ impl Validator for ConstrainedFloatValidator { "constrained-float" } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/frozenset.rs b/src/validators/frozenset.rs index ad7708324..4b4cdcb6f 100644 --- a/src/validators/frozenset.rs +++ b/src/validators/frozenset.rs @@ -10,7 +10,7 @@ use super::set::set_build; use super::validation_state::ValidationState; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FrozenSetValidator { strict: bool, item_validator: Box, @@ -48,13 +48,9 @@ impl Validator for FrozenSetValidator { Ok(f_set.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.item_validator.different_strict_behavior(definitions, true) + self.item_validator.different_strict_behavior(true) } else { true } @@ -64,7 +60,7 @@ impl Validator for FrozenSetValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.item_validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.item_validator.complete() } } diff --git a/src/validators/function.rs b/src/validators/function.rs index be0d6374f..adb143696 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use pyo3::exceptions::{PyAssertionError, PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyAny, PyDict, PyString}; @@ -111,14 +113,9 @@ macro_rules! impl_validator { self._validate(validate, py, obj, state) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.validator - .different_strict_behavior(definitions, ultra_strict) + self.validator.different_strict_behavior(ultra_strict) } else { true } @@ -128,14 +125,14 @@ macro_rules! impl_validator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } }; } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FunctionBeforeValidator { validator: Box, func: PyObject, @@ -168,7 +165,7 @@ impl FunctionBeforeValidator { impl_validator!(FunctionBeforeValidator); -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FunctionAfterValidator { validator: Box, func: PyObject, @@ -255,11 +252,7 @@ impl Validator for FunctionPlainValidator { r.map_err(|e| convert_err(py, e, input)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { // best guess, should we change this? !ultra_strict } @@ -268,14 +261,14 @@ impl Validator for FunctionPlainValidator { &self.name } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FunctionWrapValidator { - validator: Box, + validator: Arc, func: PyObject, config: PyObject, name: String, @@ -299,7 +292,7 @@ impl BuildValidator for FunctionWrapValidator { let hide_input_in_errors: bool = config.get_as(intern!(py, "hide_input_in_errors"))?.unwrap_or(false); let validation_error_cause: bool = config.get_as(intern!(py, "validation_error_cause"))?.unwrap_or(false); Ok(Self { - validator: Box::new(validator), + validator: Arc::new(validator), func: function_info.function.clone(), config: match config { Some(c) => c.into(), @@ -350,7 +343,7 @@ impl Validator for FunctionWrapValidator { validator: InternalValidator::new( py, "ValidatorCallable", - &self.validator, + self.validator.clone(), state, self.hide_input_in_errors, self.validation_error_cause, @@ -376,7 +369,7 @@ impl Validator for FunctionWrapValidator { validator: InternalValidator::new( py, "AssignmentValidatorCallable", - &self.validator, + self.validator.clone(), state, self.hide_input_in_errors, self.validation_error_cause, @@ -387,13 +380,9 @@ impl Validator for FunctionWrapValidator { self._validate(Py::new(py, handler)?.into_ref(py), py, obj, state) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.validator.different_strict_behavior(definitions, ultra_strict) + self.validator.different_strict_behavior(ultra_strict) } else { true } @@ -403,13 +392,13 @@ impl Validator for FunctionWrapValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } #[pyclass(module = "pydantic_core._pydantic_core")] -#[derive(Debug, Clone)] +#[derive(Debug)] struct ValidatorCallable { validator: InternalValidator, } @@ -441,7 +430,7 @@ impl ValidatorCallable { } #[pyclass(module = "pydantic_core._pydantic_core")] -#[derive(Debug, Clone)] +#[derive(Debug)] struct AssignmentValidatorCallable { updated_field_name: String, updated_field_value: Py, diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 0cff7e28e..111b5c101 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -1,4 +1,5 @@ use std::fmt; +use std::sync::Arc; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -14,7 +15,7 @@ use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, InputT #[derive(Debug, Clone)] pub struct GeneratorValidator { - item_validator: Option>, + item_validator: Option>, min_length: Option, max_length: Option, name: String, @@ -30,7 +31,7 @@ impl BuildValidator for GeneratorValidator { config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - let item_validator = get_items_schema(schema, config, definitions)?; + let item_validator = get_items_schema(schema, config, definitions)?.map(Arc::new); let name = match item_validator { Some(ref v) => format!("{}[{}]", Self::EXPECTED_TYPE, v.get_name()), None => format!("{}[any]", Self::EXPECTED_TYPE), @@ -67,7 +68,7 @@ impl Validator for GeneratorValidator { InternalValidator::new( py, "ValidatorIterator", - v, + v.clone(), state, self.hide_input_in_errors, self.validation_error_cause, @@ -85,13 +86,9 @@ impl Validator for GeneratorValidator { Ok(v_iterator.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if let Some(ref v) = self.item_validator { - v.different_strict_behavior(definitions, ultra_strict) + v.different_strict_behavior(ultra_strict) } else { false } @@ -101,16 +98,16 @@ impl Validator for GeneratorValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - match self.item_validator { - Some(ref mut v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + match &self.item_validator { + Some(v) => v.complete(), None => Ok(()), } } } #[pyclass(module = "pydantic_core._pydantic_core")] -#[derive(Debug, Clone)] +#[derive(Debug)] struct ValidatorIterator { iterator: GenericIterator, validator: Option, @@ -217,13 +214,11 @@ impl ValidatorIterator { } } -/// Cloneable validator wrapper for use in generators in functions, this can be passed back to python +/// Owned validator wrapper for use in generators in functions, this can be passed back to python /// mid-validation -#[derive(Clone)] pub struct InternalValidator { name: String, - validator: CombinedValidator, - definitions: Vec, + validator: Arc, // TODO, do we need data? data: Option>, strict: Option, @@ -246,7 +241,7 @@ impl InternalValidator { pub fn new( py: Python, name: &str, - validator: &CombinedValidator, + validator: Arc, state: &ValidationState, hide_input_in_errors: bool, validation_error_cause: bool, @@ -254,8 +249,7 @@ impl InternalValidator { let extra = state.extra(); Self { name: name.to_string(), - validator: validator.clone(), - definitions: state.definitions.to_vec(), + validator, data: extra.data.map(|d| d.into_py(py)), strict: extra.strict, from_attributes: extra.from_attributes, @@ -285,7 +279,7 @@ impl InternalValidator { context: self.context.as_ref().map(|data| data.as_ref(py)), self_instance: self.self_instance.as_ref().map(|data| data.as_ref(py)), }; - let mut state = ValidationState::new(extra, &self.definitions, &mut self.recursion_guard); + let mut state = ValidationState::new(extra, &mut self.recursion_guard); self.validator .validate_assignment(py, model, field_name, field_value, &mut state) .map_err(|e| { @@ -316,7 +310,7 @@ impl InternalValidator { context: self.context.as_ref().map(|data| data.as_ref(py)), self_instance: self.self_instance.as_ref().map(|data| data.as_ref(py)), }; - let mut state = ValidationState::new(extra, &self.definitions, &mut self.recursion_guard); + let mut state = ValidationState::new(extra, &mut self.recursion_guard); self.validator.validate(py, input, &mut state).map_err(|e| { ValidationError::from_val_error( py, @@ -333,7 +327,6 @@ impl InternalValidator { impl_py_gc_traverse!(InternalValidator { validator, - definitions, data, context, self_instance diff --git a/src/validators/int.rs b/src/validators/int.rs index 3fba2199d..0903e1998 100644 --- a/src/validators/int.rs +++ b/src/validators/int.rs @@ -54,11 +54,7 @@ impl Validator for IntValidator { Ok(either_int.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -66,7 +62,7 @@ impl Validator for IntValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } @@ -151,11 +147,7 @@ impl Validator for ConstrainedIntValidator { Ok(either_int.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -163,7 +155,7 @@ impl Validator for ConstrainedIntValidator { "constrained-int" } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/is_instance.rs b/src/validators/is_instance.rs index 78705482c..e64d0717c 100644 --- a/src/validators/is_instance.rs +++ b/src/validators/is_instance.rs @@ -83,11 +83,7 @@ impl Validator for IsInstanceValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { false } @@ -95,7 +91,7 @@ impl Validator for IsInstanceValidator { &self.name } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/is_subclass.rs b/src/validators/is_subclass.rs index d0f5a6cfe..0866fa1e7 100644 --- a/src/validators/is_subclass.rs +++ b/src/validators/is_subclass.rs @@ -62,11 +62,7 @@ impl Validator for IsSubclassValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { false } @@ -74,7 +70,7 @@ impl Validator for IsSubclassValidator { &self.name } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/json.rs b/src/validators/json.rs index 5eda007be..fd832f874 100644 --- a/src/validators/json.rs +++ b/src/validators/json.rs @@ -9,7 +9,7 @@ use crate::tools::SchemaDict; use super::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct JsonValidator { validator: Option>, name: String, @@ -61,13 +61,9 @@ impl Validator for JsonValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if let Some(ref v) = self.validator { - v.different_strict_behavior(definitions, ultra_strict) + v.different_strict_behavior(ultra_strict) } else { false } @@ -77,9 +73,9 @@ impl Validator for JsonValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - match self.validator { - Some(ref mut v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + match &self.validator { + Some(v) => v.complete(), None => Ok(()), } } diff --git a/src/validators/json_or_python.rs b/src/validators/json_or_python.rs index 828532fe5..cd952bed1 100644 --- a/src/validators/json_or_python.rs +++ b/src/validators/json_or_python.rs @@ -11,7 +11,7 @@ use super::InputType; use super::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct JsonOrPython { json: Box, python: Box, @@ -63,21 +63,16 @@ impl Validator for JsonOrPython { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.json.different_strict_behavior(definitions, ultra_strict) - || self.python.different_strict_behavior(definitions, ultra_strict) + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + self.json.different_strict_behavior(ultra_strict) || self.python.different_strict_behavior(ultra_strict) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.json.complete(definitions)?; - self.python.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.json.complete()?; + self.python.complete() } } diff --git a/src/validators/lax_or_strict.rs b/src/validators/lax_or_strict.rs index 9681cf689..b5cec61be 100644 --- a/src/validators/lax_or_strict.rs +++ b/src/validators/lax_or_strict.rs @@ -10,7 +10,7 @@ use crate::tools::SchemaDict; use super::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct LaxOrStrictValidator { strict: bool, lax_validator: Box, @@ -68,13 +68,9 @@ impl Validator for LaxOrStrictValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.strict_validator.different_strict_behavior(definitions, true) + self.strict_validator.different_strict_behavior(true) } else { true } @@ -84,8 +80,8 @@ impl Validator for LaxOrStrictValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.lax_validator.complete(definitions)?; - self.strict_validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.lax_validator.complete()?; + self.strict_validator.complete() } } diff --git a/src/validators/list.rs b/src/validators/list.rs index ffd7a118e..8e931657f 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -1,3 +1,5 @@ +use std::sync::OnceLock; + use pyo3::prelude::*; use pyo3::types::PyDict; @@ -7,26 +9,26 @@ use crate::tools::SchemaDict; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ListValidator { strict: bool, item_validator: Option>, min_length: Option, max_length: Option, - name: String, + name: OnceLock, } pub fn get_items_schema( schema: &PyDict, config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, -) -> PyResult>> { +) -> PyResult> { match schema.get_item(pyo3::intern!(schema.py(), "items_schema")) { Some(d) => { let validator = build_validator(d, config, definitions)?; match validator { CombinedValidator::Any(_) => Ok(None), - _ => Ok(Some(Box::new(validator))), + _ => Ok(Some(validator)), } } None => Ok(None), @@ -98,15 +100,13 @@ impl BuildValidator for ListValidator { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let item_validator = get_items_schema(schema, config, definitions)?; - let inner_name = item_validator.as_ref().map_or("any", |v| v.get_name()); - let name = format!("{}[{inner_name}]", Self::EXPECTED_TYPE); + let item_validator = get_items_schema(schema, config, definitions)?.map(Box::new); Ok(Self { strict: crate::build_tools::is_strict(schema, config)?, item_validator, min_length: schema.get_as(pyo3::intern!(py, "min_length"))?, max_length: schema.get_as(pyo3::intern!(py, "max_length"))?, - name, + name: OnceLock::new(), } .into()) } @@ -138,14 +138,10 @@ impl Validator for ListValidator { Ok(output.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { match self.item_validator { - Some(ref v) => v.different_strict_behavior(definitions, true), + Some(ref v) => v.different_strict_behavior(true), None => false, } } else { @@ -154,14 +150,27 @@ impl Validator for ListValidator { } fn get_name(&self) -> &str { - &self.name + // The logic here is a little janky, it's done to try to cache the formatted name + // while also trying to render definitions correctly when possible. + // + // Probably an opportunity for a future refactor + match self.name.get() { + Some(s) => s.as_str(), + None => { + let name = self.item_validator.as_ref().map_or("any", |v| v.get_name()); + if name == "..." { + // when inner name is not initialized yet, don't cache it here + "list[...]" + } else { + self.name.get_or_init(|| format!("list[{name}]")).as_str() + } + } + } } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - if let Some(ref mut v) = self.item_validator { - v.complete(definitions)?; - let inner_name = v.get_name(); - self.name = format!("{}[{inner_name}]", Self::EXPECTED_TYPE); + fn complete(&self) -> PyResult<()> { + if let Some(v) = &self.item_validator { + v.complete()?; } Ok(()) } diff --git a/src/validators/literal.rs b/src/validators/literal.rs index de394affb..25cb94bd9 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -22,7 +22,7 @@ struct BoolLiteral { } #[derive(Debug, Clone)] -pub struct LiteralLookup { +pub struct LiteralLookup { // Specialized lookups for ints, bools and strings because they // (1) are easy to convert between Rust and Python // (2) hashing them in Rust is very fast @@ -35,7 +35,7 @@ pub struct LiteralLookup { pub values: Vec, } -impl LiteralLookup { +impl LiteralLookup { pub fn new<'py>(py: Python<'py>, expected: impl Iterator) -> PyResult { let mut expected_int = AHashMap::new(); let mut expected_str: AHashMap = AHashMap::new(); @@ -135,7 +135,7 @@ impl LiteralLookup { } } -impl PyGcTraverse for LiteralLookup { +impl PyGcTraverse for LiteralLookup { fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { self.expected_py.py_gc_traverse(visit)?; self.values.py_gc_traverse(visit)?; @@ -198,11 +198,7 @@ impl Validator for LiteralValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -210,7 +206,7 @@ impl Validator for LiteralValidator { &self.name } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 4ee677663..bb6f16229 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -9,7 +9,7 @@ use pyo3::types::{PyAny, PyDict, PyTuple, PyType}; use pyo3::{intern, PyTraverseError, PyVisit}; use crate::build_tools::{py_schema_err, py_schema_error_type, SchemaError}; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::errors::{LocItem, ValError, ValResult, ValidationError}; use crate::input::{Input, InputType, StringMapping}; use crate::py_gc::PyGcTraverse; @@ -46,6 +46,7 @@ mod model; mod model_fields; mod none; mod nullable; +mod precompiled; mod set; mod string; mod time; @@ -97,11 +98,11 @@ impl PySome { } } -#[pyclass(module = "pydantic_core._pydantic_core")] -#[derive(Debug, Clone)] +#[pyclass(module = "pydantic_core._pydantic_core", frozen)] +#[derive(Debug)] pub struct SchemaValidator { validator: CombinedValidator, - definitions: Vec, + definitions: Definitions, schema: PyObject, #[pyo3(get)] title: PyObject, @@ -115,11 +116,11 @@ impl SchemaValidator { pub fn py_new(py: Python, schema: &PyAny, config: Option<&PyDict>) -> PyResult { let mut definitions_builder = DefinitionsBuilder::new(); - let mut validator = build_validator(schema, config, &mut definitions_builder)?; - validator.complete(&definitions_builder)?; - let mut definitions = definitions_builder.clone().finish()?; - for val in &mut definitions { - val.complete(&definitions_builder)?; + let validator = build_validator(schema, config, &mut definitions_builder)?; + let definitions = definitions_builder.finish()?; + validator.complete()?; + for val in definitions.values() { + val.get().unwrap().complete()?; } let config_title = match config { Some(c) => c.get_item("title"), @@ -141,9 +142,10 @@ impl SchemaValidator { }) } - pub fn __reduce__(&self, py: Python) -> PyResult { - let args = (self.schema.as_ref(py),); - let cls = Py::new(py, self.clone())?.getattr(py, "__class__")?; + pub fn __reduce__(slf: &PyCell) -> PyResult { + let py = slf.py(); + let args = (slf.try_borrow()?.schema.to_object(py),); + let cls = slf.getattr("__class__")?; Ok((cls, args).into_py(py)) } @@ -266,7 +268,7 @@ impl SchemaValidator { }; let guard = &mut RecursionGuard::default(); - let mut state = ValidationState::new(extra, &self.definitions, guard); + let mut state = ValidationState::new(extra, guard); self.validator .validate_assignment(py, obj, field_name, field_value, &mut state) .map_err(|e| self.prepare_validation_err(py, e, InputType::Python)) @@ -284,7 +286,7 @@ impl SchemaValidator { self_instance: None, }; let recursion_guard = &mut RecursionGuard::default(); - let mut state = ValidationState::new(extra, &self.definitions, recursion_guard); + let mut state = ValidationState::new(extra, recursion_guard); let r = self.validator.default_value(py, None::, &mut state); match r { Ok(maybe_default) => match maybe_default { @@ -307,9 +309,6 @@ impl SchemaValidator { fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { self.validator.py_gc_traverse(&visit)?; visit.call(&self.schema)?; - for slot in &self.definitions { - slot.py_gc_traverse(&visit)?; - } Ok(()) } } @@ -332,7 +331,6 @@ impl SchemaValidator { { let mut state = ValidationState::new( Extra::new(strict, from_attributes, context, self_instance, input_type), - &self.definitions, recursion_guard, ); self.validator.validate(py, input, &mut state) @@ -371,7 +369,6 @@ impl<'py> SelfValidator<'py> { let mut recursion_guard = RecursionGuard::default(); let mut state = ValidationState::new( Extra::new(strict, None, None, None, InputType::Python), - &self.validator.definitions, &mut recursion_guard, ); match self.validator.validator.validate(py, schema, &mut state) { @@ -388,14 +385,14 @@ impl<'py> SelfValidator<'py> { let mut definitions_builder = DefinitionsBuilder::new(); - let mut validator = match build_validator(self_schema, None, &mut definitions_builder) { + let validator = match build_validator(self_schema, None, &mut definitions_builder) { Ok(v) => v, Err(err) => return py_schema_err!("Error building self-schema:\n {}", err), }; - validator.complete(&definitions_builder)?; - let mut definitions = definitions_builder.clone().finish()?; - for val in &mut definitions { - val.complete(&definitions_builder)?; + let definitions = definitions_builder.finish()?; + validator.complete()?; + for val in definitions.values() { + val.get().unwrap().complete()?; } Ok(SchemaValidator { validator, @@ -546,6 +543,8 @@ pub fn build_validator<'a>( // recursive (self-referencing) models definitions::DefinitionRefValidator, definitions::DefinitionsValidatorBuilder, + // precompiled models + precompiled::PrecompiledValidator, ) } @@ -603,7 +602,7 @@ impl<'a> Extra<'a> { } } -#[derive(Debug, Clone)] +#[derive(Debug)] #[enum_dispatch(PyGcTraverse)] pub enum CombinedValidator { // typed dict e.g. heterogeneous dicts or simply a model @@ -694,12 +693,14 @@ pub enum CombinedValidator { DefinitionRef(definitions::DefinitionRefValidator), // input dependent JsonOrPython(json_or_python::JsonOrPython), + // reusing a sub-schema + Precompiled(precompiled::PrecompiledValidator), } /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, /// validators defined in `build_validator` also need `EXPECTED_TYPE` as a const, but that can't be part of the trait #[enum_dispatch(CombinedValidator)] -pub trait Validator: Send + Sync + Clone + Debug { +pub trait Validator: Send + Sync + Debug { /// Do the actual validation for this schema/type fn validate<'data>( &self, @@ -734,17 +735,13 @@ pub trait Validator: Send + Sync + Clone + Debug { /// whether the validator behaves differently in strict mode, and in ultra strict mode /// implementations should return true if any of their sub-validators return true - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool; + fn different_strict_behavior(&self, ultra_strict: bool) -> bool; /// `get_name` generally returns `Self::EXPECTED_TYPE` or some other clear identifier of the validator /// this is used in the error location in unions, and in the top level message in `ValidationError` fn get_name(&self) -> &str; /// this method must be implemented for any validator which holds references to other validators, - /// it is used by `DefinitionRefValidator` to set its name - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()>; + /// it is used by `UnionValidator` to calculate strictness + fn complete(&self) -> PyResult<()>; } diff --git a/src/validators/model.rs b/src/validators/model.rs index 2ec7185a9..1459f56f5 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -50,7 +50,7 @@ impl Revalidate { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ModelValidator { revalidate: Revalidate, validator: Box, @@ -206,13 +206,9 @@ impl Validator for ModelValidator { Ok(model.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.validator.different_strict_behavior(definitions, ultra_strict) + self.validator.different_strict_behavior(ultra_strict) } else { true } @@ -222,8 +218,8 @@ impl Validator for ModelValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index f2654c33e..30d937cb8 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -20,7 +20,7 @@ use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuild use std::ops::ControlFlow; -#[derive(Debug, Clone)] +#[derive(Debug)] struct Field { name: String, lookup_key: LookupKey, @@ -31,7 +31,7 @@ struct Field { impl_py_gc_traverse!(Field { validator }); -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ModelFieldsValidator { fields: Vec, model_name: String, @@ -415,26 +415,20 @@ impl Validator for ModelFieldsValidator { Ok((new_data.to_object(py), new_extra, fields_set.to_object(py)).to_object(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { self.fields .iter() - .any(|f| f.validator.different_strict_behavior(definitions, ultra_strict)) + .any(|f| f.validator.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.fields - .iter_mut() - .try_for_each(|f| f.validator.complete(definitions))?; - match &mut self.extras_validator { - Some(v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + self.fields.iter().try_for_each(|f| f.validator.complete())?; + match &self.extras_validator { + Some(v) => v.complete(), None => Ok(()), } } diff --git a/src/validators/none.rs b/src/validators/none.rs index 36be70acb..f241be9d8 100644 --- a/src/validators/none.rs +++ b/src/validators/none.rs @@ -36,11 +36,7 @@ impl Validator for NoneValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { false } @@ -48,7 +44,7 @@ impl Validator for NoneValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/nullable.rs b/src/validators/nullable.rs index 4b408f206..7f4cf19fc 100644 --- a/src/validators/nullable.rs +++ b/src/validators/nullable.rs @@ -9,7 +9,7 @@ use crate::tools::SchemaDict; use super::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct NullableValidator { validator: Box, name: String, @@ -45,19 +45,15 @@ impl Validator for NullableValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.validator.different_strict_behavior(definitions, ultra_strict) + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + self.validator.different_strict_behavior(ultra_strict) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } diff --git a/src/validators/precompiled.rs b/src/validators/precompiled.rs new file mode 100644 index 000000000..17e52fd25 --- /dev/null +++ b/src/validators/precompiled.rs @@ -0,0 +1,65 @@ +use pyo3::types::PyDict; +use pyo3::{intern, prelude::*}; + +use crate::build_tools::py_schema_err; +use crate::definitions::DefinitionsBuilder; +use crate::errors::ValResult; +use crate::input::Input; +use crate::tools::SchemaDict; +use crate::SchemaValidator; + +use super::{BuildValidator, CombinedValidator, ValidationState, Validator}; + +#[derive(Debug)] +pub struct PrecompiledValidator { + validator: Py, +} + +impl BuildValidator for PrecompiledValidator { + const EXPECTED_TYPE: &'static str = "precompiled"; + + fn build( + schema: &PyDict, + _config: Option<&PyDict>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + let sub_schema: &PyAny = schema.get_as_req(intern!(py, "schema"))?; + let validator: PyRef = schema.get_as_req(intern!(py, "validator"))?; + + // TODO DEBUG THIS LATER + // if !validator.schema.is(sub_schema) { + // return py_schema_err!("precompiled schema mismatch"); + // } + + Ok(CombinedValidator::Precompiled(PrecompiledValidator { + validator: validator.into(), + })) + } +} + +impl_py_gc_traverse!(PrecompiledValidator { validator }); + +impl Validator for PrecompiledValidator { + fn validate<'data>( + &self, + py: Python<'data>, + input: &'data impl Input<'data>, + state: &mut ValidationState, + ) -> ValResult<'data, PyObject> { + self.validator.get().validator.validate(py, input, state) + } + + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + self.validator.get().validator.different_strict_behavior(ultra_strict) + } + + fn get_name(&self) -> &str { + self.validator.get().validator.get_name() + } + + fn complete(&self) -> PyResult<()> { + // No need to complete a precompiled validator + Ok(()) + } +} diff --git a/src/validators/set.rs b/src/validators/set.rs index e5e2cecf3..626572139 100644 --- a/src/validators/set.rs +++ b/src/validators/set.rs @@ -8,7 +8,7 @@ use crate::tools::SchemaDict; use super::list::min_length_check; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct SetValidator { strict: bool, item_validator: Box, @@ -70,13 +70,9 @@ impl Validator for SetValidator { Ok(set.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.item_validator.different_strict_behavior(definitions, true) + self.item_validator.different_strict_behavior(true) } else { true } @@ -86,7 +82,7 @@ impl Validator for SetValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.item_validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.item_validator.complete() } } diff --git a/src/validators/string.rs b/src/validators/string.rs index 6b646224d..4eab4602a 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -10,7 +10,7 @@ use crate::tools::SchemaDict; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct StrValidator { strict: bool, coerce_numbers_to_str: bool, @@ -51,11 +51,7 @@ impl Validator for StrValidator { Ok(either_str.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -63,7 +59,7 @@ impl Validator for StrValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } @@ -150,11 +146,7 @@ impl Validator for StrConstrainedValidator { Ok(py_string.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -162,7 +154,7 @@ impl Validator for StrConstrainedValidator { "constrained-str" } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/time.rs b/src/validators/time.rs index 7bbd7e511..f5e2be7c7 100644 --- a/src/validators/time.rs +++ b/src/validators/time.rs @@ -78,11 +78,7 @@ impl Validator for TimeValidator { Ok(time.try_into_py(py)?) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -90,7 +86,7 @@ impl Validator for TimeValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/timedelta.rs b/src/validators/timedelta.rs index 106d5a64a..21340f2f0 100644 --- a/src/validators/timedelta.rs +++ b/src/validators/timedelta.rs @@ -101,11 +101,7 @@ impl Validator for TimeDeltaValidator { Ok(py_timedelta.into()) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -113,7 +109,7 @@ impl Validator for TimeDeltaValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index 5c2c09bec..cfa239e3b 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -10,7 +10,7 @@ use crate::tools::SchemaDict; use super::list::{get_items_schema, min_length_check}; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TupleVariableValidator { strict: bool, item_validator: Option>, @@ -27,7 +27,7 @@ impl BuildValidator for TupleVariableValidator { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let item_validator = get_items_schema(schema, config, definitions)?; + let item_validator = get_items_schema(schema, config, definitions)?.map(Box::new); let inner_name = item_validator.as_ref().map_or("any", |v| v.get_name()); let name = format!("tuple[{inner_name}, ...]"); Ok(Self { @@ -60,14 +60,10 @@ impl Validator for TupleVariableValidator { Ok(PyTuple::new(py, &output).into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { match self.item_validator { - Some(ref v) => v.different_strict_behavior(definitions, true), + Some(ref v) => v.different_strict_behavior(true), None => false, } } else { @@ -79,15 +75,15 @@ impl Validator for TupleVariableValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - match self.item_validator { - Some(ref mut v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + match &self.item_validator { + Some(v) => v.complete(), None => Ok(()), } } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TuplePositionalValidator { strict: bool, items_validators: Vec, @@ -242,20 +238,12 @@ impl Validator for TuplePositionalValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - if self - .items_validators - .iter() - .any(|v| v.different_strict_behavior(definitions, true)) - { + if self.items_validators.iter().any(|v| v.different_strict_behavior(true)) { true } else if let Some(ref v) = self.extras_validator { - v.different_strict_behavior(definitions, true) + v.different_strict_behavior(true) } else { false } @@ -268,12 +256,10 @@ impl Validator for TuplePositionalValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.items_validators - .iter_mut() - .try_for_each(|v| v.complete(definitions))?; - match &mut self.extras_validator { - Some(v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + self.items_validators.iter().try_for_each(CombinedValidator::complete)?; + match &self.extras_validator { + Some(v) => v.complete(), None => Ok(()), } } diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 56e4a8225..1a5b52dc6 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -20,7 +20,7 @@ use super::{ build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, }; -#[derive(Debug, Clone)] +#[derive(Debug)] struct TypedDictField { name: String, lookup_key: LookupKey, @@ -31,7 +31,7 @@ struct TypedDictField { impl_py_gc_traverse!(TypedDictField { validator }); -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TypedDictValidator { fields: Vec, extra_behavior: ExtraBehavior, @@ -307,26 +307,20 @@ impl Validator for TypedDictValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { self.fields .iter() - .any(|f| f.validator.different_strict_behavior(definitions, ultra_strict)) + .any(|f| f.validator.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.fields - .iter_mut() - .try_for_each(|f| f.validator.complete(definitions))?; - match &mut self.extras_validator { - Some(v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + self.fields.iter().try_for_each(|f| f.validator.complete())?; + match &self.extras_validator { + Some(v) => v.complete(), None => Ok(()), } } diff --git a/src/validators/union.rs b/src/validators/union.rs index 4d3b0bd78..79a21e78d 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -1,5 +1,6 @@ use std::fmt::Write; use std::str::FromStr; +use std::sync::atomic::{AtomicBool, Ordering}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString, PyTuple}; @@ -18,11 +19,11 @@ use super::custom_error::CustomError; use super::literal::LiteralLookup; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone, Copy)] +#[derive(Debug)] enum UnionMode { Smart { - strict_required: bool, - ultra_strict_required: bool, + strict_required: AtomicBool, + ultra_strict_required: AtomicBool, }, LeftToRight, } @@ -31,8 +32,23 @@ impl UnionMode { // construct smart with some default values const fn default_smart() -> Self { Self::Smart { - strict_required: true, - ultra_strict_required: false, + strict_required: AtomicBool::new(true), + ultra_strict_required: AtomicBool::new(false), + } + } +} + +impl Clone for UnionMode { + fn clone(&self) -> Self { + match self { + Self::Smart { + strict_required, + ultra_strict_required, + } => Self::Smart { + strict_required: AtomicBool::new(strict_required.load(Ordering::SeqCst)), + ultra_strict_required: AtomicBool::new(ultra_strict_required.load(Ordering::SeqCst)), + }, + Self::LeftToRight => Self::LeftToRight, } } } @@ -49,7 +65,7 @@ impl FromStr for UnionMode { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct UnionValidator { mode: UnionMode, choices: Vec<(CombinedValidator, Option)>, @@ -216,44 +232,46 @@ impl Validator for UnionValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - match self.mode { + match &self.mode { UnionMode::Smart { strict_required, ultra_strict_required, - } => self.validate_smart(py, input, state, strict_required, ultra_strict_required), + } => self.validate_smart( + py, + input, + state, + strict_required.load(Ordering::SeqCst), + ultra_strict_required.load(Ordering::SeqCst), + ), UnionMode::LeftToRight => self.validate_left_to_right(py, input, state), } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { self.choices .iter() - .any(|(v, _)| v.different_strict_behavior(definitions, ultra_strict)) + .any(|(v, _)| v.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.choices.iter_mut().try_for_each(|(v, _)| v.complete(definitions))?; + fn complete(&self) -> PyResult<()> { + self.choices.iter().try_for_each(|(v, _)| v.complete())?; if let UnionMode::Smart { - ref mut strict_required, - ref mut ultra_strict_required, - } = self.mode + strict_required, + ultra_strict_required, + } = &self.mode { - *strict_required = self - .choices - .iter() - .any(|(v, _)| v.different_strict_behavior(Some(definitions), false)); - *ultra_strict_required = self - .choices - .iter() - .any(|(v, _)| v.different_strict_behavior(Some(definitions), true)); + strict_required.store( + self.choices.iter().any(|(v, _)| v.different_strict_behavior(false)), + Ordering::SeqCst, + ); + ultra_strict_required.store( + self.choices.iter().any(|(v, _)| v.different_strict_behavior(true)), + Ordering::SeqCst, + ); } Ok(()) @@ -357,7 +375,7 @@ impl PyGcTraverse for Discriminator { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TaggedUnionValidator { discriminator: Discriminator, lookup: LiteralLookup, @@ -476,26 +494,19 @@ impl Validator for TaggedUnionValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { self.lookup .values .iter() - .any(|v| v.different_strict_behavior(definitions, ultra_strict)) + .any(|v| v.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.lookup - .values - .iter_mut() - .try_for_each(|validator| validator.complete(definitions)) + fn complete(&self) -> PyResult<()> { + self.lookup.values.iter().try_for_each(CombinedValidator::complete) } } diff --git a/src/validators/url.rs b/src/validators/url.rs index 0afc76e59..4584ae652 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -92,11 +92,7 @@ impl Validator for UrlValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -104,7 +100,7 @@ impl Validator for UrlValidator { &self.name } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } @@ -232,11 +228,7 @@ impl Validator for MultiHostUrlValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -244,7 +236,7 @@ impl Validator for MultiHostUrlValidator { &self.name } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/uuid.rs b/src/validators/uuid.rs index ca924ce66..94d302438 100644 --- a/src/validators/uuid.rs +++ b/src/validators/uuid.rs @@ -122,11 +122,7 @@ impl Validator for UuidValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -134,7 +130,7 @@ impl Validator for UuidValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index 6cf5ce313..79ec8b87a 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -1,25 +1,16 @@ -use crate::{definitions::Definitions, recursion_guard::RecursionGuard}; +use crate::recursion_guard::RecursionGuard; -use super::{CombinedValidator, Extra}; +use super::Extra; pub struct ValidationState<'a> { pub recursion_guard: &'a mut RecursionGuard, - pub definitions: &'a Definitions, // deliberately make Extra readonly extra: Extra<'a>, } impl<'a> ValidationState<'a> { - pub fn new( - extra: Extra<'a>, - definitions: &'a Definitions, - recursion_guard: &'a mut RecursionGuard, - ) -> Self { - Self { - recursion_guard, - definitions, - extra, - } + pub fn new(extra: Extra<'a>, recursion_guard: &'a mut RecursionGuard) -> Self { + Self { recursion_guard, extra } } pub fn with_new_extra<'r, R: 'r>( @@ -31,7 +22,6 @@ impl<'a> ValidationState<'a> { // but lifetimes get in a tangle. Maybe someone brave wants to have a go at unpicking lifetimes. let mut new_state = ValidationState { recursion_guard: self.recursion_guard, - definitions: self.definitions, extra, }; f(&mut new_state) diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index d68590766..36b275dd1 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -66,7 +66,7 @@ enum OnError { Default, } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct WithDefaultValidator { default: DefaultType, on_error: OnError, @@ -182,20 +182,16 @@ impl Validator for WithDefaultValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.validator.different_strict_behavior(definitions, ultra_strict) + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + self.validator.different_strict_behavior(ultra_strict) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index d4b53cbe4..b8dcd9fe3 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -39,6 +39,10 @@ def args(*args, **kwargs): return args, kwargs +INT_SCHEMA = core_schema.int_schema() +INT_VALIDATOR = SchemaValidator(INT_SCHEMA) +INT_SERIALIZER = SchemaSerializer(INT_SCHEMA) + all_schema_functions = [ (core_schema.any_schema, args(), {'type': 'any'}), (core_schema.any_schema, args(metadata=['foot', 'spa']), {'type': 'any', 'metadata': ['foot', 'spa']}), @@ -289,6 +293,11 @@ def args(*args, **kwargs): (core_schema.uuid_schema, args(), {'type': 'uuid'}), (core_schema.decimal_schema, args(), {'type': 'decimal'}), (core_schema.decimal_schema, args(multiple_of=5, gt=1.2), {'type': 'decimal', 'multiple_of': 5, 'gt': 1.2}), + ( + core_schema.precompiled_schema, + args(schema=INT_SCHEMA, validator=INT_VALIDATOR, serializer=INT_SERIALIZER), + {'type': 'precompiled', 'schema': INT_SCHEMA, 'validator': INT_VALIDATOR, 'serializer': INT_SERIALIZER}, + ), ] diff --git a/tests/validators/test_definitions_recursive.py b/tests/validators/test_definitions_recursive.py index b836eb7a1..2d676ac17 100644 --- a/tests/validators/test_definitions_recursive.py +++ b/tests/validators/test_definitions_recursive.py @@ -1,3 +1,4 @@ +import datetime import platform from dataclasses import dataclass from typing import List, Optional @@ -243,7 +244,7 @@ class Branch: def test_invalid_schema(): - with pytest.raises(SchemaError, match='Definitions error: attempted to use `Branch` before it was filled'): + with pytest.raises(SchemaError, match='Definitions error: definition `Branch` was never filled'): SchemaValidator( { 'type': 'list', @@ -895,3 +896,192 @@ class Model: 'url': f'https://errors.pydantic.dev/{pydantic_version}/v/dataclass_type', } ] + + +def test_cyclic_data() -> None: + cyclic_data = {} + cyclic_data['b'] = {'a': cyclic_data} + + schema = core_schema.definitions_schema( + core_schema.definition_reference_schema('a'), + [ + core_schema.typed_dict_schema( + { + 'b': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('b')) + ) + }, + ref='a', + ), + core_schema.typed_dict_schema( + { + 'a': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('a')) + ) + }, + ref='b', + ), + ], + ) + + validator = SchemaValidator(schema) + + with pytest.raises(ValidationError) as exc_info: + validator.validate_python(cyclic_data) + + assert exc_info.value.title == 'typed-dict' + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'recursion_loop', + 'loc': ('b', 'a'), + 'msg': 'Recursion error - cyclic reference detected', + 'input': cyclic_data, + } + ] + + +def test_cyclic_data_threeway() -> None: + cyclic_data = {} + cyclic_data['b'] = {'c': {'a': cyclic_data}} + + schema = core_schema.definitions_schema( + core_schema.definition_reference_schema('a'), + [ + core_schema.typed_dict_schema( + { + 'b': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('b')) + ) + }, + ref='a', + ), + core_schema.typed_dict_schema( + { + 'c': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('c')) + ) + }, + ref='b', + ), + core_schema.typed_dict_schema( + { + 'a': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('a')) + ) + }, + ref='c', + ), + ], + ) + + validator = SchemaValidator(schema) + + with pytest.raises(ValidationError) as exc_info: + validator.validate_python(cyclic_data) + + assert exc_info.value.title == 'typed-dict' + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'recursion_loop', + 'loc': ('b', 'c', 'a'), + 'msg': 'Recursion error - cyclic reference detected', + 'input': cyclic_data, + } + ] + + +def test_complex_recursive_type() -> None: + schema = core_schema.definitions_schema( + core_schema.definition_reference_schema('JsonType'), + [ + core_schema.nullable_schema( + core_schema.union_schema( + [ + core_schema.list_schema(core_schema.definition_reference_schema('JsonType')), + core_schema.dict_schema( + core_schema.str_schema(), core_schema.definition_reference_schema('JsonType') + ), + core_schema.str_schema(), + core_schema.int_schema(), + core_schema.float_schema(), + core_schema.bool_schema(), + ] + ), + ref='JsonType', + ) + ], + ) + + validator = SchemaValidator(schema) + + with pytest.raises(ValidationError) as exc_info: + validator.validate_python({'a': datetime.date(year=1992, month=12, day=11)}) + + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'list_type', + 'loc': ('list[nullable[union[list[...],dict[str,...],str,int,float,bool]]]',), + 'msg': 'Input should be a valid list', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + { + 'type': 'list_type', + 'loc': ('dict[str,...]', 'a', 'list[nullable[union[list[...],dict[str,...],str,int,float,bool]]]'), + 'msg': 'Input should be a valid list', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'dict_type', + 'loc': ('dict[str,...]', 'a', 'dict[str,...]'), + 'msg': 'Input should be a valid dictionary', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'string_type', + 'loc': ('dict[str,...]', 'a', 'str'), + 'msg': 'Input should be a valid string', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'int_type', + 'loc': ('dict[str,...]', 'a', 'int'), + 'msg': 'Input should be a valid integer', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'float_type', + 'loc': ('dict[str,...]', 'a', 'float'), + 'msg': 'Input should be a valid number', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'bool_type', + 'loc': ('dict[str,...]', 'a', 'bool'), + 'msg': 'Input should be a valid boolean', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'string_type', + 'loc': ('str',), + 'msg': 'Input should be a valid string', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + { + 'type': 'int_type', + 'loc': ('int',), + 'msg': 'Input should be a valid integer', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + { + 'type': 'float_type', + 'loc': ('float',), + 'msg': 'Input should be a valid number', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + { + 'type': 'bool_type', + 'loc': ('bool',), + 'msg': 'Input should be a valid boolean', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + ]