From 1276e375b2480f8aceb25f16e4648f0dc6d99b6e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 2 Apr 2023 21:18:38 +0100 Subject: [PATCH] allow function-wrap-schema to reuse parent schema (#515) * allow function-wrap-schema to reuse parent schema * fix filtering inside wrap serializers * test for json_return_type='path' * filtering on json function after * more tests --- Cargo.lock | 2 +- Cargo.toml | 2 +- pydantic_core/core_schema.py | 12 +- src/build_tools.rs | 12 ++ src/serializers/shared.rs | 12 +- src/serializers/type_serializers/function.rs | 63 ++++---- src/validators/function.rs | 14 +- tests/serializers/test_functions.py | 147 ++++++++++++++++++- tests/test_typing.py | 2 +- 9 files changed, 211 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cd646b4ef..bee3755ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -225,7 +225,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "0.21.0" +version = "0.22.0" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index fb70fedd8..f5c7c6788 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "0.21.0" +version = "0.22.0" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" diff --git a/pydantic_core/core_schema.py b/pydantic_core/core_schema.py index 12e5b38c8..20d648ef7 100644 --- a/pydantic_core/core_schema.py +++ b/pydantic_core/core_schema.py @@ -203,6 +203,7 @@ class FieldPlainSerializerFunctionSchema(TypedDict): # must match `src/serializers/ob_type.rs::ObType` JsonReturnTypes = Literal[ + 'none', 'int', 'int_subclass', 'bool', @@ -228,6 +229,7 @@ class FieldPlainSerializerFunctionSchema(TypedDict): 'dataclass', 'model', 'enum', + 'path', ] WhenUsed = Literal['always', 'unless-none', 'json', 'json-unless-none'] @@ -333,20 +335,20 @@ class FieldWrapSerializerFunctionSchema(TypedDict): class WrapSerializerFunctionSerSchema(TypedDict, total=False): type: Required[Literal['function-wrap']] function: Required[Union[GeneralWrapSerializerFunctionSchema, FieldWrapSerializerFunctionSchema]] - schema: Required[CoreSchema] + schema: CoreSchema # if ommited, the schema on which this serializer is defined is used json_return_type: JsonReturnTypes when_used: WhenUsed # default: 'always' def general_wrap_serializer_function_ser_schema( function: GeneralWrapSerializerFunction, - schema: CoreSchema, *, + schema: CoreSchema | None = None, json_return_type: JsonReturnTypes | None = None, when_used: WhenUsed = 'always', ) -> WrapSerializerFunctionSerSchema: """ - Returns a schema for serialization with a function. + Returns a schema for serialization with a general wrap function. Args: function: The function to use for serialization @@ -368,13 +370,13 @@ def general_wrap_serializer_function_ser_schema( def field_wrap_serializer_function_ser_schema( function: FieldWrapSerializerFunction, - schema: CoreSchema, *, + schema: CoreSchema | None = None, json_return_type: JsonReturnTypes | None = None, when_used: WhenUsed = 'always', ) -> WrapSerializerFunctionSerSchema: """ - Returns a schema to serialize a field from a model, TypedDict or dataclass. + Returns a schema to serialize a field from a model, TypedDict or dataclass using a wrap function. Args: function: The function to use for serialization diff --git a/src/build_tools.rs b/src/build_tools.rs index 26fe011bd..9f69517eb 100644 --- a/src/build_tools.rs +++ b/src/build_tools.rs @@ -99,6 +99,18 @@ pub fn is_strict(schema: &PyDict, config: Option<&PyDict>) -> PyResult { Ok(schema_or_config_same(schema, config, intern!(py, "strict"))?.unwrap_or(false)) } +pub fn destructure_function_schema(schema: &PyDict) -> PyResult<(bool, &PyAny)> { + let func_dict: &PyDict = schema.get_as_req(intern!(schema.py(), "function"))?; + let function: &PyAny = func_dict.get_as_req(intern!(schema.py(), "function"))?; + let func_type: &str = func_dict.get_as_req(intern!(schema.py(), "type"))?; + let is_field_serializer = match func_type { + "field" => true, + "general" => false, + _ => unreachable!(), + }; + Ok((is_field_serializer, function)) +} + enum SchemaErrorEnum { Message(String), ValidationError(ValidationError), diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index dca8b3f92..fd0c4f812 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -148,18 +148,22 @@ impl CombinedSerializer { let op_ser_type: Option<&str> = ser_schema.get_as(type_key)?; match op_ser_type { Some("function-plain") => { - // `function` is a special case, not included in `find_serializer` since it means something - // different in `schema.type` + // `function-plain` is a special case, not included in `find_serializer` since it means + // something different in `schema.type` + // NOTE! we use the `schema` here, not `ser_schema` return super::type_serializers::function::FunctionPlainSerializer::build( - ser_schema, + schema, config, build_context, ) .map_err(|err| py_error_type!("Error building `function-plain` serializer:\n {}", err)); } Some("function-wrap") => { + // `function-wrap` is also a special case, not included in `find_serializer` since it mean + // something different in `schema.type` + // NOTE! we use the `schema` here, not `ser_schema` return super::type_serializers::function::FunctionWrapSerializer::build( - ser_schema, + schema, config, build_context, ) diff --git a/src/serializers/type_serializers/function.rs b/src/serializers/type_serializers/function.rs index d59397cff..6d4188abe 100644 --- a/src/serializers/type_serializers/function.rs +++ b/src/serializers/type_serializers/function.rs @@ -10,7 +10,7 @@ use pyo3::types::PyString; use serde::ser::Error; use crate::build_context::BuildContext; -use crate::build_tools::{function_name, py_error_type, SchemaDict}; +use crate::build_tools::{destructure_function_schema, function_name, py_error_type, SchemaDict}; use crate::serializers::extra::{ExtraOwned, SerMode}; use crate::serializers::filter::AnyFilter; use crate::{PydanticOmit, PydanticSerializationUnexpectedValue}; @@ -71,18 +71,6 @@ impl BuildSerializer for FunctionPlainSerializerBuilder { } } -fn destructure_function_schema(schema: &PyDict) -> PyResult<(bool, &PyAny)> { - let func_dict: &PyDict = schema.get_as_req(intern!(schema.py(), "function"))?; - let function: &PyAny = func_dict.get_as_req(intern!(schema.py(), "function"))?; - let func_type: &str = func_dict.get_as_req(intern!(schema.py(), "type"))?; - let is_field_serializer = match func_type { - "field" => true, - "general" => false, - _ => unreachable!(), - }; - Ok((is_field_serializer, function)) -} - #[derive(Debug, Clone)] pub struct FunctionPlainSerializer { func: PyObject, @@ -95,13 +83,19 @@ pub struct FunctionPlainSerializer { impl BuildSerializer for FunctionPlainSerializer { const EXPECTED_TYPE: &'static str = "function-plain"; + + /// NOTE! `schema` here is the actual `CoreSchema`, not `schema.serialization` as in the other builders + /// (done this way to match `FunctionWrapSerializer` which requires the full schema) fn build( schema: &PyDict, _config: Option<&PyDict>, _build_context: &mut BuildContext, ) -> PyResult { let py = schema.py(); - let (is_field_serializer, function) = destructure_function_schema(schema)?; + + let ser_schema: &PyDict = schema.get_as_req(intern!(py, "serialization"))?; + + let (is_field_serializer, function) = destructure_function_schema(ser_schema)?; let function_name = function_name(function)?; let name = format!("plain_function[{function_name}]"); @@ -109,8 +103,8 @@ impl BuildSerializer for FunctionPlainSerializer { func: function.into_py(py), function_name, name, - json_return_ob_type: get_json_return_type(schema)?, - when_used: WhenUsed::new(schema, WhenUsed::Always)?, + json_return_ob_type: get_json_return_type(ser_schema)?, + when_used: WhenUsed::new(ser_schema, WhenUsed::Always)?, is_field_serializer, } .into()) @@ -157,9 +151,10 @@ macro_rules! function_type_serializer { Ok(v) => { let next_value = v.as_ref(py); match extra.mode { + // None for include/exclude here, as filtering should be done SerMode::Json => match self.json_return_ob_type { - Some(ref ob_type) => infer_to_python_known(ob_type, next_value, include, exclude, extra), - None => infer_to_python(next_value, include, exclude, extra), + Some(ref ob_type) => infer_to_python_known(ob_type, next_value, None, None, extra), + None => infer_to_python(next_value, None, None, extra), }, _ => Ok(next_value.to_object(py)), } @@ -222,11 +217,12 @@ macro_rules! function_type_serializer { match self.call(value, include, exclude, extra) { Ok(v) => { let next_value = v.as_ref(py); + // None for include/exclude here, as filtering should be done match self.json_return_ob_type { Some(ref ob_type) => { - infer_serialize_known(ob_type, next_value, serializer, include, exclude, extra) + infer_serialize_known(ob_type, next_value, serializer, None, None, extra) } - None => infer_serialize(next_value, serializer, include, exclude, extra), + None => infer_serialize(next_value, serializer, None, None, extra), } } Err(err) => match err.value(py).extract::() { @@ -281,26 +277,41 @@ pub struct FunctionWrapSerializer { impl BuildSerializer for FunctionWrapSerializer { const EXPECTED_TYPE: &'static str = "function-wrap"; + + /// NOTE! `schema` here is the actual `CoreSchema`, not `schema.serialization` as in the other builders + /// (done this way since we need the `CoreSchema`) fn build( schema: &PyDict, config: Option<&PyDict>, build_context: &mut BuildContext, ) -> PyResult { let py = schema.py(); + let ser_schema: &PyDict = schema.get_as_req(intern!(py, "serialization"))?; - let (is_field_serializer, function) = destructure_function_schema(schema)?; + let (is_field_serializer, function) = destructure_function_schema(ser_schema)?; let function_name = function_name(function)?; - let serializer_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?; - let serializer = CombinedSerializer::build(serializer_schema, config, build_context)?; + // try to get `schema.serialization.schema`, otherwise use `schema` with `serialization` key removed + let inner_schema: &PyDict = if let Some(s) = ser_schema.get_as(intern!(py, "schema"))? { + s + } else { + // we copy the schema so we can modify it without affecting the original + let schema_copy = schema.copy()?; + // remove the serialization key from the schema so we don't recurse + schema_copy.del_item(intern!(py, "serialization"))?; + schema_copy + }; + + let serializer = CombinedSerializer::build(inner_schema, config, build_context)?; + let name = format!("wrap_function[{function_name}, {}]", serializer.get_name()); Ok(Self { serializer: Box::new(serializer), func: function.into_py(py), function_name, name, - json_return_ob_type: get_json_return_type(schema)?, - when_used: WhenUsed::new(schema, WhenUsed::Always)?, + json_return_ob_type: get_json_return_type(ser_schema)?, + when_used: WhenUsed::new(ser_schema, WhenUsed::Always)?, is_field_serializer, } .into()) @@ -385,7 +396,7 @@ impl SerializationCallable { Err(PydanticOmit::new_err()) } } else { - let v = self.serializer.to_python(value, None, None, &extra)?; + let v = self.serializer.to_python(value, include, exclude, &extra)?; extra.warnings.final_check(py)?; Ok(Some(v)) } diff --git a/src/validators/function.rs b/src/validators/function.rs index 900a1f60c..e7b43d0e3 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -3,7 +3,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyAny, PyDict, PyString}; -use crate::build_tools::{function_name, py_err, SchemaDict}; +use crate::build_tools::{destructure_function_schema, function_name, py_err, SchemaDict}; use crate::errors::{ ErrorType, LocItem, PydanticCustomError, PydanticKnownError, PydanticOmit, ValError, ValResult, ValidationError, }; @@ -14,18 +14,6 @@ use crate::recursion_guard::RecursionGuard; use super::generator::InternalValidator; use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; -fn destructure_function_schema(schema: &PyDict) -> PyResult<(bool, &PyAny)> { - let func_dict: &PyDict = schema.get_as_req(intern!(schema.py(), "function"))?; - let function: &PyAny = func_dict.get_as_req(intern!(schema.py(), "function"))?; - let func_type: &str = func_dict.get_as_req(intern!(schema.py(), "type"))?; - let is_field_validator = match func_type { - "field" => true, - "general" => false, - _ => unreachable!(), - }; - Ok((is_field_validator, function)) -} - macro_rules! impl_build { ($impl_name:ident, $name:literal) => { impl BuildValidator for $impl_name { diff --git a/tests/serializers/test_functions.py b/tests/serializers/test_functions.py index cddd247dc..393d53eb5 100644 --- a/tests/serializers/test_functions.py +++ b/tests/serializers/test_functions.py @@ -1,5 +1,7 @@ import json +import sys from collections import deque +from pathlib import Path import pytest @@ -280,13 +282,51 @@ def test_wrong_return_type(): s.to_json(123) +def test_wrong_return_type_str(): + def f(value, _): + if value == 666: + return value + else: + return str(value) + + s = SchemaSerializer( + core_schema.any_schema( + serialization=core_schema.general_plain_serializer_function_ser_schema(f, json_return_type='str_subclass') + ) + ) + assert s.to_python(123) == '123' + assert s.to_python(123, mode='json') == '123' + assert s.to_json(123) == b'"123"' + + assert s.to_python(666) == 666 + + with pytest.raises(TypeError, match="^'int' object cannot be converted to 'PyString'$"): + s.to_python(666, mode='json') + + m = "^Error serializing to JSON: 'int' object cannot be converted to 'PyString'$" + with pytest.raises(PydanticSerializationError, match=m): + s.to_json(666) + + def test_function_wrap(): def f(value, serializer, _info): return f'result={serializer(len(value))} repr={serializer!r}' + s = SchemaSerializer( + core_schema.int_schema(serialization=core_schema.general_wrap_serializer_function_ser_schema(f)) + ) + assert s.to_python('foo') == 'result=3 repr=SerializationCallable(serializer=int)' + assert s.to_python('foo', mode='json') == 'result=3 repr=SerializationCallable(serializer=int)' + assert s.to_json('foo') == b'"result=3 repr=SerializationCallable(serializer=int)"' + + +def test_function_wrap_custom_schema(): + def f(value, serializer, _info): + return f'result={serializer(len(value))} repr={serializer!r}' + s = SchemaSerializer( core_schema.any_schema( - serialization=core_schema.general_wrap_serializer_function_ser_schema(f, core_schema.int_schema()) + serialization=core_schema.general_wrap_serializer_function_ser_schema(f, schema=core_schema.int_schema()) ) ) assert s.to_python('foo') == 'result=3 repr=SerializationCallable(serializer=int)' @@ -308,7 +348,7 @@ def fallback(v): s = SchemaSerializer( core_schema.any_schema( - serialization=core_schema.general_wrap_serializer_function_ser_schema(f, core_schema.any_schema()) + serialization=core_schema.general_wrap_serializer_function_ser_schema(f, schema=core_schema.any_schema()) ) ) assert s.to_python('foo') == 'result=foo' @@ -344,7 +384,7 @@ def serialize_deque(value, serializer, info: core_schema.SerializationInfo): s = SchemaSerializer( core_schema.any_schema( serialization=core_schema.general_wrap_serializer_function_ser_schema( - serialize_deque, core_schema.any_schema() + serialize_deque, schema=core_schema.any_schema() ) ) ) @@ -372,7 +412,7 @@ def serialize_custom_mapping(value, serializer, _info): s = SchemaSerializer( core_schema.any_schema( serialization=core_schema.general_wrap_serializer_function_ser_schema( - serialize_custom_mapping, core_schema.int_schema() + serialize_custom_mapping, schema=core_schema.int_schema() ) ) ) @@ -382,3 +422,102 @@ def serialize_custom_mapping(value, serializer, _info): assert s.to_python({'a': 1, 'b': 2}, mode='json', include={'a'}) == 'a=1' assert s.to_json({'a': 1, 'b': 2}) == b'"a=1 b=2"' assert s.to_json({'a': 1, 'b': 2}, exclude={'b'}) == b'"a=1"' + + +def test_function_wrap_model(): + calls = 0 + + def wrap_function(value, handler, _info): + nonlocal calls + calls += 1 + return handler(value) + + class MyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + s = SchemaSerializer( + core_schema.model_schema( + MyModel, + core_schema.typed_dict_schema( + { + 'a': core_schema.typed_dict_field(core_schema.any_schema()), + 'b': core_schema.typed_dict_field(core_schema.any_schema()), + 'c': core_schema.typed_dict_field(core_schema.any_schema(), serialization_exclude=True), + } + ), + serialization=core_schema.general_wrap_serializer_function_ser_schema(wrap_function), + ) + ) + m = MyModel(a=1, b=b'foobar', c='excluded') + assert calls == 0 + assert s.to_python(m) == {'a': 1, 'b': b'foobar'} + assert calls == 1 + assert s.to_python(m, mode='json') == {'a': 1, 'b': 'foobar'} + assert calls == 2 + assert s.to_json(m) == b'{"a":1,"b":"foobar"}' + assert calls == 3 + + assert s.to_python(m, exclude={'b'}) == {'a': 1} + assert calls == 4 + assert s.to_python(m, mode='json', exclude={'b'}) == {'a': 1} + assert calls == 5 + assert s.to_json(m, exclude={'b'}) == b'{"a":1}' + assert calls == 6 + + +def test_function_plain_model(): + calls = 0 + + def wrap_function(value, _info): + nonlocal calls + calls += 1 + return value.__dict__ + + class MyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + s = SchemaSerializer( + core_schema.model_schema( + MyModel, + core_schema.typed_dict_schema( + { + 'a': core_schema.typed_dict_field(core_schema.any_schema()), + 'b': core_schema.typed_dict_field(core_schema.any_schema()), + 'c': core_schema.typed_dict_field(core_schema.any_schema(), serialization_exclude=True), + } + ), + serialization=core_schema.general_plain_serializer_function_ser_schema(wrap_function), + ) + ) + m = MyModel(a=1, b=b'foobar', c='not excluded') + assert calls == 0 + assert s.to_python(m) == {'a': 1, 'b': b'foobar', 'c': 'not excluded'} + assert calls == 1 + assert s.to_python(m, mode='json') == {'a': 1, 'b': 'foobar', 'c': 'not excluded'} + assert calls == 2 + assert s.to_json(m) == b'{"a":1,"b":"foobar","c":"not excluded"}' + assert calls == 3 + + assert s.to_python(m, exclude={'b'}) == {'a': 1, 'b': b'foobar', 'c': 'not excluded'} + assert calls == 4 + assert s.to_python(m, mode='json', exclude={'b'}) == {'a': 1, 'b': 'foobar', 'c': 'not excluded'} + assert calls == 5 + assert s.to_json(m, exclude={'b'}) == b'{"a":1,"b":"foobar","c":"not excluded"}' + assert calls == 6 + + +@pytest.mark.skipif(sys.platform == 'win32', reason='Path output different on windows') +def test_wrap_return_type(): + def to_path(value, handler, _info): + return Path(handler(value)).with_suffix('.new') + + s = SchemaSerializer( + core_schema.str_schema( + serialization=core_schema.general_wrap_serializer_function_ser_schema(to_path, json_return_type='path') + ) + ) + assert s.to_python('foobar') == Path('foobar.new') + assert s.to_python('foobar', mode='json') == 'foobar.new' + assert s.to_json('foobar') == b'"foobar.new"' diff --git a/tests/test_typing.py b/tests/test_typing.py index b261cc790..fc789d37b 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -237,7 +237,7 @@ def f( s = SchemaSerializer( core_schema.any_schema( serialization=core_schema.general_wrap_serializer_function_ser_schema( - f, core_schema.str_schema(), when_used='json' + f, schema=core_schema.str_schema(), when_used='json' ) ) )