Skip to content

Commit

Permalink
✨ Implement root model serialization (#613)
Browse files Browse the repository at this point in the history
  • Loading branch information
lig authored May 19, 2023
1 parent 4dbc875 commit 6667c07
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 22 deletions.
32 changes: 16 additions & 16 deletions src/serializers/type_serializers/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl FunctionPlainSerializer {
self.func.call1(py, (model, value))
}
} else {
Err(PyRuntimeError::new_err("This serializer expected to be run inside the context of a model field but no model field was found"))
Err(PyRuntimeError::new_err("Function plain serializer expected to be run inside the context of a model field but no model was found"))
}
} else if self.info_arg {
let info = SerializationInfo::new(py, include, exclude, extra, self.is_field_serializer)?;
Expand Down Expand Up @@ -368,7 +368,7 @@ impl FunctionWrapSerializer {
self.func.call1(py, (model, value, serialize))
}
} else {
Err(PyRuntimeError::new_err("This serializer expected to be run inside the context of a model field but no model field was found"))
Err(PyRuntimeError::new_err("Function wrap serializer expected to be run inside the context of a model field but no model was found"))
}
} else if self.info_arg {
let info = SerializationInfo::new(py, include, exclude, extra, self.is_field_serializer)?;
Expand Down Expand Up @@ -492,20 +492,20 @@ impl SerializationInfo {
) -> PyResult<Self> {
if is_field_serializer {
match extra.field_name {
Some(field_name) => Ok(
Self {
include: include.map(|i| i.into_py(py)),
exclude: exclude.map(|e| e.into_py(py)),
_mode: extra.mode.clone(),
by_alias: extra.by_alias,
exclude_unset: extra.exclude_unset,
exclude_defaults: extra.exclude_defaults,
exclude_none: extra.exclude_none,
round_trip: extra.round_trip,
field_name: Some(field_name.to_string()),
}
),
_ => Err(PyRuntimeError::new_err("This serializer expected to be run inside the context of a model field but no model field was found")),
Some(field_name) => Ok(Self {
include: include.map(|i| i.into_py(py)),
exclude: exclude.map(|e| e.into_py(py)),
_mode: extra.mode.clone(),
by_alias: extra.by_alias,
exclude_unset: extra.exclude_unset,
exclude_defaults: extra.exclude_defaults,
exclude_none: extra.exclude_none,
round_trip: extra.round_trip,
field_name: Some(field_name.to_string()),
}),
_ => Err(PyRuntimeError::new_err(
"Model field context expected for field serialization info but no model field was found",
)),
}
} else {
Ok(Self {
Expand Down
24 changes: 20 additions & 4 deletions src/serializers/type_serializers/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use super::{
SerField, TypeSerializer,
};

const ROOT_FIELD: &str = "root";

pub struct ModelFieldsBuilder;

impl BuildSerializer for ModelFieldsBuilder {
Expand Down Expand Up @@ -66,6 +68,7 @@ pub struct ModelSerializer {
class: Py<PyType>,
serializer: Box<CombinedSerializer>,
has_extra: bool,
root_model: bool,
name: String,
}

Expand All @@ -85,11 +88,13 @@ impl BuildSerializer for ModelSerializer {
let class: &PyType = schema.get_as_req(intern!(py, "cls"))?;
let sub_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?;
let serializer = Box::new(CombinedSerializer::build(sub_schema, config, definitions)?);
let root_model = schema.get_as(intern!(py, "root_model"))?.unwrap_or(false);

Ok(Self {
class: class.into(),
serializer,
has_extra: has_extra(schema, config)?,
root_model,
name: class.getattr(intern!(py, "__name__"))?.extract()?,
}
.into())
Expand Down Expand Up @@ -139,11 +144,16 @@ impl TypeSerializer for ModelSerializer {
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
let extra = Extra {
let mut extra = Extra {
model: Some(value),
..*extra
};
if self.allow_value(value, &extra)? {
if self.root_model {
extra.field_name = Some(ROOT_FIELD);
let py = value.py();
let root = value.getattr(intern!(py, ROOT_FIELD))?;
self.serializer.to_python(root, include, exclude, &extra)
} else if self.allow_value(value, &extra)? {
let inner_value = self.get_inner_value(value, &extra)?;
self.serializer.to_python(inner_value, include, exclude, &extra)
} else {
Expand All @@ -169,11 +179,17 @@ impl TypeSerializer for ModelSerializer {
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let extra = Extra {
let mut extra = Extra {
model: Some(value),
..*extra
};
if self.allow_value(value, &extra).map_err(py_err_se_err)? {
if self.root_model {
extra.field_name = Some(ROOT_FIELD);
let py = value.py();
let root = value.getattr(intern!(py, ROOT_FIELD)).map_err(py_err_se_err)?;
self.serializer
.serde_serialize(root, serializer, include, exclude, &extra)
} else if self.allow_value(value, &extra).map_err(py_err_se_err)? {
let inner_value = self.get_inner_value(value, &extra).map_err(py_err_se_err)?;
self.serializer
.serde_serialize(inner_value, serializer, include, exclude, &extra)
Expand Down
4 changes: 2 additions & 2 deletions tests/serializers/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,5 +857,5 @@ class InnerModel:
s = SchemaSerializer(schema)
# debug(s)
s_repr = plain_repr(s)
assert 'has_extra:true,name:"InnerModel"' in s_repr
assert 'has_extra:false,name:"OuterModel"' in s_repr
assert 'has_extra:true,root_model:false,name:"InnerModel"' in s_repr
assert 'has_extra:false,root_model:false,name:"OuterModel"' in s_repr
131 changes: 131 additions & 0 deletions tests/serializers/test_model_root.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import json
import platform
from typing import Any

try:
from functools import cached_property
except ImportError:
cached_property = None


from pydantic_core import SchemaSerializer, core_schema

from ..conftest import plain_repr

on_pypy = platform.python_implementation() == 'PyPy'
# pypy doesn't seem to maintain order of `__dict__`
if on_pypy:
IsStrictDict = dict
else:
pass


class RootModel:
__slots__ = 'root'
root: str

def __init__(self, data):
self.root = data


class RootSubModel(RootModel):
pass


def test_model_root():
s = SchemaSerializer(core_schema.model_schema(RootModel, core_schema.int_schema(), root_model=True))
print(plain_repr(s))
# TODO: assert 'mode:RootModel' in plain_repr(s)
assert 'has_extra:false' in plain_repr(s)
assert s.to_python(RootModel(1)) == 1
assert s.to_python(RootSubModel(1)) == 1

j = s.to_json(RootModel(1))
if on_pypy:
assert json.loads(j) == 1
else:
assert j == b'1'

assert json.loads(s.to_json(RootSubModel(1))) == 1


def test_function_plain_field_serializer_to_python():
class Model(RootModel):
def ser_root(self, v: Any, _) -> str:
assert self.root == 1_000
return f'{v:_}'

s = SchemaSerializer(
core_schema.model_schema(
Model,
core_schema.int_schema(
serialization=core_schema.plain_serializer_function_ser_schema(
Model.ser_root, is_field_serializer=True, info_arg=True
)
),
root_model=True,
)
)
assert s.to_python(Model(1000)) == '1_000'


def test_function_wrap_field_serializer_to_python():
class Model(RootModel):
def ser_root(self, v: Any, serializer: core_schema.SerializerFunctionWrapHandler, _) -> str:
root = serializer(v)
assert self.root == 1_000
return f'{root:_}'

s = SchemaSerializer(
core_schema.model_schema(
Model,
core_schema.int_schema(
serialization=core_schema.wrap_serializer_function_ser_schema(
Model.ser_root, is_field_serializer=True, info_arg=True, schema=core_schema.any_schema()
)
),
root_model=True,
)
)
assert s.to_python(Model(1000)) == '1_000'


def test_function_plain_field_serializer_to_json():
class Model(RootModel):
def ser_root(self, v: Any, _) -> str:
assert self.root == 1_000
return f'{v:_}'

s = SchemaSerializer(
core_schema.model_schema(
Model,
core_schema.int_schema(
serialization=core_schema.plain_serializer_function_ser_schema(
Model.ser_root, is_field_serializer=True, info_arg=True
)
),
root_model=True,
)
)
assert json.loads(s.to_json(Model(1000))) == '1_000'


def test_function_wrap_field_serializer_to_json():
class Model(RootModel):
def ser_root(self, v: Any, serializer: core_schema.SerializerFunctionWrapHandler, _) -> str:
assert self.root == 1_000
root = serializer(v)
return f'{root:_}'

s = SchemaSerializer(
core_schema.model_schema(
Model,
core_schema.int_schema(
serialization=core_schema.wrap_serializer_function_ser_schema(
Model.ser_root, is_field_serializer=True, info_arg=True, schema=core_schema.any_schema()
)
),
root_model=True,
)
)
assert json.loads(s.to_json(Model(1000))) == '1_000'
13 changes: 13 additions & 0 deletions tests/validators/test_model_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,16 @@ def f(input_value: str, info):
"ValidationInfo(config=None, context='call 1', field_name='root')",
"ValidationInfo(config=None, context='assignment call', field_name='root')",
]


def test_extra():
class RootModel:
__slots__ = 'root'
root: int

v = SchemaValidator(core_schema.model_schema(RootModel, core_schema.int_schema(), root_model=True))

m = v.validate_python(1)

with pytest.raises(AttributeError):
m.__pydantic_extra__

0 comments on commit 6667c07

Please sign in to comment.