Skip to content

Commit

Permalink
Implement frozen and extra_behavior for dataclasses (#505)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb authored Mar 31, 2023
1 parent 8b2be89 commit af07999
Show file tree
Hide file tree
Showing 8 changed files with 488 additions and 166 deletions.
18 changes: 15 additions & 3 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def dict_not_none(**kwargs: Any) -> Any:
return {k: v for k, v in kwargs.items() if v is not None}


ExtraBehavior = Literal['allow', 'forbid', 'ignore']


class CoreConfig(TypedDict, total=False):
title: str
strict: bool
Expand All @@ -27,7 +30,7 @@ class CoreConfig(TypedDict, total=False):
# if configs are merged, which should take precedence, default 0, default means child takes precedence
config_merge_priority: int
# settings related to typed_dicts only
typed_dict_extra_behavior: Literal['allow', 'forbid', 'ignore']
extra_fields_behavior: ExtraBehavior
typed_dict_total: bool # default: True
# used on typed-dicts and tagged union keys
from_attributes: bool
Expand Down Expand Up @@ -2494,7 +2497,7 @@ class TypedDictSchema(TypedDict, total=False):
extra_validator: CoreSchema
return_fields_set: bool
# all these values can be set via config, equivalent fields have `typed_dict_` prefix
extra_behavior: Literal['allow', 'forbid', 'ignore']
extra_behavior: ExtraBehavior
total: bool # default: True
populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1
from_attributes: bool
Expand All @@ -2509,7 +2512,7 @@ def typed_dict_schema(
strict: bool | None = None,
extra_validator: CoreSchema | None = None,
return_fields_set: bool | None = None,
extra_behavior: Literal['allow', 'forbid', 'ignore'] | None = None,
extra_behavior: ExtraBehavior | None = None,
total: bool | None = None,
populate_by_name: bool | None = None,
from_attributes: bool | None = None,
Expand Down Expand Up @@ -2645,6 +2648,7 @@ class DataclassField(TypedDict, total=False):
schema: Required[CoreSchema]
kw_only: bool # default: True
init_only: bool # default: False
frozen: bool # default: False
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
serialization_alias: str
serialization_exclude: bool # default: False
Expand All @@ -2661,6 +2665,7 @@ def dataclass_field(
serialization_alias: str | None = None,
serialization_exclude: bool | None = None,
metadata: Any = None,
frozen: bool | None = None,
) -> DataclassField:
"""
Returns a schema for a dataclass field, e.g.:
Expand Down Expand Up @@ -2696,6 +2701,7 @@ def dataclass_field(
serialization_alias=serialization_alias,
serialization_exclude=serialization_exclude,
metadata=metadata,
frozen=frozen,
)


Expand All @@ -2708,6 +2714,7 @@ class DataclassArgsSchema(TypedDict, total=False):
ref: str
metadata: Any
serialization: SerSchema
extra_behavior: ExtraBehavior


def dataclass_args_schema(
Expand All @@ -2718,6 +2725,7 @@ def dataclass_args_schema(
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
extra_behavior: ExtraBehavior | None = None,
) -> DataclassArgsSchema:
"""
Returns a schema for validating dataclass arguments, e.g.:
Expand Down Expand Up @@ -2754,6 +2762,7 @@ def dataclass_args_schema(
ref=ref,
metadata=metadata,
serialization=serialization,
extra_behavior=extra_behavior,
)


Expand All @@ -2764,6 +2773,7 @@ class DataclassSchema(TypedDict, total=False):
post_init: bool # default: False
revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never'
strict: bool # default: False
frozen: bool # default False
ref: str
metadata: Any
serialization: SerSchema
Expand All @@ -2779,6 +2789,7 @@ def dataclass_schema(
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
frozen: bool | None = None,
) -> DataclassSchema:
"""
Returns a schema for a dataclass. As with `ModelSchema`, this schema can only be used as a field within
Expand All @@ -2805,6 +2816,7 @@ def dataclass_schema(
ref=ref,
metadata=metadata,
serialization=serialization,
frozen=frozen,
)


Expand Down
32 changes: 32 additions & 0 deletions src/build_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,35 @@ pub fn safe_repr(v: &PyAny) -> Cow<str> {
"<unprintable object>".into()
}
}

#[derive(Debug, Clone)]
pub(crate) enum ExtraBehavior {
Allow,
Forbid,
Ignore,
}

impl ExtraBehavior {
pub fn from_schema_or_config(
py: Python,
schema: &PyDict,
config: Option<&PyDict>,
default: Self,
) -> PyResult<Self> {
let extra_behavior = schema_or_config::<Option<&str>>(
schema,
config,
intern!(py, "extra_behavior"),
intern!(py, "extra_fields_behavior"),
)?
.flatten();
let res = match extra_behavior {
Some("allow") => Self::Allow,
Some("ignore") => Self::Ignore,
Some("forbid") => Self::Forbid,
Some(v) => return py_err!("Invalid extra_behavior: `{}`", v),
None => default,
};
Ok(res)
}
}
13 changes: 5 additions & 8 deletions src/serializers/type_serializers/typed_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use ahash::{AHashMap, AHashSet};
use serde::ser::SerializeMap;

use crate::build_context::BuildContext;
use crate::build_tools::{py_error_type, schema_or_config, SchemaDict};
use crate::build_tools::{py_error_type, schema_or_config, ExtraBehavior, SchemaDict};
use crate::PydanticSerializationUnexpectedValue;

use super::{
Expand Down Expand Up @@ -80,16 +80,13 @@ impl BuildSerializer for TypedDictSerializer {
) -> PyResult<CombinedSerializer> {
let py = schema.py();

let extra_behavior = schema_or_config::<&str>(
schema,
config,
intern!(py, "extra_behavior"),
intern!(py, "typed_dict_extra_behavior"),
)?;
let total =
schema_or_config(schema, config, intern!(py, "total"), intern!(py, "typed_dict_total"))?.unwrap_or(true);

let include_extra = extra_behavior == Some("allow");
let include_extra = matches!(
ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?,
ExtraBehavior::Allow
);

let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?;
let mut fields: AHashMap<String, TypedDictField> = AHashMap::with_capacity(fields_dict.len());
Expand Down
71 changes: 54 additions & 17 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pyo3::types::{PyDict, PyList, PyString, PyTuple, PyType};

use ahash::AHashSet;

use crate::build_tools::{is_strict, py_err, schema_or_config_same, SchemaDict};
use crate::build_tools::{is_strict, py_err, schema_or_config_same, ExtraBehavior, SchemaDict};
use crate::errors::{ErrorType, ValError, ValLineError, ValResult};
use crate::input::{GenericArguments, Input};
use crate::lookup_key::LookupKey;
Expand All @@ -24,6 +24,7 @@ struct Field {
init_only: bool,
lookup_key: LookupKey,
validator: CombinedValidator,
frozen: bool,
}

#[derive(Debug, Clone)]
Expand All @@ -33,6 +34,7 @@ pub struct DataclassArgsValidator {
init_only_count: Option<usize>,
dataclass_name: String,
validator_name: String,
extra_behavior: ExtraBehavior,
}

impl BuildValidator for DataclassArgsValidator {
Expand All @@ -47,6 +49,8 @@ impl BuildValidator for DataclassArgsValidator {

let populate_by_name = schema_or_config_same(schema, config, intern!(py, "populate_by_name"))?.unwrap_or(false);

let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?;

let fields_schema: &PyList = schema.get_as_req(intern!(py, "fields"))?;
let mut fields: Vec<Field> = Vec::with_capacity(fields_schema.len());

Expand Down Expand Up @@ -91,6 +95,7 @@ impl BuildValidator for DataclassArgsValidator {
lookup_key,
validator,
init_only: field.get_as(intern!(py, "init_only"))?.unwrap_or(false),
frozen: field.get_as::<bool>(intern!(py, "frozen"))?.unwrap_or(false),
});
}

Expand All @@ -108,6 +113,7 @@ impl BuildValidator for DataclassArgsValidator {
init_only_count,
dataclass_name,
validator_name,
extra_behavior,
}
.into())
}
Expand Down Expand Up @@ -254,11 +260,20 @@ impl Validator for DataclassArgsValidator {
match raw_key.strict_str() {
Ok(either_str) => {
if !used_keys.contains(either_str.as_cow()?.as_ref()) {
errors.push(ValLineError::new_with_loc(
ErrorType::UnexpectedKeywordArgument,
value,
raw_key.as_loc_item(),
));
// Unknown / extra field
match self.extra_behavior {
ExtraBehavior::Forbid => {
errors.push(ValLineError::new_with_loc(
ErrorType::UnexpectedKeywordArgument,
value,
raw_key.as_loc_item(),
));
}
ExtraBehavior::Ignore => {}
ExtraBehavior::Allow => {
output_dict.set_item(either_str.as_py_string(py), value)?
}
}
}
}
Err(ValError::LineErrors(line_errors)) => {
Expand Down Expand Up @@ -303,7 +318,19 @@ impl Validator for DataclassArgsValidator {
) -> ValResult<'data, PyObject> {
let dict: &PyDict = obj.downcast()?;

let ok = |output: PyObject| {
dict.set_item(field_name, output)?;
Ok(dict.to_object(py))
};

if let Some(field) = self.fields.iter().find(|f| f.name == field_name) {
if field.frozen {
return Err(ValError::new_with_loc(
ErrorType::FrozenField,
field_value,
field.name.to_string(),
));
}
// by using dict but removing the field in question, we match V1 behaviour
let data_dict = dict.copy()?;
if let Err(err) = data_dict.del_item(field_name) {
Expand All @@ -321,10 +348,7 @@ impl Validator for DataclassArgsValidator {
.validator
.validate(py, field_value, &next_extra, slots, recursion_guard)
{
Ok(output) => {
dict.set_item(field_name, output)?;
Ok(dict.to_object(py))
}
Ok(output) => ok(output),
Err(ValError::LineErrors(line_errors)) => {
let errors = line_errors
.into_iter()
Expand All @@ -335,13 +359,21 @@ impl Validator for DataclassArgsValidator {
Err(err) => Err(err),
}
} else {
Err(ValError::new_with_loc(
ErrorType::NoSuchAttribute {
attribute: field_name.to_string(),
},
field_value,
field_name.to_string(),
))
// Handle extra (unknown) field
// We partially use the extra_behavior for initialization / validation
// to determine how to handle assignment
match self.extra_behavior {
// For dataclasses we allow assigning unknown fields
// to match stdlib dataclass behavior
ExtraBehavior::Allow => ok(field_value.to_object(py)),
_ => Err(ValError::new_with_loc(
ErrorType::NoSuchAttribute {
attribute: field_name.to_string(),
},
field_value,
field_name.to_string(),
)),
}
}
}

Expand All @@ -364,6 +396,7 @@ pub struct DataclassValidator {
post_init: Option<Py<PyString>>,
revalidate: Revalidate,
name: String,
frozen: bool,
}

impl BuildValidator for DataclassValidator {
Expand Down Expand Up @@ -399,6 +432,7 @@ impl BuildValidator for DataclassValidator {
// as with model, get the class's `__name__`, not using `class.name()` since it uses `__qualname__`
// which is not what we want here
name: class.getattr(intern!(py, "__name__"))?.extract()?,
frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false),
}
.into())
}
Expand Down Expand Up @@ -455,6 +489,9 @@ impl Validator for DataclassValidator {
slots: &'data [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
if self.frozen {
return Err(ValError::new(ErrorType::FrozenInstance, field_value));
}
let dict_py_str = intern!(py, "__dict__");
let dict: &PyDict = obj.getattr(dict_py_str)?.downcast()?;

Expand Down
Loading

0 comments on commit af07999

Please sign in to comment.