Skip to content

Commit

Permalink
Serialize support Path and add fallback function, JSON improvemen…
Browse files Browse the repository at this point in the history
…ts (#514)

* serialize infer() support Path

* add fallback function to to_json and to_python_jsonable

* hardening json fallback logic

* support fallback passing through ExtraOwned

* fix test on windows
  • Loading branch information
samuelcolvin authored Apr 2, 2023
1 parent 330911a commit cd5cd65
Show file tree
Hide file tree
Showing 11 changed files with 357 additions and 98 deletions.
6 changes: 5 additions & 1 deletion pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import decimal
import sys
from typing import Any
from typing import Any, Callable

from pydantic_core import ErrorDetails, InitErrorDetails
from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType
Expand Down Expand Up @@ -80,6 +80,7 @@ class SchemaSerializer:
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
fallback: 'Callable[[Any], Any] | None' = None,
) -> Any: ...
def to_json(
self,
Expand All @@ -94,6 +95,7 @@ class SchemaSerializer:
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
fallback: 'Callable[[Any], Any] | None' = None,
) -> bytes: ...

def to_json(
Expand All @@ -107,6 +109,7 @@ def to_json(
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
serialize_unknown: bool = False,
fallback: 'Callable[[Any], Any] | None' = None,
) -> bytes: ...
def to_jsonable_python(
value: Any,
Expand All @@ -118,6 +121,7 @@ def to_jsonable_python(
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
serialize_unknown: bool = False,
fallback: 'Callable[[Any], Any] | None' = None,
) -> Any: ...

class Url:
Expand Down
6 changes: 3 additions & 3 deletions src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use serde::{Serialize, Serializer};
use serde_json::ser::PrettyFormatter;

use crate::build_tools::{py_error_type, safe_repr, SchemaDict};
use crate::serializers::GeneralSerializeContext;
use crate::serializers::{SerMode, SerializationState};
use crate::PydanticCustomError;

use super::line_error::ValLineError;
Expand Down Expand Up @@ -140,8 +140,8 @@ impl ValidationError {
indent: Option<usize>,
include_context: Option<bool>,
) -> PyResult<&'py PyString> {
let general_ser_context = GeneralSerializeContext::new();
let extra = general_ser_context.extra(py, true);
let state = SerializationState::new(None, None);
let extra = state.extra(py, &SerMode::Json, None, None, Some(true), None);
let serializer = ValidationErrorSerializer {
py,
line_errors: &self.line_errors,
Expand Down
7 changes: 5 additions & 2 deletions src/serializers/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use pyo3::prelude::*;
use serde::ser;

/// `UNEXPECTED_TYPE_SER` is a special prefix to denote a `PydanticSerializationUnexpectedValue` error.
pub(super) static UNEXPECTED_TYPE_SER: &str = "__PydanticSerializationUnexpectedValue__";
pub(super) static UNEXPECTED_TYPE_SER_MARKER: &str = "__PydanticSerializationUnexpectedValue__";
pub(super) static SERIALIZATION_ERR_MARKER: &str = "__PydanticSerializationError__";

// convert a `PyErr` or `PyDowncastError` into a serde serialization error
pub(super) fn py_err_se_err<T: ser::Error, E: fmt::Display>(py_error: E) -> T {
Expand All @@ -16,12 +17,14 @@ pub(super) fn py_err_se_err<T: ser::Error, E: fmt::Display>(py_error: E) -> T {
/// convert a serde serialization error into a `PyErr`
pub(super) fn se_err_py_err(error: serde_json::Error) -> PyErr {
let s = error.to_string();
if let Some(msg) = s.strip_prefix(UNEXPECTED_TYPE_SER) {
if let Some(msg) = s.strip_prefix(UNEXPECTED_TYPE_SER_MARKER) {
if msg.is_empty() {
PydanticSerializationUnexpectedValue::new_err(None)
} else {
PydanticSerializationUnexpectedValue::new_err(Some(msg.to_string()))
}
} else if let Some(msg) = s.strip_prefix(SERIALIZATION_ERR_MARKER) {
PydanticSerializationError::new_err(msg.to_string())
} else {
let msg = format!("Error serializing to JSON: {s}");
PydanticSerializationError::new_err(msg)
Expand Down
69 changes: 62 additions & 7 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,62 @@ use pyo3::{intern, AsPyPointer};
use ahash::AHashSet;
use serde::ser::Error;

use crate::build_tools::py_err;

use super::config::SerializationConfig;
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER};
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
use super::ob_type::ObTypeLookup;
use super::shared::CombinedSerializer;

/// this is ugly, would be much better if extra could be stored in `SerializationState`
/// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work
pub(crate) struct SerializationState {
warnings: CollectWarnings,
rec_guard: SerRecursionGuard,
config: SerializationConfig,
}

impl SerializationState {
pub fn new(timedelta_mode: Option<&str>, bytes_mode: Option<&str>) -> Self {
let warnings = CollectWarnings::new(None);
let rec_guard = SerRecursionGuard::default();
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode).unwrap();
Self {
warnings,
rec_guard,
config,
}
}

pub fn extra<'py>(
&'py self,
py: Python<'py>,
mode: &'py SerMode,
exclude_none: Option<bool>,
round_trip: Option<bool>,
serialize_unknown: Option<bool>,
fallback: Option<&'py PyAny>,
) -> Extra<'py> {
Extra::new(
py,
mode,
&[],
None,
&self.warnings,
None,
None,
exclude_none,
round_trip,
&self.config,
&self.rec_guard,
serialize_unknown,
fallback,
)
}

pub fn final_check(&self, py: Python) -> PyResult<()> {
self.warnings.final_check(py)
}
}

/// Useful things which are passed around by type_serializers
#[derive(Clone)]
#[cfg_attr(debug_assertions, derive(Debug))]
Expand All @@ -38,6 +87,7 @@ pub(crate) struct Extra<'a> {
pub model: Option<&'a PyAny>,
pub field_name: Option<&'a str>,
pub serialize_unknown: bool,
pub fallback: Option<&'a PyAny>,
}

impl<'a> Extra<'a> {
Expand All @@ -55,6 +105,7 @@ impl<'a> Extra<'a> {
config: &'a SerializationConfig,
rec_guard: &'a SerRecursionGuard,
serialize_unknown: Option<bool>,
fallback: Option<&'a PyAny>,
) -> Self {
Self {
mode,
Expand All @@ -72,6 +123,7 @@ impl<'a> Extra<'a> {
model: None,
field_name: None,
serialize_unknown: serialize_unknown.unwrap_or(false),
fallback,
}
}

Expand Down Expand Up @@ -111,9 +163,10 @@ pub(crate) struct ExtraOwned {
config: SerializationConfig,
rec_guard: SerRecursionGuard,
check: SerCheck,
model: Option<Py<PyAny>>,
model: Option<PyObject>,
field_name: Option<String>,
serialize_unknown: bool,
fallback: Option<PyObject>,
}

impl ExtraOwned {
Expand All @@ -133,6 +186,7 @@ impl ExtraOwned {
model: extra.model.map(|v| v.into()),
field_name: extra.field_name.map(|v| v.to_string()),
serialize_unknown: extra.serialize_unknown,
fallback: extra.fallback.map(|v| v.into()),
}
}

Expand All @@ -153,6 +207,7 @@ impl ExtraOwned {
model: self.model.as_ref().map(|m| m.as_ref(py)),
field_name: self.field_name.as_ref().map(|n| n.as_ref()),
serialize_unknown: self.serialize_unknown,
fallback: self.fallback.as_ref().map(|m| m.as_ref(py)),
}
}
}
Expand Down Expand Up @@ -248,7 +303,7 @@ impl CollectWarnings {
// note: I think this should never actually happen since we use `to_python(..., mode='json')` during
// JSON serialisation to "try" union branches, but it's here for completeness/correctness
// in particular, in future we could allow errors instead of warnings on fallback
Err(S::Error::custom(UNEXPECTED_TYPE_SER))
Err(S::Error::custom(UNEXPECTED_TYPE_SER_MARKER))
} else {
self.fallback_warning(field_type, value);
Ok(())
Expand Down Expand Up @@ -315,9 +370,9 @@ impl SerRecursionGuard {
let id = value.as_ptr() as usize;
let mut info = self.info.borrow_mut();
if !info.ids.insert(id) {
py_err!(PyValueError; "Circular reference detected (id repeated)")
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
} else if info.depth > Self::MAX_DEPTH {
py_err!(PyValueError; "Circular reference detected (depth exceeded)")
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
} else {
info.depth += 1;
Ok(id)
Expand Down
57 changes: 47 additions & 10 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ use pyo3::types::{
PyTime, PyTuple,
};

use serde::ser::{Serialize, SerializeMap, SerializeSeq, Serializer};
use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};

use crate::build_tools::{py_err, safe_repr};
use crate::serializers::errors::SERIALIZATION_ERR_MARKER;
use crate::serializers::filter::SchemaFilter;
use crate::url::{PyMultiHostUrl, PyUrl};

Expand Down Expand Up @@ -179,12 +180,18 @@ pub(crate) fn infer_to_python_known(
}
PyList::new(py, items).into_py(py)
}
ObType::Path => value.str()?.into_py(py),
ObType::Unknown => {
return if extra.serialize_unknown {
Ok(serialize_unknown(value).into_py(py))
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
let next_result = infer_to_python(next_value, include, exclude, extra);
extra.rec_guard.pop(value_id);
return next_result;
} else if extra.serialize_unknown {
serialize_unknown(value).into_py(py)
} else {
Err(unknown_type_error(value))
};
return Err(unknown_type_error(value));
}
}
},
_ => match ob_type {
Expand Down Expand Up @@ -232,6 +239,16 @@ pub(crate) fn infer_to_python_known(
);
iter.into_py(py)
}
ObType::Unknown => {
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
let next_result = infer_to_python(next_value, include, exclude, extra);
extra.rec_guard.pop(value_id);
return next_result;
} else {
value.into_py(py)
}
}
_ => value.into_py(py),
},
};
Expand Down Expand Up @@ -432,11 +449,25 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
}
seq.end()
}
ObType::Path => {
let s = value.str().map_err(py_err_se_err)?.to_str().map_err(py_err_se_err)?;
serializer.serialize_str(s)
}
ObType::Unknown => {
if extra.serialize_unknown {
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
let next_result = infer_serialize(next_value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id);
return next_result;
} else if extra.serialize_unknown {
serializer.serialize_str(&serialize_unknown(value))
} else {
return Err(py_err_se_err(unknown_type_error(value)));
let msg = format!(
"{}Unable to serialize unknown type: {}",
SERIALIZATION_ERR_MARKER,
safe_repr(value)
);
return Err(S::Error::custom(msg));
}
}
};
Expand All @@ -452,9 +483,9 @@ fn serialize_unknown(value: &PyAny) -> Cow<str> {
if let Ok(s) = value.str() {
s.to_string_lossy()
} else if let Ok(name) = value.get_type().name() {
format!("<{name} object cannot be serialized to JSON>").into()
format!("<Unserializable {name} object>").into()
} else {
"<object cannot be serialized to JSON>".into()
"<Unserializable object>".into()
}
}

Expand Down Expand Up @@ -531,8 +562,14 @@ pub(crate) fn infer_json_key_known<'py>(ob_type: &ObType, key: &'py PyAny, extra
let k = key.getattr(intern!(key.py(), "value"))?;
infer_json_key(k, extra)
}
ObType::Path => Ok(key.str()?.to_string_lossy()),
ObType::Unknown => {
if extra.serialize_unknown {
if let Some(fallback) = extra.fallback {
let next_key = fallback.call1((key,))?;
// totally unnecessary step to placate rust's lifetime rules
let next_key = next_key.to_object(key.py()).into_ref(key.py());
infer_json_key(next_key, extra)
} else if extra.serialize_unknown {
Ok(serialize_unknown(key))
} else {
Err(unknown_type_error(key))
Expand Down
Loading

0 comments on commit cd5cd65

Please sign in to comment.