From cd5cd657cc8a57f6310b3fbb19d9657136087eec Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 2 Apr 2023 16:00:28 +0100 Subject: [PATCH] Serialize support `Path` and add `fallback` function, JSON improvements (#514) * serialize infer() support Path * add fallback function to to_json and to_python_jsonable * hardening json fallback logic * support fallback passing through ExtraOwned * fix test on windows --- pydantic_core/_pydantic_core.pyi | 6 +- src/errors/validation_exception.rs | 6 +- src/serializers/errors.rs | 7 ++- src/serializers/extra.rs | 69 +++++++++++++++++++--- src/serializers/infer.rs | 57 ++++++++++++++---- src/serializers/mod.rs | 90 ++++++++--------------------- src/serializers/ob_type.rs | 8 +++ tests/serializers/test_any.py | 86 +++++++++++++++++++++++++-- tests/serializers/test_functions.py | 32 ++++++++++ tests/test_errors.py | 8 +++ tests/test_json.py | 86 +++++++++++++++++++++++++-- 11 files changed, 357 insertions(+), 98 deletions(-) diff --git a/pydantic_core/_pydantic_core.pyi b/pydantic_core/_pydantic_core.pyi index 3e4406e09..49b785897 100644 --- a/pydantic_core/_pydantic_core.pyi +++ b/pydantic_core/_pydantic_core.pyi @@ -1,6 +1,6 @@ import decimal import sys -from typing import Any +from typing import Any, Callable from pydantic_core import ErrorDetails, InitErrorDetails from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType @@ -80,6 +80,7 @@ class SchemaSerializer: exclude_none: bool = False, round_trip: bool = False, warnings: bool = True, + fallback: 'Callable[[Any], Any] | None' = None, ) -> Any: ... def to_json( self, @@ -94,6 +95,7 @@ class SchemaSerializer: exclude_none: bool = False, round_trip: bool = False, warnings: bool = True, + fallback: 'Callable[[Any], Any] | None' = None, ) -> bytes: ... def to_json( @@ -107,6 +109,7 @@ def to_json( timedelta_mode: Literal['iso8601', 'float'] = 'iso8601', bytes_mode: Literal['utf8', 'base64'] = 'utf8', serialize_unknown: bool = False, + fallback: 'Callable[[Any], Any] | None' = None, ) -> bytes: ... def to_jsonable_python( value: Any, @@ -118,6 +121,7 @@ def to_jsonable_python( timedelta_mode: Literal['iso8601', 'float'] = 'iso8601', bytes_mode: Literal['utf8', 'base64'] = 'utf8', serialize_unknown: bool = False, + fallback: 'Callable[[Any], Any] | None' = None, ) -> Any: ... class Url: diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 76336f4f0..92e285862 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -13,7 +13,7 @@ use serde::{Serialize, Serializer}; use serde_json::ser::PrettyFormatter; use crate::build_tools::{py_error_type, safe_repr, SchemaDict}; -use crate::serializers::GeneralSerializeContext; +use crate::serializers::{SerMode, SerializationState}; use crate::PydanticCustomError; use super::line_error::ValLineError; @@ -140,8 +140,8 @@ impl ValidationError { indent: Option, include_context: Option, ) -> PyResult<&'py PyString> { - let general_ser_context = GeneralSerializeContext::new(); - let extra = general_ser_context.extra(py, true); + let state = SerializationState::new(None, None); + let extra = state.extra(py, &SerMode::Json, None, None, Some(true), None); let serializer = ValidationErrorSerializer { py, line_errors: &self.line_errors, diff --git a/src/serializers/errors.rs b/src/serializers/errors.rs index 8b99c1a46..4dcebba20 100644 --- a/src/serializers/errors.rs +++ b/src/serializers/errors.rs @@ -6,7 +6,8 @@ use pyo3::prelude::*; use serde::ser; /// `UNEXPECTED_TYPE_SER` is a special prefix to denote a `PydanticSerializationUnexpectedValue` error. -pub(super) static UNEXPECTED_TYPE_SER: &str = "__PydanticSerializationUnexpectedValue__"; +pub(super) static UNEXPECTED_TYPE_SER_MARKER: &str = "__PydanticSerializationUnexpectedValue__"; +pub(super) static SERIALIZATION_ERR_MARKER: &str = "__PydanticSerializationError__"; // convert a `PyErr` or `PyDowncastError` into a serde serialization error pub(super) fn py_err_se_err(py_error: E) -> T { @@ -16,12 +17,14 @@ pub(super) fn py_err_se_err(py_error: E) -> T { /// convert a serde serialization error into a `PyErr` pub(super) fn se_err_py_err(error: serde_json::Error) -> PyErr { let s = error.to_string(); - if let Some(msg) = s.strip_prefix(UNEXPECTED_TYPE_SER) { + if let Some(msg) = s.strip_prefix(UNEXPECTED_TYPE_SER_MARKER) { if msg.is_empty() { PydanticSerializationUnexpectedValue::new_err(None) } else { PydanticSerializationUnexpectedValue::new_err(Some(msg.to_string())) } + } else if let Some(msg) = s.strip_prefix(SERIALIZATION_ERR_MARKER) { + PydanticSerializationError::new_err(msg.to_string()) } else { let msg = format!("Error serializing to JSON: {s}"); PydanticSerializationError::new_err(msg) diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index a099b11ad..21ec69884 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -8,13 +8,62 @@ use pyo3::{intern, AsPyPointer}; use ahash::AHashSet; use serde::ser::Error; -use crate::build_tools::py_err; - use super::config::SerializationConfig; -use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER}; +use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER}; use super::ob_type::ObTypeLookup; use super::shared::CombinedSerializer; +/// this is ugly, would be much better if extra could be stored in `SerializationState` +/// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work +pub(crate) struct SerializationState { + warnings: CollectWarnings, + rec_guard: SerRecursionGuard, + config: SerializationConfig, +} + +impl SerializationState { + pub fn new(timedelta_mode: Option<&str>, bytes_mode: Option<&str>) -> Self { + let warnings = CollectWarnings::new(None); + let rec_guard = SerRecursionGuard::default(); + let config = SerializationConfig::from_args(timedelta_mode, bytes_mode).unwrap(); + Self { + warnings, + rec_guard, + config, + } + } + + pub fn extra<'py>( + &'py self, + py: Python<'py>, + mode: &'py SerMode, + exclude_none: Option, + round_trip: Option, + serialize_unknown: Option, + fallback: Option<&'py PyAny>, + ) -> Extra<'py> { + Extra::new( + py, + mode, + &[], + None, + &self.warnings, + None, + None, + exclude_none, + round_trip, + &self.config, + &self.rec_guard, + serialize_unknown, + fallback, + ) + } + + pub fn final_check(&self, py: Python) -> PyResult<()> { + self.warnings.final_check(py) + } +} + /// Useful things which are passed around by type_serializers #[derive(Clone)] #[cfg_attr(debug_assertions, derive(Debug))] @@ -38,6 +87,7 @@ pub(crate) struct Extra<'a> { pub model: Option<&'a PyAny>, pub field_name: Option<&'a str>, pub serialize_unknown: bool, + pub fallback: Option<&'a PyAny>, } impl<'a> Extra<'a> { @@ -55,6 +105,7 @@ impl<'a> Extra<'a> { config: &'a SerializationConfig, rec_guard: &'a SerRecursionGuard, serialize_unknown: Option, + fallback: Option<&'a PyAny>, ) -> Self { Self { mode, @@ -72,6 +123,7 @@ impl<'a> Extra<'a> { model: None, field_name: None, serialize_unknown: serialize_unknown.unwrap_or(false), + fallback, } } @@ -111,9 +163,10 @@ pub(crate) struct ExtraOwned { config: SerializationConfig, rec_guard: SerRecursionGuard, check: SerCheck, - model: Option>, + model: Option, field_name: Option, serialize_unknown: bool, + fallback: Option, } impl ExtraOwned { @@ -133,6 +186,7 @@ impl ExtraOwned { model: extra.model.map(|v| v.into()), field_name: extra.field_name.map(|v| v.to_string()), serialize_unknown: extra.serialize_unknown, + fallback: extra.fallback.map(|v| v.into()), } } @@ -153,6 +207,7 @@ impl ExtraOwned { model: self.model.as_ref().map(|m| m.as_ref(py)), field_name: self.field_name.as_ref().map(|n| n.as_ref()), serialize_unknown: self.serialize_unknown, + fallback: self.fallback.as_ref().map(|m| m.as_ref(py)), } } } @@ -248,7 +303,7 @@ impl CollectWarnings { // note: I think this should never actually happen since we use `to_python(..., mode='json')` during // JSON serialisation to "try" union branches, but it's here for completeness/correctness // in particular, in future we could allow errors instead of warnings on fallback - Err(S::Error::custom(UNEXPECTED_TYPE_SER)) + Err(S::Error::custom(UNEXPECTED_TYPE_SER_MARKER)) } else { self.fallback_warning(field_type, value); Ok(()) @@ -315,9 +370,9 @@ impl SerRecursionGuard { let id = value.as_ptr() as usize; let mut info = self.info.borrow_mut(); if !info.ids.insert(id) { - py_err!(PyValueError; "Circular reference detected (id repeated)") + Err(PyValueError::new_err("Circular reference detected (id repeated)")) } else if info.depth > Self::MAX_DEPTH { - py_err!(PyValueError; "Circular reference detected (depth exceeded)") + Err(PyValueError::new_err("Circular reference detected (depth exceeded)")) } else { info.depth += 1; Ok(id) diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index c0d790790..14bbf64e8 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -8,9 +8,10 @@ use pyo3::types::{ PyTime, PyTuple, }; -use serde::ser::{Serialize, SerializeMap, SerializeSeq, Serializer}; +use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer}; use crate::build_tools::{py_err, safe_repr}; +use crate::serializers::errors::SERIALIZATION_ERR_MARKER; use crate::serializers::filter::SchemaFilter; use crate::url::{PyMultiHostUrl, PyUrl}; @@ -179,12 +180,18 @@ pub(crate) fn infer_to_python_known( } PyList::new(py, items).into_py(py) } + ObType::Path => value.str()?.into_py(py), ObType::Unknown => { - return if extra.serialize_unknown { - Ok(serialize_unknown(value).into_py(py)) + if let Some(fallback) = extra.fallback { + let next_value = fallback.call1((value,))?; + let next_result = infer_to_python(next_value, include, exclude, extra); + extra.rec_guard.pop(value_id); + return next_result; + } else if extra.serialize_unknown { + serialize_unknown(value).into_py(py) } else { - Err(unknown_type_error(value)) - }; + return Err(unknown_type_error(value)); + } } }, _ => match ob_type { @@ -232,6 +239,16 @@ pub(crate) fn infer_to_python_known( ); iter.into_py(py) } + ObType::Unknown => { + if let Some(fallback) = extra.fallback { + let next_value = fallback.call1((value,))?; + let next_result = infer_to_python(next_value, include, exclude, extra); + extra.rec_guard.pop(value_id); + return next_result; + } else { + value.into_py(py) + } + } _ => value.into_py(py), }, }; @@ -432,11 +449,25 @@ pub(crate) fn infer_serialize_known( } seq.end() } + ObType::Path => { + let s = value.str().map_err(py_err_se_err)?.to_str().map_err(py_err_se_err)?; + serializer.serialize_str(s) + } ObType::Unknown => { - if extra.serialize_unknown { + if let Some(fallback) = extra.fallback { + let next_value = fallback.call1((value,)).map_err(py_err_se_err)?; + let next_result = infer_serialize(next_value, serializer, include, exclude, extra); + extra.rec_guard.pop(value_id); + return next_result; + } else if extra.serialize_unknown { serializer.serialize_str(&serialize_unknown(value)) } else { - return Err(py_err_se_err(unknown_type_error(value))); + let msg = format!( + "{}Unable to serialize unknown type: {}", + SERIALIZATION_ERR_MARKER, + safe_repr(value) + ); + return Err(S::Error::custom(msg)); } } }; @@ -452,9 +483,9 @@ fn serialize_unknown(value: &PyAny) -> Cow { if let Ok(s) = value.str() { s.to_string_lossy() } else if let Ok(name) = value.get_type().name() { - format!("<{name} object cannot be serialized to JSON>").into() + format!("").into() } else { - "".into() + "".into() } } @@ -531,8 +562,14 @@ pub(crate) fn infer_json_key_known<'py>(ob_type: &ObType, key: &'py PyAny, extra let k = key.getattr(intern!(key.py(), "value"))?; infer_json_key(k, extra) } + ObType::Path => Ok(key.str()?.to_string_lossy()), ObType::Unknown => { - if extra.serialize_unknown { + if let Some(fallback) = extra.fallback { + let next_key = fallback.call1((key,))?; + // totally unnecessary step to placate rust's lifetime rules + let next_key = next_key.to_object(key.py()).into_ref(key.py()); + infer_json_key(next_key, extra) + } else if extra.serialize_unknown { Ok(serialize_unknown(key)) } else { Err(unknown_type_error(key)) diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 80a151e7f..621f014c1 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -9,8 +9,8 @@ use crate::validators::SelfValidator; use config::SerializationConfig; pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue}; -pub(crate) use extra::Extra; -use extra::{CollectWarnings, SerMode, SerRecursionGuard}; +use extra::{CollectWarnings, SerRecursionGuard}; +pub(crate) use extra::{Extra, SerMode, SerializationState}; pub use shared::CombinedSerializer; use shared::{to_json_bytes, BuildSerializer, TypeSerializer}; @@ -50,6 +50,9 @@ impl SchemaSerializer { } #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (value, *, mode = None, include = None, exclude = None, by_alias = None, + exclude_unset = false, exclude_defaults = false, exclude_none = false, round_trip = false, warnings = true, + fallback = None))] pub fn to_python( &self, py: Python, @@ -63,6 +66,7 @@ impl SchemaSerializer { exclude_none: Option, round_trip: Option, warnings: Option, + fallback: Option<&PyAny>, ) -> PyResult { let mode: SerMode = mode.into(); let warnings = CollectWarnings::new(warnings); @@ -80,6 +84,7 @@ impl SchemaSerializer { &self.config, &rec_guard, None, + fallback, ); let v = self.serializer.to_python(value, include, exclude, &extra)?; warnings.final_check(py)?; @@ -87,6 +92,9 @@ impl SchemaSerializer { } #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = None, + exclude_unset = false, exclude_defaults = false, exclude_none = false, round_trip = false, warnings = true, + fallback = None))] pub fn to_json( &mut self, py: Python, @@ -100,6 +108,7 @@ impl SchemaSerializer { exclude_none: Option, round_trip: Option, warnings: Option, + fallback: Option<&PyAny>, ) -> PyResult { let warnings = CollectWarnings::new(warnings); let rec_guard = SerRecursionGuard::default(); @@ -116,6 +125,7 @@ impl SchemaSerializer { &self.config, &rec_guard, None, + fallback, ); let bytes = to_json_bytes( value, @@ -160,7 +170,7 @@ impl SchemaSerializer { #[allow(clippy::too_many_arguments)] #[pyfunction] #[pyo3(signature = (value, *, indent = None, include = None, exclude = None, exclude_none = false, round_trip = false, - timedelta_mode = None, bytes_mode = None, serialize_unknown = false))] + timedelta_mode = None, bytes_mode = None, serialize_unknown = false, fallback = None))] pub fn to_json( py: Python, value: &PyAny, @@ -172,27 +182,20 @@ pub fn to_json( timedelta_mode: Option<&str>, bytes_mode: Option<&str>, serialize_unknown: Option, + fallback: Option<&PyAny>, ) -> PyResult { - let warnings = CollectWarnings::new(None); - let rec_guard = SerRecursionGuard::default(); - let config = SerializationConfig::from_args(timedelta_mode, bytes_mode)?; - let extra = Extra::new( + let state = SerializationState::new(timedelta_mode, bytes_mode); + let extra = state.extra( py, &SerMode::Json, - &[], - None, - &warnings, - None, - None, exclude_none, round_trip, - &config, - &rec_guard, serialize_unknown, + fallback, ); let serializer = type_serializers::any::AnySerializer::default().into(); let bytes = to_json_bytes(value, &serializer, include, exclude, &extra, indent, 1024)?; - warnings.final_check(py)?; + state.final_check(py)?; let py_bytes = PyBytes::new(py, &bytes); Ok(py_bytes.into()) } @@ -200,7 +203,7 @@ pub fn to_json( #[allow(clippy::too_many_arguments)] #[pyfunction] #[pyo3(signature = (value, *, include = None, exclude = None, exclude_none = false, round_trip = false, - timedelta_mode = None, bytes_mode = None, serialize_unknown = false))] + timedelta_mode = None, bytes_mode = None, serialize_unknown = false, fallback = None))] pub fn to_jsonable_python( py: Python, value: &PyAny, @@ -211,63 +214,18 @@ pub fn to_jsonable_python( timedelta_mode: Option<&str>, bytes_mode: Option<&str>, serialize_unknown: Option, + fallback: Option<&PyAny>, ) -> PyResult { - let warnings = CollectWarnings::new(None); - let rec_guard = SerRecursionGuard::default(); - let config = SerializationConfig::from_args(timedelta_mode, bytes_mode)?; - let extra = Extra::new( + let state = SerializationState::new(timedelta_mode, bytes_mode); + let extra = state.extra( py, &SerMode::Json, - &[], - None, - &warnings, - None, - None, exclude_none, round_trip, - &config, - &rec_guard, serialize_unknown, + fallback, ); let v = infer::infer_to_python(value, include, exclude, &extra)?; - warnings.final_check(py)?; + state.final_check(py)?; Ok(v) } - -/// this is ugly, but would be much better if extra could be stored in `GeneralSerializeContext` -/// then `GeneralSerializeContext` got a `serialize_infer` method, but I couldn't get it to work -pub(crate) struct GeneralSerializeContext { - warnings: CollectWarnings, - rec_guard: SerRecursionGuard, - config: SerializationConfig, -} - -impl GeneralSerializeContext { - pub fn new() -> Self { - let warnings = CollectWarnings::new(None); - let rec_guard = SerRecursionGuard::default(); - let config = SerializationConfig::from_args(None, None).unwrap(); - Self { - warnings, - rec_guard, - config, - } - } - - pub fn extra<'py>(&'py self, py: Python<'py>, serialize_unknown: bool) -> Extra<'py> { - Extra::new( - py, - &SerMode::Json, - &[], - None, - &self.warnings, - None, - None, - None, - None, - &self.config, - &self.rec_guard, - Some(serialize_unknown), - ) - } -} diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index edc88b82a..cb0bb0f71 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -43,6 +43,8 @@ pub struct ObTypeLookup { enum_type: usize, // generator generator: usize, + // path + path: usize, } static TYPE_LOOKUP: GILOnceCell = GILOnceCell::new(); @@ -80,6 +82,7 @@ impl ObTypeLookup { multi_host_url: PyMultiHostUrl::new(lib_url, None).into_py(py).as_ref(py).get_type_ptr() as usize, enum_type: py.import("enum").unwrap().getattr("Enum").unwrap().get_type_ptr() as usize, generator: py.import("types").unwrap().getattr("GeneratorType").unwrap().as_ptr() as usize, + path: py.import("pathlib").unwrap().getattr("Path").unwrap().as_ptr() as usize, } } @@ -126,6 +129,7 @@ impl ObTypeLookup { ObType::Model => is_pydantic_model(op_value), ObType::Enum => self.enum_type == ob_type, ObType::Generator => self.generator == ob_type, + ObType::Path => self.path == ob_type, ObType::Unknown => false, }; @@ -212,6 +216,8 @@ impl ObTypeLookup { ObType::Enum } else if ob_type == self.generator || is_generator(op_value) { ObType::Generator + } else if ob_type == self.path { + ObType::Path } else { // this allows for subtypes of the supported class types, // if `ob_type` didn't match any member of self, we try again with the next base type pointer @@ -306,6 +312,8 @@ pub enum ObType { Enum, // generator type Generator, + // Path + Path, // unknown type Unknown, } diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index 23d2fd527..81717540f 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -1,12 +1,15 @@ import dataclasses import json +import sys +from collections import namedtuple from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal from enum import Enum +from pathlib import Path from typing import ClassVar import pytest -from dirty_equals import IsList +from dirty_equals import HasRepr, IsList from pydantic_core import PydanticSerializationError, SchemaSerializer, core_schema, to_json @@ -90,6 +93,18 @@ def test_set_member_db(any_serializer): (MyDataclass(1, 'foo', 2), b'{"a":1,"b":"foo"}'), (MyModel(a=1, b='foo'), b'{"a":1,"b":"foo"}'), ([MyDataclass(1, 'a', 2), MyModel(a=2, b='b')], b'[{"a":1,"b":"a"},{"a":2,"b":"b"}]'), + pytest.param( + Path('/foo/bar/spam.svg'), + b'"/foo/bar/spam.svg"', + marks=pytest.mark.skipif(sys.platform == 'win32', reason='Path output different on windows'), + ), + pytest.param( + Path(r'C:\\foo\\bar\\spam.svg'), + b'"C:\\\\foo\\\\bar\\\\spam.svg"', + marks=pytest.mark.skipif(sys.platform != 'win32', reason='Path output different on windows'), + ), + # I'm open to adding custom logic to make namedtuples behave like dataclasses or models + (namedtuple('Point', ['x', 'y'])(1, 2), b'[1,2]'), ], ) def test_any_json(any_serializer, value, expected_json): @@ -255,11 +270,12 @@ def test_exclude_unset(any_serializer): assert any_serializer.to_python(m2, exclude_unset=True) == {'bar': 2, 'spam': 3} -def test_unknown_type(any_serializer): - class Foobar: - def __repr__(self): - return '' +class Foobar: + def __repr__(self): + return '' + +def test_unknown_type(any_serializer: SchemaSerializer): f = Foobar() assert any_serializer.to_python(f) == f @@ -270,6 +286,66 @@ def __repr__(self): any_serializer.to_json(f) +def test_unknown_type_fallback(any_serializer: SchemaSerializer): + def fallback_func(obj): + return f'fallback:{obj!r}' + + f = Foobar() + assert any_serializer.to_python(f) == f + + assert any_serializer.to_python(f, mode='json', fallback=fallback_func) == 'fallback:' + assert any_serializer.to_python(f, fallback=fallback_func) == 'fallback:' + assert any_serializer.to_json(f, fallback=fallback_func) == b'"fallback:"' + + +def test_fallback_cycle_same(any_serializer: SchemaSerializer): + def fallback_func(obj): + return obj + + f = Foobar() + assert any_serializer.to_python(f) == f + + with pytest.raises(ValueError, match=r'Circular reference detected \(id repeated\)'): + any_serializer.to_python(f, mode='json', fallback=fallback_func) + + # because when recursion is detected and we're in mode python, we just return the value + assert any_serializer.to_python(f, fallback=fallback_func) == f + + with pytest.raises(ValueError, match=r'Circular reference detected \(id repeated\)'): + any_serializer.to_json(f, fallback=fallback_func) + + +class FoobarCount: + def __init__(self, v): + self.v = v + + def __repr__(self): + return f'' + + +def test_fallback_cycle_change(any_serializer: SchemaSerializer): + v = 1 + + def fallback_func(obj): + nonlocal v + v += 1 + return FoobarCount(v) + + f = FoobarCount(0) + assert any_serializer.to_python(f) == f + + with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'): + any_serializer.to_python(f, mode='json', fallback=fallback_func) + + f = FoobarCount(0) + v = 0 + # because when recursion is detected and we're in mode python, we just return the value + assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr('') + + with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'): + any_serializer.to_json(f, fallback=fallback_func) + + class MyEnum(Enum): a = 1 b = 'b' diff --git a/tests/serializers/test_functions.py b/tests/serializers/test_functions.py index bf5e2671e..cddd247dc 100644 --- a/tests/serializers/test_functions.py +++ b/tests/serializers/test_functions.py @@ -294,6 +294,38 @@ def f(value, serializer, _info): assert s.to_json('foo') == b'"result=3 repr=SerializationCallable(serializer=int)"' +class Foobar: + def __str__(self): + return 'foobar!' + + +def test_function_wrap_fallback(): + def f(value, serializer, _info): + return f'result={serializer(value)}' + + def fallback(v): + return f'fallback:{v}' + + s = SchemaSerializer( + core_schema.any_schema( + serialization=core_schema.general_wrap_serializer_function_ser_schema(f, core_schema.any_schema()) + ) + ) + assert s.to_python('foo') == 'result=foo' + assert s.to_python('foo', mode='json') == 'result=foo' + assert s.to_json('foo') == b'"result=foo"' + + assert s.to_python(Foobar()) == 'result=foobar!' + with pytest.raises(PydanticSerializationError, match='Error calling function `f`'): + assert s.to_python(Foobar(), mode='json') == 'result=foobar!' + with pytest.raises(PydanticSerializationError, match='Error calling function `f`'): + assert s.to_json(Foobar()) == b'"result=foobar!"' + + assert s.to_python(Foobar(), fallback=fallback) == 'result=fallback:foobar!' + assert s.to_python(Foobar(), mode='json', fallback=fallback) == 'result=fallback:foobar!' + assert s.to_json(Foobar(), fallback=fallback) == b'"result=fallback:foobar!"' + + def test_deque(): def serialize_deque(value, serializer, info: core_schema.SerializationInfo): items = [] diff --git a/tests/test_errors.py b/tests/test_errors.py index 023dd03b2..a5d9dca36 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -424,6 +424,14 @@ def test_error_on_repr(): assert exc_info.value.errors() == [ {'type': 'int_type', 'loc': (), 'msg': 'Input should be a valid integer', 'input': IsInstance(BadRepr)} ] + assert json.loads(exc_info.value.json()) == [ + { + 'type': 'int_type', + 'loc': [], + 'msg': 'Input should be a valid integer', + 'input': '', + } + ] def test_error_json(): diff --git a/tests/test_json.py b/tests/test_json.py index 230eb3ec1..d3d265a22 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -182,19 +182,27 @@ def __str__(self): return 'Foobar.__str__' +def fallback_func(v): + return f'fallback:{type(v).__name__}' + + def test_to_json(): assert to_json([1, 2]) == b'[1,2]' assert to_json([1, 2], indent=2) == b'[\n 1,\n 2\n]' assert to_json([1, b'x']) == b'[1,"x"]' + # kwargs required + with pytest.raises(TypeError, match=r'to_json\(\) takes 1 positional arguments but 2 were given'): + to_json([1, 2], 2) + + +def test_to_json_fallback(): with pytest.raises(PydanticSerializationError, match=r'Unable to serialize unknown type: <.+\.Foobar'): to_json(Foobar()) assert to_json(Foobar(), serialize_unknown=True) == b'"Foobar.__str__"' - - # kwargs required - with pytest.raises(TypeError, match=r'to_json\(\) takes 1 positional arguments but 2 were given'): - to_json([1, 2], 2) + assert to_json(Foobar(), serialize_unknown=True, fallback=fallback_func) == b'"fallback:Foobar"' + assert to_json(Foobar(), fallback=fallback_func) == b'"fallback:Foobar"' def test_to_jsonable_python(): @@ -203,7 +211,77 @@ def test_to_jsonable_python(): assert to_jsonable_python([1, b'x']) == [1, 'x'] assert to_jsonable_python([0, 1, 2, 3, 4], exclude={1, 3}) == [0, 2, 4] + +def test_to_jsonable_python_fallback(): with pytest.raises(PydanticSerializationError, match=r'Unable to serialize unknown type: <.+\.Foobar'): to_jsonable_python(Foobar()) assert to_jsonable_python(Foobar(), serialize_unknown=True) == 'Foobar.__str__' + assert to_jsonable_python(Foobar(), serialize_unknown=True, fallback=fallback_func) == 'fallback:Foobar' + assert to_jsonable_python(Foobar(), fallback=fallback_func) == 'fallback:Foobar' + + +def test_cycle_same(): + def fallback_func_passthrough(obj): + return obj + + f = Foobar() + + with pytest.raises(ValueError, match=r'Circular reference detected \(id repeated\)'): + to_jsonable_python(f, fallback=fallback_func_passthrough) + + with pytest.raises(ValueError, match=r'Circular reference detected \(id repeated\)'): + to_json(f, fallback=fallback_func_passthrough) + + +def test_cycle_change(): + def fallback_func_change_id(obj): + return Foobar() + + f = Foobar() + + with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'): + to_jsonable_python(f, fallback=fallback_func_change_id) + + with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'): + to_json(f, fallback=fallback_func_change_id) + + +class FoobarHash: + def __str__(self): + return 'Foobar.__str__' + + def __hash__(self): + return 1 + + +def test_json_key_fallback(): + x = {FoobarHash(): 1} + + assert to_jsonable_python(x, serialize_unknown=True) == {'Foobar.__str__': 1} + assert to_jsonable_python(x, fallback=fallback_func) == {'fallback:FoobarHash': 1} + assert to_json(x, serialize_unknown=True) == b'{"Foobar.__str__":1}' + assert to_json(x, fallback=fallback_func) == b'{"fallback:FoobarHash":1}' + + +class BadRepr: + def __repr__(self): + raise ValueError('bad repr') + + def __hash__(self): + return 1 + + +def test_bad_repr(): + b = BadRepr() + + error_msg = '^Unable to serialize unknown type: $' + with pytest.raises(PydanticSerializationError, match=error_msg): + to_jsonable_python(b) + + assert to_jsonable_python(b, serialize_unknown=True) == '' + + with pytest.raises(PydanticSerializationError, match=error_msg): + to_json(b) + + assert to_json(b, serialize_unknown=True) == b'""'