Skip to content

Commit

Permalink
Add SchemaValidator.get_default_value() (#643)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb authored May 28, 2023
1 parent 068742e commit 616dea9
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 10 deletions.
2 changes: 2 additions & 0 deletions pydantic_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SchemaError,
SchemaSerializer,
SchemaValidator,
Some,
Url,
ValidationError,
__version__,
Expand All @@ -37,6 +38,7 @@
'CoreSchemaType',
'SchemaValidator',
'SchemaSerializer',
'Some',
'Url',
'MultiHostUrl',
'ArgsKwargs',
Expand Down
13 changes: 12 additions & 1 deletion pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import decimal
import sys
from typing import Any, Callable
from typing import Any, Callable, Generic

from typing_extensions import TypeVar

from pydantic_core import ErrorDetails, InitErrorDetails
from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType
Expand Down Expand Up @@ -35,6 +37,14 @@ __all__ = (
__version__: str
build_profile: str

T = TypeVar('T', default=Any, covariant=True)

class Some(Generic[T]):
__match_args__ = ('value',)

@property
def value(self) -> T: ...

class SchemaValidator:
@property
def title(self) -> str: ...
Expand All @@ -56,6 +66,7 @@ class SchemaValidator:
def validate_assignment(
self, obj: Any, field: str, input: Any, *, strict: 'bool | None' = None, context: Any = None
) -> 'dict[str, Any]': ...
def get_default_value(self, *, strict: 'bool | None' = None, context: Any = None) -> Some | None: ...

IncEx: TypeAlias = 'set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None'

Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub use errors::{list_all_errors, PydanticCustomError, PydanticKnownError, Pydan
pub use serializers::{
to_json, to_jsonable_python, PydanticSerializationError, PydanticSerializationUnexpectedValue, SchemaSerializer,
};
pub use validators::SchemaValidator;
pub use validators::{PySome, SchemaValidator};

pub fn get_version() -> String {
let version = env!("CARGO_PKG_VERSION");
Expand All @@ -46,6 +46,7 @@ pub fn get_version() -> String {
fn _pydantic_core(_py: Python, m: &PyModule) -> PyResult<()> {
m.add("__version__", get_version())?;
m.add("build_profile", env!("PROFILE"))?;
m.add_class::<PySome>()?;
m.add_class::<SchemaValidator>()?;
m.add_class::<ValidationError>()?;
m.add_class::<SchemaError>()?;
Expand Down
60 changes: 59 additions & 1 deletion src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use enum_dispatch::enum_dispatch;
use pyo3::exceptions::PyTypeError;
use pyo3::once_cell::GILOnceCell;
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyDict};
use pyo3::types::{PyAny, PyDict, PyTuple, PyType};
use pyo3::{intern, PyTraverseError, PyVisit};

use crate::build_tools::{py_schema_err, py_schema_error_type, SchemaError};
Expand Down Expand Up @@ -58,6 +58,40 @@ pub use with_default::DefaultType;

use self::definitions::DefinitionRefValidator;

#[pyclass(module = "pydantic_core._pydantic_core", name = "Some")]
pub struct PySome {
#[pyo3(get)]
value: PyObject,
}

impl PySome {
fn new(value: PyObject) -> Self {
Self { value }
}
}

#[pymethods]
impl PySome {
pub fn __repr__(&self, py: Python) -> PyResult<String> {
Ok(format!("Some({})", self.value.as_ref(py).repr()?,))
}

#[new]
pub fn py_new(value: PyObject) -> Self {
Self { value }
}

#[classmethod]
pub fn __class_getitem__(cls: &PyType, _args: &PyAny) -> Py<PyType> {
cls.into_py(cls.py())
}

#[classattr]
fn __match_args__(py: Python) -> &PyTuple {
PyTuple::new(py, vec![intern!(py, "value")])
}
}

#[pyclass(module = "pydantic_core._pydantic_core")]
#[derive(Debug, Clone)]
pub struct SchemaValidator {
Expand Down Expand Up @@ -182,6 +216,30 @@ impl SchemaValidator {
.map_err(|e| self.prepare_validation_err(py, e, ErrorMode::Python))
}

#[pyo3(signature = (*, strict=None, context=None))]
pub fn get_default_value(&self, py: Python, strict: Option<bool>, context: Option<&PyAny>) -> PyResult<PyObject> {
let extra = Extra {
mode: InputType::Python,
data: None,
strict,
ultra_strict: false,
context,
field_name: None,
self_instance: None,
};
let recursion_guard = &mut RecursionGuard::default();
let r = self
.validator
.default_value(py, None::<i64>, &extra, &self.definitions, recursion_guard);
match r {
Ok(maybe_default) => match maybe_default {
Some(v) => Ok(PySome::new(v).into_py(py)),
None => Ok(py.None().into_py(py)),
},
Err(e) => Err(self.prepare_validation_err(py, e, ErrorMode::Python)),
}
}

pub fn __repr__(&self, py: Python) -> String {
format!(
"SchemaValidator(title={:?}, validator={:#?}, definitions={:#?})",
Expand Down
159 changes: 152 additions & 7 deletions tests/validators/test_with_default.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import sys
from collections import deque
from typing import Dict, List
from typing import Any, Callable, Dict, List, Union, cast

import pytest

from pydantic_core import ArgsKwargs, SchemaError, SchemaValidator, ValidationError, core_schema
from pydantic_core import ArgsKwargs, SchemaError, SchemaValidator, Some, ValidationError, core_schema

from ..conftest import PyAndJson

Expand Down Expand Up @@ -307,18 +308,46 @@ class MyModel:
assert m.__pydantic_fields_set__ == set()


def test_validate_default():
@pytest.mark.parametrize('config_validate_default', [True, False, None])
@pytest.mark.parametrize('schema_validate_default', [True, False, None])
@pytest.mark.parametrize(
'inner_schema',
[
core_schema.no_info_after_validator_function(lambda x: x * 2, core_schema.int_schema()),
core_schema.no_info_before_validator_function(lambda x: str(int(x) * 2), core_schema.int_schema()),
core_schema.no_info_wrap_validator_function(lambda x, h: h(str(int(x) * 2)), core_schema.int_schema()),
core_schema.no_info_wrap_validator_function(lambda x, h: h(x) * 2, core_schema.int_schema()),
],
ids=['after', 'before', 'wrap-before', 'wrap-after'],
)
def test_validate_default(
config_validate_default: Union[bool, None],
schema_validate_default: Union[bool, None],
inner_schema: core_schema.CoreSchema,
):
if config_validate_default is not None:
config = core_schema.CoreConfig(validate_default=config_validate_default)
else:
config = None
v = SchemaValidator(
core_schema.typed_dict_schema(
{
'x': core_schema.typed_dict_field(
core_schema.with_default_schema(core_schema.int_schema(), default='42', validate_default=True)
core_schema.with_default_schema(
inner_schema, default='42', validate_default=schema_validate_default
)
)
}
)
),
config,
)
assert v.validate_python({'x': '2'}) == {'x': 2}
assert v.validate_python({}) == {'x': 42}
assert v.validate_python({'x': '2'}) == {'x': 4}
expected = (
84
if (config_validate_default is True and schema_validate_default is not False or schema_validate_default is True)
else '42'
)
assert v.validate_python({}) == {'x': expected}


def test_validate_default_factory():
Expand Down Expand Up @@ -436,3 +465,119 @@ class Model:

assert m2.int_list_with_default is not m1.int_list_with_default
assert m2.str_dict_with_default is not m1.str_dict_with_default


def test_default_value() -> None:
s = core_schema.with_default_schema(core_schema.list_schema(core_schema.int_schema()), default=[1, 2, 3])

v = SchemaValidator(s)

r = v.get_default_value()
assert r is not None
assert r.value == [1, 2, 3]


def test_default_value_validate_default() -> None:
s = core_schema.with_default_schema(core_schema.list_schema(core_schema.int_schema()), default=['1', '2', '3'])

v = SchemaValidator(s, core_schema.CoreConfig(validate_default=True))

r = v.get_default_value()
assert r is not None
assert r.value == [1, 2, 3]


def test_default_value_validate_default_fail() -> None:
s = core_schema.with_default_schema(core_schema.list_schema(core_schema.int_schema()), default=['a'])

v = SchemaValidator(s, core_schema.CoreConfig(validate_default=True))

with pytest.raises(ValidationError) as exc_info:
v.get_default_value()
assert exc_info.value.errors(include_url=False) == [
{
'type': 'int_parsing',
'loc': (0,),
'msg': 'Input should be a valid integer, unable to parse string as an integer',
'input': 'a',
}
]


def test_default_value_validate_default_strict_pass() -> None:
s = core_schema.with_default_schema(core_schema.list_schema(core_schema.int_schema()), default=[1, 2, 3])

v = SchemaValidator(s, core_schema.CoreConfig(validate_default=True))

r = v.get_default_value(strict=True)
assert r is not None
assert r.value == [1, 2, 3]


def test_default_value_validate_default_strict_fail() -> None:
s = core_schema.with_default_schema(core_schema.list_schema(core_schema.int_schema()), default=['1'])

v = SchemaValidator(s, core_schema.CoreConfig(validate_default=True))

with pytest.raises(ValidationError) as exc_info:
v.get_default_value(strict=True)
assert exc_info.value.errors(include_url=False) == [
{'type': 'int_type', 'loc': (0,), 'msg': 'Input should be a valid integer', 'input': '1'}
]


@pytest.mark.parametrize('validate_default', [True, False])
def test_no_default_value(validate_default: bool) -> None:
s = core_schema.list_schema(core_schema.int_schema())
v = SchemaValidator(s, core_schema.CoreConfig(validate_default=validate_default))

assert v.get_default_value() is None


@pytest.mark.parametrize('validate_default', [True, False])
def test_some(validate_default: bool) -> None:
def get_default() -> Union[Some[int], None]:
s = core_schema.with_default_schema(core_schema.int_schema(), default=42)
return SchemaValidator(s).get_default_value()

res = get_default()
assert res is not None
assert res.value == 42
assert repr(res) == 'Some(42)'


@pytest.mark.skipif(sys.version_info < (3, 10), reason='pattern matching was added in 3.10')
def test_some_pattern_match() -> None:
code = """\
def f(v: Union[Some[Any], None]) -> str:
match v:
case Some(1):
return 'case1'
case Some(value=2):
return 'case2'
case Some(int(value)):
return f'case3: {value}'
case Some(value):
return f'case4: {type(value).__name__}({value})'
case None:
return 'case5'
"""

local_vars = {}
exec(code, globals(), local_vars)
f = cast(Callable[[Union[Some[Any], None]], str], local_vars['f'])

res = f(SchemaValidator(core_schema.with_default_schema(core_schema.int_schema(), default=1)).get_default_value())
assert res == 'case1'

res = f(SchemaValidator(core_schema.with_default_schema(core_schema.int_schema(), default=2)).get_default_value())
assert res == 'case2'

res = f(SchemaValidator(core_schema.with_default_schema(core_schema.int_schema(), default=3)).get_default_value())
assert res == 'case3: 3'

res = f(SchemaValidator(core_schema.with_default_schema(core_schema.int_schema(), default='4')).get_default_value())
assert res == 'case4: str(4)'

res = f(SchemaValidator(core_schema.int_schema()).get_default_value())
assert res == 'case5'

0 comments on commit 616dea9

Please sign in to comment.