diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 897fc0f37..dc8f81a07 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -90,6 +90,8 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { self.strict_str() } + fn as_str_strict(&self) -> Option<&str>; + fn validate_bytes(&'a self, strict: bool) -> ValResult> { if strict { self.strict_bytes() @@ -129,6 +131,8 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { self.strict_int() } + fn as_int_strict(&self) -> Option; + fn validate_float(&self, strict: bool, ultra_strict: bool) -> ValResult { if ultra_strict { self.ultra_strict_float() diff --git a/src/input/input_json.rs b/src/input/input_json.rs index bb3327296..41b8313bd 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -92,6 +92,13 @@ impl<'a> Input<'a> for JsonInput { } } + fn as_str_strict(&self) -> Option<&str> { + match self { + JsonInput::String(s) => Some(s.as_str()), + _ => None, + } + } + fn validate_bytes(&'a self, _strict: bool) -> ValResult> { match self { JsonInput::String(s) => Ok(s.as_bytes().into()), @@ -141,6 +148,13 @@ impl<'a> Input<'a> for JsonInput { } } + fn as_int_strict(&self) -> Option { + match self { + JsonInput::Int(i) => Some(*i), + _ => None, + } + } + fn ultra_strict_float(&self) -> ValResult { match self { JsonInput::Float(f) => Ok(*f), @@ -349,6 +363,10 @@ impl<'a> Input<'a> for String { self.validate_str(false) } + fn as_str_strict(&self) -> Option<&str> { + Some(self.as_str()) + } + fn validate_bytes(&'a self, _strict: bool) -> ValResult> { Ok(self.as_bytes().into()) } @@ -374,6 +392,10 @@ impl<'a> Input<'a> for String { } } + fn as_int_strict(&self) -> Option { + None + } + #[cfg_attr(has_no_coverage, no_coverage)] fn ultra_strict_float(&self) -> ValResult { self.strict_float() diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 3f1e24967..d4d435b59 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -220,6 +220,14 @@ impl<'a> Input<'a> for PyAny { } } + fn as_str_strict(&self) -> Option<&str> { + if self.get_type().is(get_py_str_type(self.py())) { + self.extract().ok() + } else { + None + } + } + fn strict_bytes(&'a self) -> ValResult> { if let Ok(py_bytes) = self.downcast::() { Ok(py_bytes.into()) @@ -289,6 +297,14 @@ impl<'a> Input<'a> for PyAny { } } + fn as_int_strict(&self) -> Option { + if self.get_type().is(get_py_int_type(self.py())) { + self.extract().ok() + } else { + None + } + } + fn ultra_strict_float(&self) -> ValResult { if matches!(self.is_instance_of::(), Ok(true)) { Err(ValError::new(ErrorType::FloatType, self)) @@ -697,3 +713,15 @@ pub fn list_as_tuple(list: &PyList) -> &PyTuple { }; py_tuple.into_ref(list.py()) } + +static PY_INT_TYPE: GILOnceCell = GILOnceCell::new(); + +fn get_py_int_type(py: Python) -> &PyObject { + PY_INT_TYPE.get_or_init(py, || PyInt::type_object(py).into()) +} + +static PY_STR_TYPE: GILOnceCell = GILOnceCell::new(); + +fn get_py_str_type(py: Python) -> &PyObject { + PY_STR_TYPE.get_or_init(py, || PyString::type_object(py).into()) +} diff --git a/src/validators/literal.rs b/src/validators/literal.rs index ad0dbf756..d829962f7 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -1,367 +1,100 @@ -use pyo3::intern; -use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList, PyString}; +// Validator for things inside of a typing.Literal[] +// which can be an int, a string, bytes or an Enum value (including `class Foo(str, Enum)` type enums) use ahash::AHashSet; +use pyo3::intern; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyList}; use crate::build_tools::{py_err, SchemaDict}; use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::Input; use crate::recursion_guard::RecursionGuard; -use super::none::NoneValidator; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; -#[derive(Debug)] -pub struct LiteralBuilder; +#[derive(Debug, Clone)] +pub struct LiteralValidator { + // Specialized lookups for ints and strings because they + // (1) are easy to convert between Rust and Python + // (2) hashing them in Rust is very fast + // (3) are the most commonly used things in Literal[...] + expected_int: Option>, + expected_str: Option>, + // Catch all for Enum and bytes (the latter only because it is seldom used) + expected_py: Option>, + expected_repr: String, + name: String, +} -impl BuildValidator for LiteralBuilder { +impl BuildValidator for LiteralValidator { const EXPECTED_TYPE: &'static str = "literal"; fn build( schema: &PyDict, - config: Option<&PyDict>, - build_context: &mut BuildContext, + _config: Option<&PyDict>, + _build_context: &mut BuildContext, ) -> PyResult { let expected: &PyList = schema.get_as_req(intern!(schema.py(), "expected"))?; if expected.is_empty() { return py_err!(r#""expected" should have length > 0"#); - } else if expected.len() == 1 { - let first = expected.get_item(0)?; - if let Ok(py_str) = first.downcast::() { - return Ok(LiteralSingleStringValidator::new(py_str.to_str()?.to_string()).into()); - } else if let Ok(int) = first.extract::() { - return Ok(LiteralSingleIntValidator::new(int).into()); - } else if first.is_none() { - return NoneValidator::build(schema, config, build_context); - } } - - if let Some(v) = LiteralMultipleStringsValidator::new(expected) { - Ok(v.into()) - } else if let Some(v) = LiteralMultipleIntsValidator::new(expected) { - Ok(v.into()) - } else { - Ok(LiteralGeneralValidator::new(expected)?.into()) - } - } -} - -#[derive(Debug, Clone)] -pub struct LiteralSingleStringValidator { - expected: String, - expected_repr: String, - name: String, -} - -impl LiteralSingleStringValidator { - fn new(expected: String) -> Self { - let expected_repr = format!("'{expected}'"); - let name = format!("literal[{expected_repr}]"); - Self { - expected, - expected_repr, - name, - } - } -} - -impl Validator for LiteralSingleStringValidator { - fn validate<'s, 'data>( - &'s self, - py: Python<'data>, - input: &'data impl Input<'data>, - extra: &Extra, - _slots: &'data [CombinedValidator], - _recursion_guard: &'s mut RecursionGuard, - ) -> ValResult<'data, PyObject> { - let either_str = input.validate_str(extra.strict.unwrap_or(false))?; - if either_str.as_cow()?.as_ref() == self.expected.as_str() { - Ok(input.to_object(py)) - } else { - Err(ValError::new( - ErrorType::LiteralError { - expected: self.expected_repr.clone(), - }, - input, - )) - } - } - - fn different_strict_behavior( - &self, - _build_context: Option<&BuildContext>, - ultra_strict: bool, - ) -> bool { - !ultra_strict - } - - fn get_name(&self) -> &str { - &self.name - } - - fn complete(&mut self, _build_context: &BuildContext) -> PyResult<()> { - Ok(()) - } -} - -#[derive(Debug, Clone)] -pub struct LiteralSingleIntValidator { - expected: i64, - name: String, -} - -impl LiteralSingleIntValidator { - fn new(expected: i64) -> Self { - Self { - expected, - name: format!("literal[{expected}]"), - } - } -} - -impl Validator for LiteralSingleIntValidator { - fn validate<'s, 'data>( - &'s self, - py: Python<'data>, - input: &'data impl Input<'data>, - extra: &Extra, - _slots: &'data [CombinedValidator], - _recursion_guard: &'s mut RecursionGuard, - ) -> ValResult<'data, PyObject> { - let int = input.validate_int(extra.strict.unwrap_or(false))?; - if int == self.expected { - Ok(input.to_object(py)) - } else { - Err(ValError::new( - ErrorType::LiteralError { - expected: self.expected.to_string(), - }, - input, - )) - } - } - - fn different_strict_behavior( - &self, - _build_context: Option<&BuildContext>, - _ultra_strict: bool, - ) -> bool { - true - } - - fn get_name(&self) -> &str { - &self.name - } - - fn complete(&mut self, _build_context: &BuildContext) -> PyResult<()> { - Ok(()) - } -} - -#[derive(Debug, Clone)] -pub struct LiteralMultipleStringsValidator { - expected: AHashSet, - expected_repr: String, - name: String, -} - -impl LiteralMultipleStringsValidator { - fn new(expected_list: &PyList) -> Option { - let mut expected: AHashSet = AHashSet::new(); - let mut repr_args = Vec::new(); - for item in expected_list.iter() { - if let Ok(str) = item.extract() { - repr_args.push(format!("'{str}'")); - expected.insert(str); - } else { - return None; - } - } - let (expected_repr, name) = expected_repr_name(repr_args, "literal"); - Some(Self { - expected, - expected_repr, - name, - }) - } -} - -impl Validator for LiteralMultipleStringsValidator { - fn validate<'s, 'data>( - &'s self, - py: Python<'data>, - input: &'data impl Input<'data>, - extra: &Extra, - _slots: &'data [CombinedValidator], - _recursion_guard: &'s mut RecursionGuard, - ) -> ValResult<'data, PyObject> { - let either_str = input.validate_str(extra.strict.unwrap_or(false))?; - if self.expected.contains(either_str.as_cow()?.as_ref()) { - Ok(input.to_object(py)) - } else { - Err(ValError::new( - ErrorType::LiteralError { - expected: self.expected_repr.clone(), - }, - input, - )) - } - } - - fn different_strict_behavior( - &self, - _build_context: Option<&BuildContext>, - ultra_strict: bool, - ) -> bool { - !ultra_strict - } - - fn get_name(&self) -> &str { - &self.name - } - - fn complete(&mut self, _build_context: &BuildContext) -> PyResult<()> { - Ok(()) - } -} - -#[derive(Debug, Clone)] -pub struct LiteralMultipleIntsValidator { - expected: AHashSet, - expected_repr: String, - name: String, -} - -impl LiteralMultipleIntsValidator { - fn new(expected_list: &PyList) -> Option { - let mut expected: AHashSet = AHashSet::with_capacity(expected_list.len()); - let mut repr_args = Vec::new(); - for item in expected_list.iter() { - if let Ok(int) = item.extract() { - expected.insert(int); - repr_args.push(int.to_string()); - } else { - return None; - } - } - let (expected_repr, name) = expected_repr_name(repr_args, "literal"); - Some(Self { - expected, - expected_repr, - name, - }) - } -} - -impl Validator for LiteralMultipleIntsValidator { - fn validate<'s, 'data>( - &'s self, - py: Python<'data>, - input: &'data impl Input<'data>, - extra: &Extra, - _slots: &'data [CombinedValidator], - _recursion_guard: &'s mut RecursionGuard, - ) -> ValResult<'data, PyObject> { - let int = input.validate_int(extra.strict.unwrap_or(false))?; - if self.expected.contains(&int) { - Ok(input.to_object(py)) - } else { - Err(ValError::new( - ErrorType::LiteralError { - expected: self.expected_repr.clone(), - }, - input, - )) - } - } - - fn different_strict_behavior( - &self, - _build_context: Option<&BuildContext>, - _ultra_strict: bool, - ) -> bool { - true - } - - fn get_name(&self) -> &str { - &self.name - } - - fn complete(&mut self, _build_context: &BuildContext) -> PyResult<()> { - Ok(()) - } -} - -#[derive(Debug, Clone)] -pub struct LiteralGeneralValidator { - expected_int: AHashSet, - expected_str: AHashSet, - expected_py: Py, - expected_repr: String, - name: String, -} - -impl LiteralGeneralValidator { - fn new(expected: &PyList) -> PyResult { + let py = expected.py(); + // Literal[...] only supports int, str, bytes or enums, all of which can be hashed let mut expected_int = AHashSet::new(); let mut expected_str = AHashSet::new(); - let py = expected.py(); - let expected_py = PyList::empty(py); + let expected_py = PyDict::new(py); let mut repr_args: Vec = Vec::new(); for item in expected.iter() { repr_args.push(item.repr()?.extract()?); - if let Ok(int) = item.extract::() { + if let Some(int) = item.as_int_strict() { expected_int.insert(int); - } else if let Ok(py_str) = item.downcast::() { - expected_str.insert(py_str.to_str()?.to_string()); + } else if let Some(str) = item.as_str_strict() { + expected_str.insert(str.to_string()); } else { - expected_py.append(item)?; + expected_py.set_item(item, item)?; } } let (expected_repr, name) = expected_repr_name(repr_args, "literal"); - Ok(Self { - expected_int, - expected_str, - expected_py: expected_py.into_py(py), + Ok(CombinedValidator::Literal(Self { + expected_int: (!expected_int.is_empty()).then_some(expected_int), + expected_str: (!expected_str.is_empty()).then_some(expected_str), + expected_py: (!expected_py.is_empty()).then_some(expected_py.into()), expected_repr, name, - }) + })) } } -impl Validator for LiteralGeneralValidator { +impl Validator for LiteralValidator { fn validate<'s, 'data>( &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, + _extra: &Extra, _slots: &'data [CombinedValidator], _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let strict = extra.strict.unwrap_or(false); - if !self.expected_int.is_empty() { - if let Ok(int) = input.validate_int(strict) { - if self.expected_int.contains(&int) { + if let Some(expected_ints) = &self.expected_int { + if let Some(int) = input.as_int_strict() { + if expected_ints.contains(&int) { return Ok(input.to_object(py)); } } } - if !self.expected_str.is_empty() { - if let Ok(either_str) = input.validate_str(strict) { - if self.expected_str.contains(either_str.as_cow()?.as_ref()) { + if let Some(expected_strings) = &self.expected_str { + if let Some(str) = input.as_str_strict() { + if expected_strings.contains(str) { return Ok(input.to_object(py)); } } } - - let py_value = input.to_object(py); - - let expected_py = self.expected_py.as_ref(py); - if !expected_py.is_empty() && expected_py.contains(&py_value)? { - return Ok(py_value); - } - + // must be an enum or bytes + if let Some(expected_py) = &self.expected_py { + if let Some(v) = expected_py.as_ref(py).get_item(input) { + return Ok(v.into()); + } + }; Err(ValError::new( ErrorType::LiteralError { expected: self.expected_repr.clone(), diff --git a/src/validators/mod.rs b/src/validators/mod.rs index f42ed8510..2764a633e 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -410,7 +410,7 @@ pub fn build_validator<'a>( // function call - validation around a function call call::CallValidator, // literals - literal::LiteralBuilder, + literal::LiteralValidator, // any any::AnyValidator, // bytes @@ -540,11 +540,7 @@ pub enum CombinedValidator { // function call - validation around a function call FunctionCall(call::CallValidator), // literals - LiteralSingleString(literal::LiteralSingleStringValidator), - LiteralSingleInt(literal::LiteralSingleIntValidator), - LiteralMultipleStrings(literal::LiteralMultipleStringsValidator), - LiteralMultipleInts(literal::LiteralMultipleIntsValidator), - LiteralGeneral(literal::LiteralGeneralValidator), + Literal(literal::LiteralValidator), // any Any(any::AnyValidator), // bytes diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index 58c6d259f..1becd31d3 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -1333,10 +1333,20 @@ class SomeStrEnum(str, Enum): 'few_mixed', ], ) -def test_validate_literal(benchmark: Any, allowed_values: List[Any], input: Any, expected_val_res: Any) -> None: +@pytest.mark.parametrize('py_or_json', ['python', 'json']) +def test_validate_literal( + benchmark: Any, allowed_values: List[Any], input: Any, expected_val_res: Any, py_or_json: str +) -> None: validator = SchemaValidator(core_schema.literal_schema(expected=allowed_values)) - res = validator.validate_python(input) - assert res == expected_val_res + if py_or_json == 'python': + res = validator.validate_python(input) + assert res == expected_val_res - benchmark(validator.validate_python, input) + benchmark(validator.validate_python, input) + else: + input_json = json.dumps(input) + res = validator.validate_json(input_json) + assert res == expected_val_res + + benchmark(validator.validate_json, input_json) diff --git a/tests/validators/test_literal.py b/tests/validators/test_literal.py index 4b8ad0f83..175b10a5a 100644 --- a/tests/validators/test_literal.py +++ b/tests/validators/test_literal.py @@ -1,5 +1,6 @@ import re from enum import Enum +from typing import Any, Callable, List import pytest @@ -168,7 +169,8 @@ def test_literal_none(): assert v.isinstance_python(0) is False assert v.isinstance_json('null') is True assert v.isinstance_json('""') is False - assert plain_repr(v) == 'SchemaValidator(title="none",validator=None(NoneValidator),slots=[])' + expected_repr_start = 'SchemaValidator(title="literal[None]"' + assert plain_repr(v)[: len(expected_repr_start)] == expected_repr_start def test_union(): @@ -195,9 +197,10 @@ def test_union(): ] -def test_enum(): +def test_enum_value(): class FooEnum(Enum): foo = 'foo_value' + bar = 'bar_value' v = SchemaValidator(core_schema.literal_schema([FooEnum.foo])) assert v.validate_python(FooEnum.foo) == FooEnum.foo @@ -213,3 +216,167 @@ class FooEnum(Enum): 'ctx': {'expected': ""}, } ] + with pytest.raises(ValidationError) as exc_info: + v.validate_python('unknown') + # insert_assert(exc_info.value.errors()) + assert exc_info.value.errors() == [ + { + 'type': 'literal_error', + 'loc': (), + 'msg': "Input should be ", + 'input': 'unknown', + 'ctx': {'expected': ""}, + } + ] + + with pytest.raises(ValidationError) as exc_info: + v.validate_json('"foo_value"') + assert exc_info.value.errors() == [ + { + 'type': 'literal_error', + 'loc': (), + 'msg': "Input should be ", + 'input': 'foo_value', + 'ctx': {'expected': ""}, + } + ] + + +def test_str_enum_values(): + class Foo(str, Enum): + foo = 'foo_value' + bar = 'bar_value' + + v = SchemaValidator(core_schema.literal_schema([Foo.foo])) + + assert v.validate_python(Foo.foo) == Foo.foo + assert v.validate_python('foo_value') == Foo.foo + assert v.validate_json('"foo_value"') == Foo.foo + + with pytest.raises(ValidationError) as exc_info: + v.validate_python('unknown') + assert exc_info.value.errors() == [ + { + 'type': 'literal_error', + 'loc': (), + 'msg': "Input should be ", + 'input': 'unknown', + 'ctx': {'expected': ""}, + } + ] + + +def test_int_enum_values(): + class Foo(int, Enum): + foo = 2 + bar = 3 + + v = SchemaValidator(core_schema.literal_schema([Foo.foo])) + + assert v.validate_python(Foo.foo) == Foo.foo + assert v.validate_python(2) == Foo.foo + assert v.validate_json('2') == Foo.foo + + with pytest.raises(ValidationError) as exc_info: + v.validate_python(4) + assert exc_info.value.errors() == [ + { + 'type': 'literal_error', + 'loc': (), + 'msg': 'Input should be ', + 'input': 4, + 'ctx': {'expected': ''}, + } + ] + + +@pytest.mark.parametrize( + 'reverse, err', + [ + ( + lambda x: list(reversed(x)), + [ + { + 'type': 'literal_error', + 'loc': (), + 'msg': 'Input should be or 1', + 'input': 2, + 'ctx': {'expected': ' or 1'}, + } + ], + ), + ( + lambda x: x, + [ + { + 'type': 'literal_error', + 'loc': (), + 'msg': 'Input should be 1 or ', + 'input': 2, + 'ctx': {'expected': '1 or '}, + } + ], + ), + ], +) +def test_mix_int_enum_with_int(reverse: Callable[[List[Any]], List[Any]], err: Any): + class Foo(int, Enum): + foo = 1 + + v = SchemaValidator(core_schema.literal_schema(reverse([1, Foo.foo]))) + + assert v.validate_python(Foo.foo) is Foo.foo + val = v.validate_python(1) + assert val == 1 and val is not Foo.foo + val = v.validate_json('1') + assert val == 1 and val is not Foo.foo + + with pytest.raises(ValidationError) as exc_info: + v.validate_python(2) + assert exc_info.value.errors() == err + + +@pytest.mark.parametrize( + 'reverse, err', + [ + ( + lambda x: list(reversed(x)), + [ + { + 'type': 'literal_error', + 'loc': (), + 'msg': "Input should be or 'foo_val'", + 'input': 'bar_val', + 'ctx': {'expected': " or 'foo_val'"}, + } + ], + ), + ( + lambda x: x, + [ + { + 'type': 'literal_error', + 'loc': (), + 'msg': "Input should be 'foo_val' or ", + 'input': 'bar_val', + 'ctx': {'expected': "'foo_val' or "}, + } + ], + ), + ], +) +def test_mix_str_enum_with_str(reverse: Callable[[List[Any]], List[Any]], err: Any): + class Foo(str, Enum): + foo = 'foo_val' + + v = SchemaValidator(core_schema.literal_schema(reverse(['foo_val', Foo.foo]))) + + assert v.validate_python(Foo.foo) is Foo.foo + val = v.validate_python('foo_val') + assert val == 'foo_val' and val is not Foo.foo + val = v.validate_json('"foo_val"') + assert val == 'foo_val' and val is not Foo.foo + + with pytest.raises(ValidationError) as exc_info: + v.validate_python('bar_val') + assert exc_info.value.errors() == err