diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 8fc1a4230..ed524e09b 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -1,5 +1,6 @@ use std::fmt::Debug; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyTuple, PyType}; @@ -24,6 +25,7 @@ mod fields; mod filter; mod infer; mod ob_type; +mod prebuilt; pub mod ser; mod shared; mod type_serializers; @@ -37,7 +39,7 @@ pub enum WarningsArg { #[pyclass(module = "pydantic_core._pydantic_core", frozen)] #[derive(Debug)] pub struct SchemaSerializer { - serializer: CombinedSerializer, + serializer: Arc, definitions: Definitions, expected_json_size: AtomicUsize, config: SerializationConfig, @@ -92,7 +94,7 @@ impl SchemaSerializer { let mut definitions_builder = DefinitionsBuilder::new(); let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?; Ok(Self { - serializer, + serializer: Arc::new(serializer), definitions: definitions_builder.finish()?, expected_json_size: AtomicUsize::new(1024), config: SerializationConfig::from_config(config)?, diff --git a/src/serializers/prebuilt.rs b/src/serializers/prebuilt.rs new file mode 100644 index 000000000..5b81985e0 --- /dev/null +++ b/src/serializers/prebuilt.rs @@ -0,0 +1,93 @@ +use std::borrow::Cow; +use std::sync::Arc; + +use pyo3::exceptions::PyValueError; +use pyo3::intern; +use pyo3::prelude::*; +use pyo3::types::{PyBool, PyDict, PyType}; + +use crate::definitions::DefinitionsBuilder; +use crate::tools::SchemaDict; +use crate::SchemaSerializer; + +use super::extra::Extra; +use super::shared::{BuildSerializer, CombinedSerializer, TypeSerializer}; + +#[derive(Debug)] +pub struct PrebuiltSerializer { + serializer: Arc, +} + +impl BuildSerializer for PrebuiltSerializer { + const EXPECTED_TYPE: &'static str = "prebuilt"; + + fn build( + schema: &Bound<'_, PyDict>, + _config: Option<&Bound<'_, PyDict>>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; + + // note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr) + // because we don't want to fetch prebuilt serializers from parent classes + let class_dict: Bound<'_, PyDict> = class.getattr(intern!(py, "__dict__"))?.extract()?; + + // Ensure the class has completed its Pydantic setup + let is_complete: bool = class_dict + .get_as_req::>(intern!(py, "__pydantic_complete__")) + .is_ok_and(|b| b.extract().unwrap_or(false)); + + if !is_complete { + return Err(PyValueError::new_err("Prebuilt serializer not found.")); + } + + // Retrieve the prebuilt validator if available + let prebuilt_serializer: Bound<'_, PyAny> = class_dict.get_as_req(intern!(py, "__pydantic_serializer__"))?; + let schema_serializer: PyRef = prebuilt_serializer.extract()?; + let combined_serializer: Arc = schema_serializer.serializer.clone(); + + Ok(Self { + serializer: combined_serializer, + } + .into()) + } +} + +impl_py_gc_traverse!(PrebuiltSerializer { serializer }); + +impl TypeSerializer for PrebuiltSerializer { + fn to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> PyResult { + self.serializer.to_python(value, include, exclude, extra) + } + + fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + self.serializer.json_key(key, extra) + } + + fn serde_serialize( + &self, + value: &Bound<'_, PyAny>, + serializer: S, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> Result { + self.serializer + .serde_serialize(value, serializer, include, exclude, extra) + } + + fn get_name(&self) -> &str { + self.serializer.get_name() + } + + fn retry_with_lax_check(&self) -> bool { + self.serializer.retry_with_lax_check() + } +} diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index f37810657..741ce19c2 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -84,6 +84,8 @@ combined_serializer! { Function: super::type_serializers::function::FunctionPlainSerializer; FunctionWrap: super::type_serializers::function::FunctionWrapSerializer; Fields: super::fields::GeneralFieldsSerializer; + // prebuilt serializers are manually constructed, and thus manually added to the `CombinedSerializer` enum + Prebuilt: super::prebuilt::PrebuiltSerializer; } // `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer` // but aren't actually used for serialization, e.g. their `build` method must return another serializer @@ -195,7 +197,16 @@ impl CombinedSerializer { } let type_: Bound<'_, PyString> = schema.get_as_req(type_key)?; - Self::find_serializer(type_.to_str()?, schema, config, definitions) + let type_ = type_.to_str()?; + + // if we have a SchemaValidator on the type already, use it + if matches!(type_, "model" | "dataclass" | "typed-dict") { + if let Ok(prebuilt_serializer) = super::prebuilt::PrebuiltSerializer::build(schema, config, definitions) { + return Ok(prebuilt_serializer); + } + } + + Self::find_serializer(type_, schema, config, definitions) } } @@ -219,6 +230,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Function(inner) => inner.py_gc_traverse(visit), CombinedSerializer::FunctionWrap(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Fields(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::Prebuilt(inner) => inner.py_gc_traverse(visit), CombinedSerializer::None(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Nullable(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Int(inner) => inner.py_gc_traverse(visit),