Skip to content

Commit

Permalink
allow function-wrap-schema to reuse parent schema (#515)
Browse files Browse the repository at this point in the history
* allow function-wrap-schema to reuse parent schema

* fix filtering inside wrap serializers

* test for json_return_type='path'

* filtering on json function after

* more tests
  • Loading branch information
samuelcolvin authored Apr 2, 2023
1 parent cd5cd65 commit 1276e37
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 55 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pydantic-core"
version = "0.21.0"
version = "0.22.0"
edition = "2021"
license = "MIT"
homepage = "https://github.com/pydantic/pydantic-core"
Expand Down
12 changes: 7 additions & 5 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class FieldPlainSerializerFunctionSchema(TypedDict):

# must match `src/serializers/ob_type.rs::ObType`
JsonReturnTypes = Literal[
'none',
'int',
'int_subclass',
'bool',
Expand All @@ -228,6 +229,7 @@ class FieldPlainSerializerFunctionSchema(TypedDict):
'dataclass',
'model',
'enum',
'path',
]

WhenUsed = Literal['always', 'unless-none', 'json', 'json-unless-none']
Expand Down Expand Up @@ -333,20 +335,20 @@ class FieldWrapSerializerFunctionSchema(TypedDict):
class WrapSerializerFunctionSerSchema(TypedDict, total=False):
type: Required[Literal['function-wrap']]
function: Required[Union[GeneralWrapSerializerFunctionSchema, FieldWrapSerializerFunctionSchema]]
schema: Required[CoreSchema]
schema: CoreSchema # if ommited, the schema on which this serializer is defined is used
json_return_type: JsonReturnTypes
when_used: WhenUsed # default: 'always'


def general_wrap_serializer_function_ser_schema(
function: GeneralWrapSerializerFunction,
schema: CoreSchema,
*,
schema: CoreSchema | None = None,
json_return_type: JsonReturnTypes | None = None,
when_used: WhenUsed = 'always',
) -> WrapSerializerFunctionSerSchema:
"""
Returns a schema for serialization with a function.
Returns a schema for serialization with a general wrap function.
Args:
function: The function to use for serialization
Expand All @@ -368,13 +370,13 @@ def general_wrap_serializer_function_ser_schema(

def field_wrap_serializer_function_ser_schema(
function: FieldWrapSerializerFunction,
schema: CoreSchema,
*,
schema: CoreSchema | None = None,
json_return_type: JsonReturnTypes | None = None,
when_used: WhenUsed = 'always',
) -> WrapSerializerFunctionSerSchema:
"""
Returns a schema to serialize a field from a model, TypedDict or dataclass.
Returns a schema to serialize a field from a model, TypedDict or dataclass using a wrap function.
Args:
function: The function to use for serialization
Expand Down
12 changes: 12 additions & 0 deletions src/build_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ pub fn is_strict(schema: &PyDict, config: Option<&PyDict>) -> PyResult<bool> {
Ok(schema_or_config_same(schema, config, intern!(py, "strict"))?.unwrap_or(false))
}

pub fn destructure_function_schema(schema: &PyDict) -> PyResult<(bool, &PyAny)> {
let func_dict: &PyDict = schema.get_as_req(intern!(schema.py(), "function"))?;
let function: &PyAny = func_dict.get_as_req(intern!(schema.py(), "function"))?;
let func_type: &str = func_dict.get_as_req(intern!(schema.py(), "type"))?;
let is_field_serializer = match func_type {
"field" => true,
"general" => false,
_ => unreachable!(),
};
Ok((is_field_serializer, function))
}

enum SchemaErrorEnum {
Message(String),
ValidationError(ValidationError),
Expand Down
12 changes: 8 additions & 4 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,22 @@ impl CombinedSerializer {
let op_ser_type: Option<&str> = ser_schema.get_as(type_key)?;
match op_ser_type {
Some("function-plain") => {
// `function` is a special case, not included in `find_serializer` since it means something
// different in `schema.type`
// `function-plain` is a special case, not included in `find_serializer` since it means
// something different in `schema.type`
// NOTE! we use the `schema` here, not `ser_schema`
return super::type_serializers::function::FunctionPlainSerializer::build(
ser_schema,
schema,
config,
build_context,
)
.map_err(|err| py_error_type!("Error building `function-plain` serializer:\n {}", err));
}
Some("function-wrap") => {
// `function-wrap` is also a special case, not included in `find_serializer` since it mean
// something different in `schema.type`
// NOTE! we use the `schema` here, not `ser_schema`
return super::type_serializers::function::FunctionWrapSerializer::build(
ser_schema,
schema,
config,
build_context,
)
Expand Down
63 changes: 37 additions & 26 deletions src/serializers/type_serializers/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use pyo3::types::PyString;
use serde::ser::Error;

use crate::build_context::BuildContext;
use crate::build_tools::{function_name, py_error_type, SchemaDict};
use crate::build_tools::{destructure_function_schema, function_name, py_error_type, SchemaDict};
use crate::serializers::extra::{ExtraOwned, SerMode};
use crate::serializers::filter::AnyFilter;
use crate::{PydanticOmit, PydanticSerializationUnexpectedValue};
Expand Down Expand Up @@ -71,18 +71,6 @@ impl BuildSerializer for FunctionPlainSerializerBuilder {
}
}

fn destructure_function_schema(schema: &PyDict) -> PyResult<(bool, &PyAny)> {
let func_dict: &PyDict = schema.get_as_req(intern!(schema.py(), "function"))?;
let function: &PyAny = func_dict.get_as_req(intern!(schema.py(), "function"))?;
let func_type: &str = func_dict.get_as_req(intern!(schema.py(), "type"))?;
let is_field_serializer = match func_type {
"field" => true,
"general" => false,
_ => unreachable!(),
};
Ok((is_field_serializer, function))
}

#[derive(Debug, Clone)]
pub struct FunctionPlainSerializer {
func: PyObject,
Expand All @@ -95,22 +83,28 @@ pub struct FunctionPlainSerializer {

impl BuildSerializer for FunctionPlainSerializer {
const EXPECTED_TYPE: &'static str = "function-plain";

/// NOTE! `schema` here is the actual `CoreSchema`, not `schema.serialization` as in the other builders
/// (done this way to match `FunctionWrapSerializer` which requires the full schema)
fn build(
schema: &PyDict,
_config: Option<&PyDict>,
_build_context: &mut BuildContext<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
let py = schema.py();
let (is_field_serializer, function) = destructure_function_schema(schema)?;

let ser_schema: &PyDict = schema.get_as_req(intern!(py, "serialization"))?;

let (is_field_serializer, function) = destructure_function_schema(ser_schema)?;
let function_name = function_name(function)?;

let name = format!("plain_function[{function_name}]");
Ok(Self {
func: function.into_py(py),
function_name,
name,
json_return_ob_type: get_json_return_type(schema)?,
when_used: WhenUsed::new(schema, WhenUsed::Always)?,
json_return_ob_type: get_json_return_type(ser_schema)?,
when_used: WhenUsed::new(ser_schema, WhenUsed::Always)?,
is_field_serializer,
}
.into())
Expand Down Expand Up @@ -157,9 +151,10 @@ macro_rules! function_type_serializer {
Ok(v) => {
let next_value = v.as_ref(py);
match extra.mode {
// None for include/exclude here, as filtering should be done
SerMode::Json => match self.json_return_ob_type {
Some(ref ob_type) => infer_to_python_known(ob_type, next_value, include, exclude, extra),
None => infer_to_python(next_value, include, exclude, extra),
Some(ref ob_type) => infer_to_python_known(ob_type, next_value, None, None, extra),
None => infer_to_python(next_value, None, None, extra),
},
_ => Ok(next_value.to_object(py)),
}
Expand Down Expand Up @@ -222,11 +217,12 @@ macro_rules! function_type_serializer {
match self.call(value, include, exclude, extra) {
Ok(v) => {
let next_value = v.as_ref(py);
// None for include/exclude here, as filtering should be done
match self.json_return_ob_type {
Some(ref ob_type) => {
infer_serialize_known(ob_type, next_value, serializer, include, exclude, extra)
infer_serialize_known(ob_type, next_value, serializer, None, None, extra)
}
None => infer_serialize(next_value, serializer, include, exclude, extra),
None => infer_serialize(next_value, serializer, None, None, extra),
}
}
Err(err) => match err.value(py).extract::<PydanticSerializationUnexpectedValue>() {
Expand Down Expand Up @@ -281,26 +277,41 @@ pub struct FunctionWrapSerializer {

impl BuildSerializer for FunctionWrapSerializer {
const EXPECTED_TYPE: &'static str = "function-wrap";

/// NOTE! `schema` here is the actual `CoreSchema`, not `schema.serialization` as in the other builders
/// (done this way since we need the `CoreSchema`)
fn build(
schema: &PyDict,
config: Option<&PyDict>,
build_context: &mut BuildContext<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
let py = schema.py();
let ser_schema: &PyDict = schema.get_as_req(intern!(py, "serialization"))?;

let (is_field_serializer, function) = destructure_function_schema(schema)?;
let (is_field_serializer, function) = destructure_function_schema(ser_schema)?;
let function_name = function_name(function)?;

let serializer_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?;
let serializer = CombinedSerializer::build(serializer_schema, config, build_context)?;
// try to get `schema.serialization.schema`, otherwise use `schema` with `serialization` key removed
let inner_schema: &PyDict = if let Some(s) = ser_schema.get_as(intern!(py, "schema"))? {
s
} else {
// we copy the schema so we can modify it without affecting the original
let schema_copy = schema.copy()?;
// remove the serialization key from the schema so we don't recurse
schema_copy.del_item(intern!(py, "serialization"))?;
schema_copy
};

let serializer = CombinedSerializer::build(inner_schema, config, build_context)?;

let name = format!("wrap_function[{function_name}, {}]", serializer.get_name());
Ok(Self {
serializer: Box::new(serializer),
func: function.into_py(py),
function_name,
name,
json_return_ob_type: get_json_return_type(schema)?,
when_used: WhenUsed::new(schema, WhenUsed::Always)?,
json_return_ob_type: get_json_return_type(ser_schema)?,
when_used: WhenUsed::new(ser_schema, WhenUsed::Always)?,
is_field_serializer,
}
.into())
Expand Down Expand Up @@ -385,7 +396,7 @@ impl SerializationCallable {
Err(PydanticOmit::new_err())
}
} else {
let v = self.serializer.to_python(value, None, None, &extra)?;
let v = self.serializer.to_python(value, include, exclude, &extra)?;
extra.warnings.final_check(py)?;
Ok(Some(v))
}
Expand Down
14 changes: 1 addition & 13 deletions src/validators/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyDict, PyString};

use crate::build_tools::{function_name, py_err, SchemaDict};
use crate::build_tools::{destructure_function_schema, function_name, py_err, SchemaDict};
use crate::errors::{
ErrorType, LocItem, PydanticCustomError, PydanticKnownError, PydanticOmit, ValError, ValResult, ValidationError,
};
Expand All @@ -14,18 +14,6 @@ use crate::recursion_guard::RecursionGuard;
use super::generator::InternalValidator;
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};

fn destructure_function_schema(schema: &PyDict) -> PyResult<(bool, &PyAny)> {
let func_dict: &PyDict = schema.get_as_req(intern!(schema.py(), "function"))?;
let function: &PyAny = func_dict.get_as_req(intern!(schema.py(), "function"))?;
let func_type: &str = func_dict.get_as_req(intern!(schema.py(), "type"))?;
let is_field_validator = match func_type {
"field" => true,
"general" => false,
_ => unreachable!(),
};
Ok((is_field_validator, function))
}

macro_rules! impl_build {
($impl_name:ident, $name:literal) => {
impl BuildValidator for $impl_name {
Expand Down
Loading

0 comments on commit 1276e37

Please sign in to comment.