From c31fe514e753b109a3a660b8f27062352faa280b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 25 Apr 2023 12:49:08 -0600 Subject: [PATCH] Support RichComparison, hash and deepcopy for Url and MultiHostUrl (#558) --- pydantic_core/_pydantic_core.pyi | 6 ++- src/url.rs | 55 +++++++++++++++++++++++ tests/validators/test_url.py | 75 +++++++++++++++++++++++++++++++- 3 files changed, 133 insertions(+), 3 deletions(-) diff --git a/pydantic_core/_pydantic_core.pyi b/pydantic_core/_pydantic_core.pyi index fc2909d1a..47abfc7ba 100644 --- a/pydantic_core/_pydantic_core.pyi +++ b/pydantic_core/_pydantic_core.pyi @@ -15,6 +15,8 @@ if sys.version_info < (3, 11): else: from typing import Literal, NotRequired, TypeAlias +from _typeshed import SupportsAllComparisons + __all__ = ( '__version__', 'build_profile', @@ -126,7 +128,7 @@ def to_jsonable_python( fallback: 'Callable[[Any], Any] | None' = None, ) -> Any: ... -class Url: +class Url(SupportsAllComparisons): @property def scheme(self) -> str: ... @property @@ -156,7 +158,7 @@ class MultiHostHost(TypedDict): query: 'str | None' fragment: 'str | None' -class MultiHostUrl: +class MultiHostUrl(SupportsAllComparisons): @property def scheme(self) -> str: ... @property diff --git a/src/url.rs b/src/url.rs index bbf4dcf58..68a77ce99 100644 --- a/src/url.rs +++ b/src/url.rs @@ -1,6 +1,10 @@ +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; + use idna::punycode::decode_to_string; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; +use pyo3::pyclass::CompareOp; use pyo3::types::PyDict; use url::Url; @@ -116,6 +120,31 @@ impl PyUrl { pub fn __repr__(&self) -> String { format!("Url('{}')", self.lib_url) } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { + match op { + CompareOp::Lt => Ok(self.lib_url < other.lib_url), + CompareOp::Le => Ok(self.lib_url <= other.lib_url), + CompareOp::Eq => Ok(self.lib_url == other.lib_url), + CompareOp::Ne => Ok(self.lib_url != other.lib_url), + CompareOp::Gt => Ok(self.lib_url > other.lib_url), + CompareOp::Ge => Ok(self.lib_url >= other.lib_url), + } + } + + fn __hash__(&self) -> u64 { + let mut s = DefaultHasher::new(); + self.lib_url.to_string().hash(&mut s); + s.finish() + } + + fn __bool__(&self) -> bool { + true // an empty string is not a valid URL + } + + pub fn __deepcopy__(&self, py: Python, _memo: &PyDict) -> Py { + self.clone().into_py(py) + } } #[pyclass(name = "MultiHostUrl", module = "pydantic_core._pydantic_core")] @@ -250,6 +279,32 @@ impl PyMultiHostUrl { pub fn __repr__(&self) -> String { format!("Url('{}')", self.__str__()) } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { + match op { + CompareOp::Lt => Ok(self.unicode_string() < other.unicode_string()), + CompareOp::Le => Ok(self.unicode_string() <= other.unicode_string()), + CompareOp::Eq => Ok(self.unicode_string() == other.unicode_string()), + CompareOp::Ne => Ok(self.unicode_string() != other.unicode_string()), + CompareOp::Gt => Ok(self.unicode_string() > other.unicode_string()), + CompareOp::Ge => Ok(self.unicode_string() >= other.unicode_string()), + } + } + + fn __hash__(&self) -> u64 { + let mut s = DefaultHasher::new(); + self.ref_url.clone().into_url().to_string().hash(&mut s); + self.extra_urls.hash(&mut s); + s.finish() + } + + fn __bool__(&self) -> bool { + true // an empty string is not a valid URL + } + + pub fn __deepcopy__(&self, py: Python, _memo: &PyDict) -> Py { + self.clone().into_py(py) + } } fn host_to_dict<'a>(py: Python<'a>, lib_url: &Url) -> PyResult<&'a PyDict> { diff --git a/tests/validators/test_url.py b/tests/validators/test_url.py index 29a749301..44bc93b0d 100644 --- a/tests/validators/test_url.py +++ b/tests/validators/test_url.py @@ -1,5 +1,6 @@ import re -from typing import Optional, Union +from copy import deepcopy +from typing import Dict, Optional, Union import pytest from dirty_equals import HasRepr, IsInstance @@ -1140,3 +1141,75 @@ def test_url_vulnerabilities(url_validator, url, expected): else: output_parts[key] = getattr(output_url, key) assert output_parts == expected + + +def test_multi_host_url_comparison() -> None: + assert MultiHostUrl('http://example.com,www.example.com') == MultiHostUrl('http://example.com,www.example.com') + assert MultiHostUrl('http://example.com,www.example.com') == MultiHostUrl('http://example.com,www.example.com/') + assert MultiHostUrl('http://example.com,www.example.com') != MultiHostUrl('http://example.com,www.example.com/123') + assert MultiHostUrl('http://example.com,www.example.com/123') > MultiHostUrl('http://example.com,www.example.com') + assert MultiHostUrl('http://example.com,www.example.com/123') >= MultiHostUrl('http://example.com,www.example.com') + assert MultiHostUrl('http://example.com,www.example.com') >= MultiHostUrl('http://example.com,www.example.com') + assert MultiHostUrl('http://example.com,www.example.com') < MultiHostUrl('http://example.com,www.example.com/123') + assert MultiHostUrl('http://example.com,www.example.com') <= MultiHostUrl('http://example.com,www.example.com/123') + assert MultiHostUrl('http://example.com,www.example.com') <= MultiHostUrl('http://example.com') + + +def test_multi_host_url_bool() -> None: + assert bool(MultiHostUrl('http://example.com,www.example.com')) is True + + +def test_multi_host_url_hash() -> None: + data: Dict[MultiHostUrl, int] = {} + + data[MultiHostUrl('http://example.com,www.example.com')] = 1 + assert data == {MultiHostUrl('http://example.com,www.example.com/'): 1} + + data[MultiHostUrl('http://example.com,www.example.com/123')] = 2 + assert data == { + MultiHostUrl('http://example.com,www.example.com/'): 1, + MultiHostUrl('http://example.com,www.example.com/123'): 2, + } + + data[MultiHostUrl('http://example.com,www.example.com')] = 3 + assert data == { + MultiHostUrl('http://example.com,www.example.com/'): 3, + MultiHostUrl('http://example.com,www.example.com/123'): 2, + } + + +def test_multi_host_url_deepcopy() -> None: + assert deepcopy(MultiHostUrl('http://example.com')) == MultiHostUrl('http://example.com/') + + +def test_url_comparison() -> None: + assert Url('http://example.com') == Url('http://example.com') + assert Url('http://example.com') == Url('http://example.com/') + assert Url('http://example.com') != Url('http://example.com/123') + assert Url('http://example.com/123') > Url('http://example.com') + assert Url('http://example.com/123') >= Url('http://example.com') + assert Url('http://example.com') >= Url('http://example.com') + assert Url('http://example.com') < Url('http://example.com/123') + assert Url('http://example.com') <= Url('http://example.com/123') + assert Url('http://example.com') <= Url('http://example.com') + + +def test_url_bool() -> None: + assert bool(Url('http://example.com')) is True + + +def test_url_hash() -> None: + data: Dict[Url, int] = {} + + data[Url('http://example.com')] = 1 + assert data == {Url('http://example.com/'): 1} + + data[Url('http://example.com/123')] = 2 + assert data == {Url('http://example.com/'): 1, Url('http://example.com/123'): 2} + + data[Url('http://example.com')] = 3 + assert data == {Url('http://example.com/'): 3, Url('http://example.com/123'): 2} + + +def test_url_deepcopy() -> None: + assert deepcopy(Url('http://example.com')) == Url('http://example.com/')