Skip to content

Commit

Permalink
Support RichComparison, hash and deepcopy for Url and MultiHostUrl (#558
Browse files Browse the repository at this point in the history
)
  • Loading branch information
adriangb authored Apr 25, 2023
1 parent ae4cb28 commit c31fe51
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 3 deletions.
6 changes: 4 additions & 2 deletions pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ if sys.version_info < (3, 11):
else:
from typing import Literal, NotRequired, TypeAlias

from _typeshed import SupportsAllComparisons

__all__ = (
'__version__',
'build_profile',
Expand Down Expand Up @@ -126,7 +128,7 @@ def to_jsonable_python(
fallback: 'Callable[[Any], Any] | None' = None,
) -> Any: ...

class Url:
class Url(SupportsAllComparisons):
@property
def scheme(self) -> str: ...
@property
Expand Down Expand Up @@ -156,7 +158,7 @@ class MultiHostHost(TypedDict):
query: 'str | None'
fragment: 'str | None'

class MultiHostUrl:
class MultiHostUrl(SupportsAllComparisons):
@property
def scheme(self) -> str: ...
@property
Expand Down
55 changes: 55 additions & 0 deletions src/url.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

use idna::punycode::decode_to_string;
use pyo3::once_cell::GILOnceCell;
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;
use pyo3::types::PyDict;
use url::Url;

Expand Down Expand Up @@ -116,6 +120,31 @@ impl PyUrl {
pub fn __repr__(&self) -> String {
format!("Url('{}')", self.lib_url)
}

fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
match op {
CompareOp::Lt => Ok(self.lib_url < other.lib_url),
CompareOp::Le => Ok(self.lib_url <= other.lib_url),
CompareOp::Eq => Ok(self.lib_url == other.lib_url),
CompareOp::Ne => Ok(self.lib_url != other.lib_url),
CompareOp::Gt => Ok(self.lib_url > other.lib_url),
CompareOp::Ge => Ok(self.lib_url >= other.lib_url),
}
}

fn __hash__(&self) -> u64 {
let mut s = DefaultHasher::new();
self.lib_url.to_string().hash(&mut s);
s.finish()
}

fn __bool__(&self) -> bool {
true // an empty string is not a valid URL
}

pub fn __deepcopy__(&self, py: Python, _memo: &PyDict) -> Py<PyAny> {
self.clone().into_py(py)
}
}

#[pyclass(name = "MultiHostUrl", module = "pydantic_core._pydantic_core")]
Expand Down Expand Up @@ -250,6 +279,32 @@ impl PyMultiHostUrl {
pub fn __repr__(&self) -> String {
format!("Url('{}')", self.__str__())
}

fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
match op {
CompareOp::Lt => Ok(self.unicode_string() < other.unicode_string()),
CompareOp::Le => Ok(self.unicode_string() <= other.unicode_string()),
CompareOp::Eq => Ok(self.unicode_string() == other.unicode_string()),
CompareOp::Ne => Ok(self.unicode_string() != other.unicode_string()),
CompareOp::Gt => Ok(self.unicode_string() > other.unicode_string()),
CompareOp::Ge => Ok(self.unicode_string() >= other.unicode_string()),
}
}

fn __hash__(&self) -> u64 {
let mut s = DefaultHasher::new();
self.ref_url.clone().into_url().to_string().hash(&mut s);
self.extra_urls.hash(&mut s);
s.finish()
}

fn __bool__(&self) -> bool {
true // an empty string is not a valid URL
}

pub fn __deepcopy__(&self, py: Python, _memo: &PyDict) -> Py<PyAny> {
self.clone().into_py(py)
}
}

fn host_to_dict<'a>(py: Python<'a>, lib_url: &Url) -> PyResult<&'a PyDict> {
Expand Down
75 changes: 74 additions & 1 deletion tests/validators/test_url.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from typing import Optional, Union
from copy import deepcopy
from typing import Dict, Optional, Union

import pytest
from dirty_equals import HasRepr, IsInstance
Expand Down Expand Up @@ -1140,3 +1141,75 @@ def test_url_vulnerabilities(url_validator, url, expected):
else:
output_parts[key] = getattr(output_url, key)
assert output_parts == expected


def test_multi_host_url_comparison() -> None:
assert MultiHostUrl('http://example.com,www.example.com') == MultiHostUrl('http://example.com,www.example.com')
assert MultiHostUrl('http://example.com,www.example.com') == MultiHostUrl('http://example.com,www.example.com/')
assert MultiHostUrl('http://example.com,www.example.com') != MultiHostUrl('http://example.com,www.example.com/123')
assert MultiHostUrl('http://example.com,www.example.com/123') > MultiHostUrl('http://example.com,www.example.com')
assert MultiHostUrl('http://example.com,www.example.com/123') >= MultiHostUrl('http://example.com,www.example.com')
assert MultiHostUrl('http://example.com,www.example.com') >= MultiHostUrl('http://example.com,www.example.com')
assert MultiHostUrl('http://example.com,www.example.com') < MultiHostUrl('http://example.com,www.example.com/123')
assert MultiHostUrl('http://example.com,www.example.com') <= MultiHostUrl('http://example.com,www.example.com/123')
assert MultiHostUrl('http://example.com,www.example.com') <= MultiHostUrl('http://example.com')


def test_multi_host_url_bool() -> None:
assert bool(MultiHostUrl('http://example.com,www.example.com')) is True


def test_multi_host_url_hash() -> None:
data: Dict[MultiHostUrl, int] = {}

data[MultiHostUrl('http://example.com,www.example.com')] = 1
assert data == {MultiHostUrl('http://example.com,www.example.com/'): 1}

data[MultiHostUrl('http://example.com,www.example.com/123')] = 2
assert data == {
MultiHostUrl('http://example.com,www.example.com/'): 1,
MultiHostUrl('http://example.com,www.example.com/123'): 2,
}

data[MultiHostUrl('http://example.com,www.example.com')] = 3
assert data == {
MultiHostUrl('http://example.com,www.example.com/'): 3,
MultiHostUrl('http://example.com,www.example.com/123'): 2,
}


def test_multi_host_url_deepcopy() -> None:
assert deepcopy(MultiHostUrl('http://example.com')) == MultiHostUrl('http://example.com/')


def test_url_comparison() -> None:
assert Url('http://example.com') == Url('http://example.com')
assert Url('http://example.com') == Url('http://example.com/')
assert Url('http://example.com') != Url('http://example.com/123')
assert Url('http://example.com/123') > Url('http://example.com')
assert Url('http://example.com/123') >= Url('http://example.com')
assert Url('http://example.com') >= Url('http://example.com')
assert Url('http://example.com') < Url('http://example.com/123')
assert Url('http://example.com') <= Url('http://example.com/123')
assert Url('http://example.com') <= Url('http://example.com')


def test_url_bool() -> None:
assert bool(Url('http://example.com')) is True


def test_url_hash() -> None:
data: Dict[Url, int] = {}

data[Url('http://example.com')] = 1
assert data == {Url('http://example.com/'): 1}

data[Url('http://example.com/123')] = 2
assert data == {Url('http://example.com/'): 1, Url('http://example.com/123'): 2}

data[Url('http://example.com')] = 3
assert data == {Url('http://example.com/'): 3, Url('http://example.com/123'): 2}


def test_url_deepcopy() -> None:
assert deepcopy(Url('http://example.com')) == Url('http://example.com/')

0 comments on commit c31fe51

Please sign in to comment.