diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 8635b63a5..7b9404ce4 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -1437,6 +1437,30 @@ def uuid_schema( ) +class NestedSchema(TypedDict, total=False): + type: Required[Literal['nested']] + cls: Required[Type[Any]] + # Should return `(CoreSchema, SchemaValidator, SchemaSerializer)` but this requires a forward ref + get_info: Required[Callable[[], Any]] + metadata: Dict[str, Any] + serialization: SerSchema + +def nested_schema( + *, + cls: Type[Any], + get_info: Callable[[], Any], + metadata: Dict[str, Any] | None = None, + serialization: SerSchema | None = None +) -> NestedSchema: + return _dict_not_none( + type='nested', + cls=cls, + get_info=get_info, + metadata=metadata, + serialization=serialization + ) + + class IncExSeqSerSchema(TypedDict, total=False): type: Required[Literal['include-exclude-sequence']] include: Set[int] @@ -3866,6 +3890,7 @@ def definition_reference_schema( DefinitionReferenceSchema, UuidSchema, ComplexSchema, + NestedSchema, ] elif False: CoreSchema: TypeAlias = Mapping[str, Any] @@ -3922,6 +3947,7 @@ def definition_reference_schema( 'definition-ref', 'uuid', 'complex', + 'nested', ] CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field'] diff --git a/src/py_gc.rs b/src/py_gc.rs index 8af285afb..ebe4bb18f 100644 --- a/src/py_gc.rs +++ b/src/py_gc.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use ahash::AHashMap; use enum_dispatch::enum_dispatch; @@ -58,6 +58,25 @@ impl PyGcTraverse for Option { } } +impl PyGcTraverse for Result { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + match self { + Ok(v) => T::py_gc_traverse(v, visit), + // FIXME(BoxyUwU): Lol + Err(_) => Ok(()), + } + } +} + +impl PyGcTraverse for OnceLock { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + match self.get() { + Some(item) => T::py_gc_traverse(item, visit), + None => Ok(()), + } + } +} + /// A crude alternative to a "derive" macro to help with building PyGcTraverse implementations macro_rules! impl_py_gc_traverse { ($name:ty { }) => { diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 8eb54c837..37a1f26b4 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -143,6 +143,7 @@ combined_serializer! { Recursive: super::type_serializers::definitions::DefinitionRefSerializer; Tuple: super::type_serializers::tuple::TupleSerializer; Complex: super::type_serializers::complex::ComplexSerializer; + Nested: super::type_serializers::nested::NestedSerializer; } } @@ -254,6 +255,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::Nested(inner) => inner.py_gc_traverse(visit), } } } diff --git a/src/serializers/type_serializers/mod.rs b/src/serializers/type_serializers/mod.rs index dabd006a3..e5235f53e 100644 --- a/src/serializers/type_serializers/mod.rs +++ b/src/serializers/type_serializers/mod.rs @@ -16,6 +16,7 @@ pub mod json_or_python; pub mod list; pub mod literal; pub mod model; +pub mod nested; pub mod nullable; pub mod other; pub mod set_frozenset; diff --git a/src/serializers/type_serializers/nested.rs b/src/serializers/type_serializers/nested.rs new file mode 100644 index 000000000..dc0b64fa1 --- /dev/null +++ b/src/serializers/type_serializers/nested.rs @@ -0,0 +1,135 @@ +use std::{borrow::Cow, sync::OnceLock}; + +use pyo3::{ + intern, + types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType}, + Bound, Py, PyAny, PyObject, PyResult, Python, +}; + +use crate::{ + definitions::DefinitionsBuilder, + serializers::{ + shared::{BuildSerializer, TypeSerializer}, + CombinedSerializer, Extra, + }, + SchemaSerializer, +}; + +#[derive(Debug)] +pub struct NestedSerializer { + model: Py, + name: String, + get_serializer: Py, + serializer: OnceLock>>, +} + +impl_py_gc_traverse!(NestedSerializer { + model, + get_serializer, + serializer +}); + +impl BuildSerializer for NestedSerializer { + const EXPECTED_TYPE: &'static str = "nested"; + + fn build( + schema: &Bound<'_, PyDict>, + _config: Option<&Bound<'_, PyDict>>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + + let get_serializer = schema + .get_item(intern!(py, "get_info"))? + .expect("Invalid core schema for `nested` type, no `get_info`") + .unbind(); + + let model = schema + .get_item(intern!(py, "cls"))? + .expect("Invalid core schema for `nested` type, no `model`") + .downcast::() + .expect("Invalid core schema for `nested` type, not a `PyType`") + .clone(); + + let name = model.getattr(intern!(py, "__name__"))?.extract()?; + + Ok(CombinedSerializer::Nested(NestedSerializer { + model: model.clone().unbind(), + name, + get_serializer, + serializer: OnceLock::new(), + })) + } +} + +impl NestedSerializer { + fn nested_serializer<'py>(&self, py: Python<'py>) -> PyResult<&Py> { + self.serializer + .get_or_init(|| { + Ok(self + .get_serializer + .bind(py) + .call((), None)? + .downcast::()? + .get_item(2)? + .downcast::()? + .clone() + .unbind()) + }) + .as_ref() + .map_err(|e| e.clone_ref(py)) + } +} + +impl TypeSerializer for NestedSerializer { + fn to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + mut extra: &Extra, + ) -> PyResult { + let mut guard = extra.recursion_guard(value, self.model.as_ptr() as usize)?; + + self.nested_serializer(value.py())? + .bind(value.py()) + .get() + .serializer + .to_python(value, include, exclude, guard.state()) + } + + fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + self.nested_serializer(key.py())? + .bind(key.py()) + .get() + .serializer + .json_key(key, extra) + } + + fn serde_serialize( + &self, + value: &Bound<'_, PyAny>, + serializer: S, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + mut extra: &Extra, + ) -> Result { + use super::py_err_se_err; + + let mut guard = extra + .recursion_guard(value, self.model.as_ptr() as usize) + .map_err(py_err_se_err)?; + + self.nested_serializer(value.py()) + // FIXME(BoxyUwU): Don't unwrap this + .unwrap() + .bind(value.py()) + .get() + .serializer + .serde_serialize(value, serializer, include, exclude, guard.state()) + } + + fn get_name(&self) -> &str { + &self.name + } +} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 5f88d5dc8..3b75776a2 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -49,6 +49,7 @@ mod list; mod literal; mod model; mod model_fields; +mod nested; mod none; mod nullable; mod set; @@ -584,6 +585,7 @@ pub fn build_validator( definitions::DefinitionRefValidator, definitions::DefinitionsValidatorBuilder, complex::ComplexValidator, + nested::NestedValidator, ) } @@ -738,6 +740,8 @@ pub enum CombinedValidator { // input dependent JsonOrPython(json_or_python::JsonOrPython), Complex(complex::ComplexValidator), + // Schema for reusing an existing validator + Nested(nested::NestedValidator), } /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, diff --git a/src/validators/model.rs b/src/validators/model.rs index 2c0cef6fd..dad29e3e4 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -77,7 +77,7 @@ impl BuildValidator for ModelValidator { let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; let sub_schema = schema.get_as_req(intern!(py, "schema"))?; - let validator = build_validator(&sub_schema, config.as_ref(), definitions)?; + let validator: CombinedValidator = build_validator(&sub_schema, config.as_ref(), definitions)?; let name = class.getattr(intern!(py, "__name__"))?.extract()?; Ok(Self { diff --git a/src/validators/nested.rs b/src/validators/nested.rs new file mode 100644 index 000000000..fb9f88059 --- /dev/null +++ b/src/validators/nested.rs @@ -0,0 +1,115 @@ +use std::sync::OnceLock; + +use pyo3::{ + intern, + types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyTupleMethods, PyType}, + Bound, Py, PyAny, PyObject, PyResult, Python, +}; + +use crate::{ + definitions::DefinitionsBuilder, + errors::{ErrorTypeDefaults, ValError, ValResult}, + input::Input, + recursion_guard::RecursionGuard, +}; + +use super::{BuildValidator, CombinedValidator, SchemaValidator, ValidationState, Validator}; + +#[derive(Debug)] +pub struct NestedValidator { + cls: Py, + name: String, + get_validator: Py, + validator: OnceLock>>, +} + +impl_py_gc_traverse!(NestedValidator { + cls, + get_validator, + validator +}); + +impl BuildValidator for NestedValidator { + const EXPECTED_TYPE: &'static str = "nested"; + + fn build( + schema: &Bound<'_, PyDict>, + _config: Option<&Bound<'_, PyDict>>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + + let get_validator = schema.get_item(intern!(py, "get_info"))?.unwrap().unbind(); + + let cls = schema + .get_item(intern!(py, "cls"))? + .unwrap() + .downcast::()? + .clone(); + + let name = cls.getattr(intern!(py, "__name__"))?.extract()?; + + Ok(CombinedValidator::Nested(NestedValidator { + cls: cls.clone().unbind(), + name, + get_validator: get_validator, + validator: OnceLock::new(), + })) + } +} + +impl NestedValidator { + fn nested_validator<'py>(&self, py: Python<'py>) -> PyResult<&Py> { + self.validator + .get_or_init(|| { + Ok(self + .get_validator + .bind(py) + .call((), None)? + .downcast::()? + .get_item(1)? + .downcast::()? + .clone() + .unbind()) + }) + .as_ref() + .map_err(|e| e.clone_ref(py)) + } +} + +impl Validator for NestedValidator { + fn validate<'py>( + &self, + py: Python<'py>, + input: &(impl Input<'py> + ?Sized), + state: &mut ValidationState<'_, 'py>, + ) -> ValResult { + let Some(id) = input.as_python().map(py_identity) else { + return self + .nested_validator(py)? + .bind(py) + .get() + .validator + .validate(py, input, state); + }; + + // Python objects can be cyclic, so need recursion guard + let Ok(mut guard) = RecursionGuard::new(state, id, self.cls.as_ptr() as usize) else { + return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); + }; + + self.nested_validator(py)? + .bind(py) + .get() + .validator + .validate(py, input, guard.state()) + } + + fn get_name(&self) -> &str { + &self.name + } +} + +fn py_identity(obj: &Bound<'_, PyAny>) -> usize { + obj.as_ptr() as usize +}