Skip to content

Commit

Permalink
a
Browse files Browse the repository at this point in the history
  • Loading branch information
BoxyUwU committed Aug 27, 2024
1 parent b65d178 commit 26e3ec4
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 49 deletions.
8 changes: 7 additions & 1 deletion python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,18 +1440,24 @@ def uuid_schema(
class NestedModelSchema(TypedDict, total=False):
type: Required[Literal['nested-model']]
model: Required[Type[Any]]
# Should return `(CoreSchema, SchemaValidator, SchemaSerializer)` but this requires a forward ref
get_info: Required[Callable[[], Any]]
metadata: Any

serialization: SerSchema

def nested_model_schema(
*,
model: Type[Any],
get_info: Callable[[], Any],
metadata: Any = None,
serialization: SerSchema | None = None
) -> NestedModelSchema:
return _dict_not_none(
type='nested-model',
model=model,
get_info=get_info,
metadata=metadata,
serialization=serialization
)


Expand Down
11 changes: 10 additions & 1 deletion src/py_gc.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::sync::{Arc, OnceLock};

use ahash::AHashMap;
use enum_dispatch::enum_dispatch;
Expand Down Expand Up @@ -58,6 +58,15 @@ impl<T: PyGcTraverse> PyGcTraverse for Option<T> {
}
}

impl<T: PyGcTraverse> PyGcTraverse for OnceLock<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
match self.get() {
Some(item) => T::py_gc_traverse(item, visit),
None => Ok(()),
}
}
}

/// A crude alternative to a "derive" macro to help with building PyGcTraverse implementations
macro_rules! impl_py_gc_traverse {
($name:ty { }) => {
Expand Down
77 changes: 52 additions & 25 deletions src/serializers/type_serializers/nested_model.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::borrow::Cow;
use std::{borrow::Cow, sync::OnceLock};

use pyo3::{
intern,
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType},
Bound, Py, PyAny, PyObject, PyResult, Python,
};

Expand All @@ -19,9 +19,15 @@ use crate::{
pub struct NestedModelSerializer {
model: Py<PyType>,
name: String,
get_serializer: Py<PyAny>,
serializer: OnceLock<Py<SchemaSerializer>>,
}

impl_py_gc_traverse!(NestedModelSerializer { model });
impl_py_gc_traverse!(NestedModelSerializer {
model,
get_serializer,
serializer
});

impl BuildSerializer for NestedModelSerializer {
const EXPECTED_TYPE: &'static str = "nested-model";
Expand All @@ -32,6 +38,12 @@ impl BuildSerializer for NestedModelSerializer {
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
let py = schema.py();

let get_serializer = schema
.get_item(intern!(py, "get_info"))?
.expect("Invalid core schema for `nested-model` type")
.unbind();

let model = schema
.get_item(intern!(py, "model"))?
.expect("Invalid core schema for `nested-model` type")
Expand All @@ -44,29 +56,30 @@ impl BuildSerializer for NestedModelSerializer {
Ok(CombinedSerializer::NestedModel(NestedModelSerializer {
model: model.clone().unbind(),
name,
get_serializer,
serializer: OnceLock::new(),
}))
}
}

impl NestedModelSerializer {
fn nested_serializer<'py>(&self, py: Python<'py>) -> Bound<'py, SchemaSerializer> {
self.model
.bind(py)
.call_method(intern!(py, "model_rebuild"), (), None)
.unwrap();

self.model
.getattr(py, intern!(py, "__pydantic_serializer__"))
.unwrap()
.downcast_bound::<SchemaSerializer>(py)
.unwrap()
fn nested_serializer<'py>(&self, py: Python<'py>) -> Py<SchemaSerializer> {
self.serializer
.get_or_init(|| {
self.get_serializer
.bind(py)
.call((), None)
.expect("Invalid core schema for `nested-model`")
.downcast::<PyTuple>()
.expect("Invalid return value from `nested-model`'s `get_info` callable")
.get_item(2)
.expect("Invalid return value from `nested-model`'s `get_info` callable")
.downcast::<SchemaSerializer>()
.expect("Invalid return value from `nested-model`'s `get_info` callable")
.clone()
.unbind()
})
.clone()

// crate::schema_cache::retrieve_schema(py, self.model.as_any().clone())
// .downcast_bound::<SchemaSerializer>(py)
// // FIXME: This actually will always trigger as we cache a `CoreSchema` lol
// .expect("Cached validator was not a `SchemaSerializer`")
// .clone()
}
}

Expand All @@ -76,16 +89,23 @@ impl TypeSerializer for NestedModelSerializer {
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
mut extra: &Extra,
) -> PyResult<PyObject> {
let mut guard = extra.recursion_guard(value, self.model.as_ptr() as usize)?;

self.nested_serializer(value.py())
.bind(value.py())
.get()
.serializer
.to_python(value, include, exclude, extra)
.to_python(value, include, exclude, guard.state())
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
self.nested_serializer(key.py()).get().serializer.json_key(key, extra)
self.nested_serializer(key.py())
.bind(key.py())
.get()
.serializer
.json_key(key, extra)
}

fn serde_serialize<S: serde::ser::Serializer>(
Expand All @@ -94,12 +114,19 @@ impl TypeSerializer for NestedModelSerializer {
serializer: S,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
mut extra: &Extra,
) -> Result<S::Ok, S::Error> {
use super::py_err_se_err;

let mut guard = extra
.recursion_guard(value, self.model.as_ptr() as usize)
.map_err(py_err_se_err)?;

self.nested_serializer(value.py())
.bind(value.py())
.get()
.serializer
.serde_serialize(value, serializer, include, exclude, extra)
.serde_serialize(value, serializer, include, exclude, guard.state())
}

fn get_name(&self) -> &str {
Expand Down
84 changes: 62 additions & 22 deletions src/validators/nested_model.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
use std::sync::OnceLock;

use pyo3::{
intern,
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
Bound, Py, PyObject, PyResult, Python,
types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyTupleMethods, PyType},
Bound, Py, PyAny, PyObject, PyResult, Python,
};

use crate::{definitions::DefinitionsBuilder, errors::ValResult, input::Input};
use crate::{
definitions::DefinitionsBuilder,
errors::{ErrorTypeDefaults, ValError, ValResult},
input::Input,
recursion_guard::RecursionGuard,
};

use super::{BuildValidator, CombinedValidator, SchemaValidator, ValidationState, Validator};

#[derive(Debug, Clone)]
pub struct NestedModelValidator {
model: Py<PyType>,
name: String,
get_validator: Py<PyAny>,
validator: OnceLock<Py<SchemaValidator>>,
}

impl_py_gc_traverse!(NestedModelValidator { model });
impl_py_gc_traverse!(NestedModelValidator {
model,
get_validator,
validator
});

impl BuildValidator for NestedModelValidator {
const EXPECTED_TYPE: &'static str = "nested-model";
Expand All @@ -25,6 +38,12 @@ impl BuildValidator for NestedModelValidator {
_definitions: &mut DefinitionsBuilder<super::CombinedValidator>,
) -> PyResult<super::CombinedValidator> {
let py = schema.py();

let get_validator = schema
.get_item(intern!(py, "get_info"))?
.expect("Invalid core schema for `nested-model` type")
.unbind();

let model = schema
.get_item(intern!(py, "model"))?
.expect("Invalid core schema for `nested-model` type")
Expand All @@ -37,40 +56,61 @@ impl BuildValidator for NestedModelValidator {
Ok(CombinedValidator::NestedModel(NestedModelValidator {
model: model.clone().unbind(),
name,
get_validator: get_validator,
validator: OnceLock::new(),
}))
}
}

impl NestedModelValidator {
fn nested_validator<'py>(&self, py: Python<'py>) -> Py<SchemaValidator> {
self.validator
.get_or_init(|| {
self.get_validator
.bind(py)
.call((), None)
.expect("Invalid core schema for `nested-model`")
.downcast::<PyTuple>()
.expect("Invalid return value from `nested-model`'s `get_info` callable")
.get_item(1)
.expect("Invalid return value from `nested-model`'s `get_info` callable")
.downcast::<SchemaValidator>()
.expect("Invalid return value from `nested-model`'s `get_info` callable")
.clone()
.unbind()
})
.clone()
}
}

impl Validator for NestedModelValidator {
fn validate<'py>(
&self,
py: Python<'py>,
input: &(impl Input<'py> + ?Sized),
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
self.model
.bind(py)
.call_method(intern!(py, "model_rebuild"), (), None)
.unwrap();

let validator = self
.model
.getattr(py, intern!(py, "__pydantic_validator__"))
.unwrap()
.downcast_bound::<SchemaValidator>(py)
.unwrap()
.clone();
let Some(id) = input.as_python().map(py_identity) else {
panic!("")
};

// let validator = crate::schema_cache::retrieve_schema(py, self.model.as_any().clone())
// .downcast_bound::<SchemaValidator>(py)
// // FIXME: This actually will always trigger as we cache a `CoreSchema` lol
// .expect("Cached validator was not a `SchemaValidator`")
// .clone();
// Python objects can be cyclic, so need recursion guard
let Ok(mut guard) = RecursionGuard::new(state, id, self.model.as_ptr() as usize) else {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
};

validator.get().validator.validate(py, input, state)
self.nested_validator(py)
.bind(py)
.get()
.validator
.validate(py, input, guard.state())
}

fn get_name(&self) -> &str {
&self.name
}
}

fn py_identity(obj: &Bound<'_, PyAny>) -> usize {
obj.as_ptr() as usize
}

0 comments on commit 26e3ec4

Please sign in to comment.