Skip to content

Commit

Permalink
serializer reuse logic as well
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle committed Jan 31, 2025
1 parent e5245b5 commit 980fa20
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt::Debug;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyTuple, PyType};
Expand All @@ -24,6 +25,7 @@ mod fields;
mod filter;
mod infer;
mod ob_type;
mod prebuilt;
pub mod ser;
mod shared;
mod type_serializers;
Expand All @@ -37,7 +39,7 @@ pub enum WarningsArg {
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
#[derive(Debug)]
pub struct SchemaSerializer {
serializer: CombinedSerializer,
serializer: Arc<CombinedSerializer>,
definitions: Definitions<CombinedSerializer>,
expected_json_size: AtomicUsize,
config: SerializationConfig,
Expand Down Expand Up @@ -92,7 +94,7 @@ impl SchemaSerializer {
let mut definitions_builder = DefinitionsBuilder::new();
let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?;
Ok(Self {
serializer,
serializer: Arc::new(serializer),
definitions: definitions_builder.finish()?,
expected_json_size: AtomicUsize::new(1024),
config: SerializationConfig::from_config(config)?,
Expand Down
93 changes: 93 additions & 0 deletions src/serializers/prebuilt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use std::borrow::Cow;
use std::sync::Arc;

use pyo3::exceptions::PyValueError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyDict, PyType};

use crate::definitions::DefinitionsBuilder;
use crate::tools::SchemaDict;
use crate::SchemaSerializer;

use super::extra::Extra;
use super::shared::{BuildSerializer, CombinedSerializer, TypeSerializer};

#[derive(Debug)]
pub struct PrebuiltSerializer {
serializer: Arc<CombinedSerializer>,
}

impl BuildSerializer for PrebuiltSerializer {
const EXPECTED_TYPE: &'static str = "prebuilt";

fn build(
schema: &Bound<'_, PyDict>,
_config: Option<&Bound<'_, PyDict>>,
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
let py = schema.py();
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;

// note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr)
// because we don't want to fetch prebuilt serializers from parent classes
let class_dict: Bound<'_, PyDict> = class.getattr(intern!(py, "__dict__"))?.extract()?;

// Ensure the class has completed its Pydantic setup
let is_complete: bool = class_dict
.get_as_req::<Bound<'_, PyBool>>(intern!(py, "__pydantic_complete__"))
.is_ok_and(|b| b.extract().unwrap_or(false));

if !is_complete {
return Err(PyValueError::new_err("Prebuilt serializer not found."));
}

// Retrieve the prebuilt validator if available
let prebuilt_serializer: Bound<'_, PyAny> = class_dict.get_as_req(intern!(py, "__pydantic_serializer__"))?;
let schema_serializer: PyRef<SchemaSerializer> = prebuilt_serializer.extract()?;
let combined_serializer: Arc<CombinedSerializer> = schema_serializer.serializer.clone();

Ok(Self {
serializer: combined_serializer,
}
.into())
}
}

impl_py_gc_traverse!(PrebuiltSerializer { serializer });

impl TypeSerializer for PrebuiltSerializer {
fn to_python(
&self,
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> PyResult<PyObject> {
self.serializer.to_python(value, include, exclude, extra)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
self.serializer.json_key(key, extra)
}

fn serde_serialize<S: serde::ser::Serializer>(
&self,
value: &Bound<'_, PyAny>,
serializer: S,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
self.serializer
.serde_serialize(value, serializer, include, exclude, extra)
}

fn get_name(&self) -> &str {
self.serializer.get_name()
}

fn retry_with_lax_check(&self) -> bool {
self.serializer.retry_with_lax_check()
}
}
14 changes: 13 additions & 1 deletion src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ combined_serializer! {
Function: super::type_serializers::function::FunctionPlainSerializer;
FunctionWrap: super::type_serializers::function::FunctionWrapSerializer;
Fields: super::fields::GeneralFieldsSerializer;
// prebuilt serializers are manually constructed, and thus manually added to the `CombinedSerializer` enum
Prebuilt: super::prebuilt::PrebuiltSerializer;
}
// `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer`
// but aren't actually used for serialization, e.g. their `build` method must return another serializer
Expand Down Expand Up @@ -195,7 +197,16 @@ impl CombinedSerializer {
}

let type_: Bound<'_, PyString> = schema.get_as_req(type_key)?;
Self::find_serializer(type_.to_str()?, schema, config, definitions)
let type_ = type_.to_str()?;

// if we have a SchemaValidator on the type already, use it
if matches!(type_, "model" | "dataclass" | "typed-dict") {
if let Ok(prebuilt_serializer) = super::prebuilt::PrebuiltSerializer::build(schema, config, definitions) {
return Ok(prebuilt_serializer);
}
}

Self::find_serializer(type_, schema, config, definitions)
}
}

Expand All @@ -219,6 +230,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Function(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::FunctionWrap(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Fields(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Prebuilt(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::None(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Nullable(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Int(inner) => inner.py_gc_traverse(visit),
Expand Down

0 comments on commit 980fa20

Please sign in to comment.