Skip to content

Commit

Permalink
Fix arguments to dataclass validation functions (#563)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb authored Apr 29, 2023
1 parent b23b38a commit 93af8cc
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 3 deletions.
12 changes: 10 additions & 2 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,11 @@ impl Validator for DataclassArgsValidator {

let ok = |output: PyObject| {
dict.set_item(field_name, output)?;
Ok(dict.to_object(py))
// The second return value represents `init_only_args`
// which doesn't make much sense in this context but we need to put something there
// so that function validators that sit between DataclassValidator and DataclassArgsValidator
// always get called the same shape of data.
Ok(PyTuple::new(py, vec![dict.to_object(py), py.None()]).into_py(py))
};

if let Some(field) = self.fields.iter().find(|f| f.name == field_name) {
Expand Down Expand Up @@ -511,10 +515,14 @@ impl Validator for DataclassValidator {
let new_dict = dict.copy()?;
new_dict.set_item(field_name, field_value)?;

let dc_dict =
// Discard the second return value, which is `init_only_args` but is always
// None anyway for validate_assignment; see validate_assignment in DataclassArgsValidator
let val_assignment_result =
self.validator
.validate_assignment(py, new_dict, field_name, field_value, extra, slots, recursion_guard)?;

let (dc_dict, _): (&PyDict, PyObject) = val_assignment_result.extract(py)?;

force_setattr(py, obj, dict_py_str, dc_dict)?;

Ok(obj.to_object(py))
Expand Down
98 changes: 97 additions & 1 deletion tests/validators/test_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import re
from typing import Any, Dict, Union
from typing import Any, Dict, List, Union

import pytest
from dirty_equals import IsListOrTuple, IsStr
Expand Down Expand Up @@ -1039,3 +1039,99 @@ class MyModel:

v.validate_assignment(m, 'not_f', '123')
assert getattr(m, 'not_f') == '123'


def test_function_validator_wrapping_args_schema_after() -> None:
calls: List[Any] = []

def func(*args: Any) -> Any:
calls.append(args)
return args[0]

@dataclasses.dataclass
class Model:
number: int = 1

cs = core_schema.dataclass_schema(
Model,
core_schema.no_info_after_validator_function(
func,
core_schema.dataclass_args_schema(
'Model', [core_schema.dataclass_field('number', core_schema.int_schema())]
),
),
)

v = SchemaValidator(cs)

instance: Model = v.validate_python({'number': 1})
assert instance.number == 1
assert calls == [(({'number': 1}, None),)]
v.validate_assignment(instance, 'number', 2)
assert instance.number == 2
assert calls == [(({'number': 1}, None),), (({'number': 2}, None),)]


def test_function_validator_wrapping_args_schema_before() -> None:
calls: List[Any] = []

def func(*args: Any) -> Any:
calls.append(args)
return args[0]

@dataclasses.dataclass
class Model:
number: int = 1

cs = core_schema.dataclass_schema(
Model,
core_schema.no_info_before_validator_function(
func,
core_schema.dataclass_args_schema(
'Model', [core_schema.dataclass_field('number', core_schema.int_schema())]
),
),
)

v = SchemaValidator(cs)

instance: Model = v.validate_python({'number': 1})
assert instance.number == 1
assert calls == [({'number': 1},)]
v.validate_assignment(instance, 'number', 2)
assert instance.number == 2
assert calls == [({'number': 1},), ({'number': 2},)]


def test_function_validator_wrapping_args_schema_wrap() -> None:
calls: List[Any] = []

def func(*args: Any) -> Any:
assert len(args) == 2
input, handler = args
output = handler(input)
calls.append((input, output))
return output

@dataclasses.dataclass
class Model:
number: int = 1

cs = core_schema.dataclass_schema(
Model,
core_schema.no_info_wrap_validator_function(
func,
core_schema.dataclass_args_schema(
'Model', [core_schema.dataclass_field('number', core_schema.int_schema())]
),
),
)

v = SchemaValidator(cs)

instance: Model = v.validate_python({'number': 1})
assert instance.number == 1
assert calls == [({'number': 1}, ({'number': 1}, None))]
v.validate_assignment(instance, 'number', 2)
assert instance.number == 2
assert calls == [({'number': 1}, ({'number': 1}, None)), ({'number': 2}, ({'number': 2}, None))]

0 comments on commit 93af8cc

Please sign in to comment.