diff --git a/pyproject.toml b/pyproject.toml index d91ada6..b6fd29a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,7 +145,9 @@ messages_control.disable = [ "line-too-long", "missing-module-docstring", "missing-function-docstring", + "no-member", # handled by mypy "no-value-for-parameter", # for plum-dispatch + "too-many-function-args", # handled by testing "unused-argument", # handled by ruff "unused-wildcard-import", # handled by ruff "wildcard-import", # handled by ruff diff --git a/src/dataclassish/__init__.py b/src/dataclassish/__init__.py index f205ee9..b36ef59 100644 --- a/src/dataclassish/__init__.py +++ b/src/dataclassish/__init__.py @@ -19,6 +19,7 @@ "field_items", # Classes "DataclassInstance", + "CanCopyReplace", "F", ] @@ -33,11 +34,16 @@ get_field, replace, ) -from ._src.types import DataclassInstance, F +from ._src.types import CanCopyReplace, DataclassInstance, F from ._version import version as __version__ # Register dispatches by importing the submodules # isort: split -from ._src import register_base, register_dataclass, register_mapping +from ._src import ( + register_base, + register_copyreplace, + register_dataclass, + register_mapping, +) -del register_base, register_dataclass, register_mapping +del register_base, register_dataclass, register_mapping, register_copyreplace diff --git a/src/dataclassish/_src/register_base.py b/src/dataclassish/_src/register_base.py index 78b0cd7..7e5c6b4 100644 --- a/src/dataclassish/_src/register_base.py +++ b/src/dataclassish/_src/register_base.py @@ -2,11 +2,13 @@ __all__: list[str] = [] +from collections.abc import Mapping from typing import Any, TypeVar from plum import dispatch -from .api import fields +from .api import fields, replace +from .types import F K = TypeVar("K") V = TypeVar("V") @@ -37,6 +39,20 @@ def get_field(obj: Any, k: str, /) -> Any: return getattr(obj, k) +# =================================================================== +# Replace + + +def _recursive_replace_helper(obj: object, k: str, v: Any, /) -> Any: + if isinstance(v, F): + out = v.value + elif isinstance(v, Mapping): + out = replace(get_field(obj, k), v) + else: + out = v + return out + + # =================================================================== # Field keys diff --git a/src/dataclassish/_src/register_copyreplace.py b/src/dataclassish/_src/register_copyreplace.py new file mode 100644 index 0000000..ce05545 --- /dev/null +++ b/src/dataclassish/_src/register_copyreplace.py @@ -0,0 +1,179 @@ +"""Register dispatches for CanCopyReplace objects.""" + +__all__: list[str] = [] + +import copy +import sys +from collections.abc import Mapping +from typing import Any + +from plum import dispatch + +from .register_base import _recursive_replace_helper +from .types import CanCopyReplace + +# =================================================================== +# Get field + + +@dispatch # type: ignore[misc] +def get_field(obj: CanCopyReplace, k: str, /) -> Any: + """Get a field of a dataclass instance by name. + + Examples + -------- + >>> from dataclasses import dataclass + >>> from dataclassish import get_field + + % invisible-code-block: python + % + % import sys + + % skip: start if(sys.version_info < (3, 13), reason="py3.13+") + + >>> @dataclass + ... class Point: + ... x: float + ... y: float + + >>> p = Point(1.0, 2.0) + >>> get_field(p, "x") + 1.0 + + % skip: end + + This works for any object that implements the ``__replace__`` method. + + >>> class Point: + ... def __init__(self, x, y): + ... self.x = x + ... self.y = y + ... def __replace__(self, **changes): + ... return Point(**(self.__dict__ | changes)) + ... def __repr__(self): + ... return f"Point(x={self.x}, y={self.y})" + + >>> p = Point(1.0, 2.0) + >>> get_field(p, "x") + 1.0 + + """ + return getattr(obj, k) + + +# =================================================================== +# Replace + + +@dispatch +def replace(obj: CanCopyReplace, /, **kwargs: Any) -> CanCopyReplace: + """Replace the fields of an object. + + Examples + -------- + >>> from dataclassish import replace + + % invisible-code-block: python + % + % import sys + + % skip: start if(sys.version_info < (3, 13), reason="py3.13+") + + As of Python 3.13, dataclasses implement the ``__replace__`` method. + + >>> from dataclasses import dataclass + + >>> @dataclass + ... class Point: + ... x: float + ... y: float + + >>> p = Point(1.0, 2.0) + >>> p + Point(x=1.0, y=2.0) + + >>> replace(p, x=3.0) + Point(x=3.0, y=2.0) + + % skip: end + + >>> class Point: + ... def __init__(self, x, y): + ... self.x = x + ... self.y = y + ... def __replace__(self, **changes): + ... return Point(**(self.__dict__ | changes)) + ... def __repr__(self): + ... return f"Point(x={self.x}, y={self.y})" + + >>> p = Point(1.0, 2.0) + >>> replace(p, x=2.0) + Point(x=2.0, y=2.0) + + The ``__replace__`` method was introduced in Python 3.13 to bring + ``dataclasses.replace``-like functionality to any implementing object. The + method is publicly exposed via the ``copy.replace`` function. + + % invisible-code-block: python + % + % import sys + + % skip: start if(sys.version_info < (3, 13), reason="py3.13+") + + >>> import copy + >>> copy.replace(p, x=3.0) + Point(x=3.0, y=2.0) + + % skip: end + + """ + return ( + obj.__replace__(**kwargs) + if sys.version_info < (3, 13) + else copy.replace(obj, **kwargs) + ) + + +@dispatch # type: ignore[no-redef] +def replace(obj: CanCopyReplace, fs: Mapping[str, Any], /) -> CanCopyReplace: + """Replace the fields of a dataclass instance. + + Examples + -------- + >>> from dataclasses import dataclass + >>> from dataclassish import replace, F + + >>> class Point: + ... def __init__(self, x, y): + ... self.x = x + ... self.y = y + ... def __replace__(self, **changes): + ... return Point(**(self.__dict__ | changes)) + ... def __repr__(self): + ... return f"Point(x={self.x}, y={self.y})" + + >>> @dataclass + ... class TwoPoint: + ... a: Point + ... b: Point + + >>> p = TwoPoint(Point(1.0, 2.0), Point(3.0, 4.0)) + >>> p + TwoPoint(a=Point(x=1.0, y=2.0), b=Point(x=3.0, y=4.0)) + + >>> replace(p, {"a": {"x": 5.0}, "b": {"y": 6.0}}) + TwoPoint(a=Point(x=5.0, y=2.0), b=Point(x=3.0, y=6.0)) + + >>> replace(p, {"a": {"x": F({"thing": 5.0})}}) + TwoPoint(a=Point(x={'thing': 5.0}, y=2.0), + b=Point(x=3.0, y=4.0)) + + This also works on mixed-type structures, e.g. a dictionary of objects. + + >>> p = {"a": Point(1.0, 2.0), "b": Point(3.0, 4.0)} + >>> replace(p, {"a": {"x": 5.0}, "b": {"y": 6.0}}) + {'a': Point(x=5.0, y=2.0), 'b': Point(x=3.0, y=6.0)} + + """ + kwargs = {k: _recursive_replace_helper(obj, k, v) for k, v in fs.items()} + return replace(obj, **kwargs) diff --git a/src/dataclassish/_src/register_dataclass.py b/src/dataclassish/_src/register_dataclass.py index 0d05d8a..a378fe4 100644 --- a/src/dataclassish/_src/register_dataclass.py +++ b/src/dataclassish/_src/register_dataclass.py @@ -2,6 +2,7 @@ __all__: list[str] = [] +import sys from collections.abc import Callable, Mapping from dataclasses import ( Field, @@ -14,124 +15,114 @@ from plum import dispatch -from .types import DataclassInstance, F +from .register_base import _recursive_replace_helper +from .types import DataclassInstance # =================================================================== +if sys.version_info < (3, 13): -@dispatch # type: ignore[misc] -def get_field(obj: DataclassInstance, k: str, /) -> Any: - """Get a field of a dataclass instance by name. + @dispatch # type: ignore[misc] + def get_field(obj: DataclassInstance, k: str, /) -> Any: + """Get a field of a dataclass instance by name. - Examples - -------- - >>> from dataclasses import dataclass - >>> from dataclassish import get_field + Examples + -------- + >>> from dataclasses import dataclass + >>> from dataclassish import get_field - >>> @dataclass - ... class Point: - ... x: float - ... y: float + >>> @dataclass + ... class Point: + ... x: float + ... y: float - >>> p = Point(1.0, 2.0) - >>> get_field(p, "x") - 1.0 + >>> p = Point(1.0, 2.0) + >>> get_field(p, "x") + 1.0 - """ - return getattr(obj, k) + """ + return getattr(obj, k) # =================================================================== # Replace +if sys.version_info < (3, 13): -@dispatch # type: ignore[misc] -def replace(obj: DataclassInstance, /, **kwargs: Any) -> DataclassInstance: - """Replace the fields of a dataclass instance. - - Examples - -------- - >>> from dataclasses import dataclass - >>> from dataclassish import replace - - >>> @dataclass - ... class Point: - ... x: float - ... y: float - - >>> p = Point(1.0, 2.0) - >>> p - Point(x=1.0, y=2.0) + @dispatch + def replace(obj: DataclassInstance, /, **kwargs: Any) -> DataclassInstance: + """Replace the fields of a dataclass instance. - >>> replace(p, x=3.0) - Point(x=3.0, y=2.0) + Examples + -------- + >>> from dataclasses import dataclass + >>> from dataclassish import replace - """ - return _dataclass_replace(obj, **kwargs) + >>> @dataclass + ... class Point: + ... x: float + ... y: float + >>> p = Point(1.0, 2.0) + >>> p + Point(x=1.0, y=2.0) -def _recursive_replace_dataclass_helper( - obj: DataclassInstance, k: str, v: Any, / -) -> Any: - if isinstance(v, F): - out = v.value - elif isinstance(v, Mapping): - out = replace(get_field(obj, k), v) - else: - out = v - return out + >>> replace(p, x=3.0) + Point(x=3.0, y=2.0) + """ + return _dataclass_replace(obj, **kwargs) -@dispatch # type: ignore[misc, no-redef] -def replace(obj: DataclassInstance, fs: Mapping[str, Any], /) -> DataclassInstance: - """Replace the fields of a dataclass instance. + @dispatch # type: ignore[no-redef] + def replace(obj: DataclassInstance, fs: Mapping[str, Any], /) -> DataclassInstance: + """Replace the fields of a dataclass instance. - Examples - -------- - >>> from dataclasses import dataclass - >>> from dataclassish import replace, F + Examples + -------- + >>> from dataclasses import dataclass + >>> from dataclassish import replace, F - >>> @dataclass - ... class Point: - ... x: float | dict - ... y: float + >>> @dataclass + ... class Point: + ... x: float | dict + ... y: float - >>> @dataclass - ... class PointofPoints: - ... a: Point - ... b: Point + >>> @dataclass + ... class TwoPoints: + ... a: Point + ... b: Point - >>> p = PointofPoints(Point(1.0, 2.0), Point(3.0, 4.0)) - >>> p - PointofPoints(a=Point(x=1.0, y=2.0), b=Point(x=3.0, y=4.0)) + >>> p = TwoPoints(Point(1.0, 2.0), Point(3.0, 4.0)) + >>> p + TwoPoints(a=Point(x=1.0, y=2.0), b=Point(x=3.0, y=4.0)) - >>> replace(p, {"a": {"x": 5.0}, "b": {"y": 6.0}}) - PointofPoints(a=Point(x=5.0, y=2.0), b=Point(x=3.0, y=6.0)) + >>> replace(p, {"a": {"x": 5.0}, "b": {"y": 6.0}}) + TwoPoints(a=Point(x=5.0, y=2.0), b=Point(x=3.0, y=6.0)) - >>> replace(p, {"a": {"x": F({"thing": 5.0})}}) - PointofPoints(a=Point(x={'thing': 5.0}, y=2.0), - b=Point(x=3.0, y=4.0)) + >>> replace(p, {"a": {"x": F({"thing": 5.0})}}) + TwoPoints(a=Point(x={'thing': 5.0}, y=2.0), + b=Point(x=3.0, y=4.0)) - This also works on mixed-type structures, e.g. a dictionary of dataclasses. + This also works on mixed-type structures, e.g. a dictionary of dataclasses. - >>> p = {"a": Point(1.0, 2.0), "b": Point(3.0, 4.0)} - >>> replace(p, {"a": {"x": 5.0}, "b": {"y": 6.0}}) - {'a': Point(x=5.0, y=2.0), 'b': Point(x=3.0, y=6.0)} + >>> p = {"a": Point(1.0, 2.0), "b": Point(3.0, 4.0)} + >>> replace(p, {"a": {"x": 5.0}, "b": {"y": 6.0}}) + {'a': Point(x=5.0, y=2.0), 'b': Point(x=3.0, y=6.0)} - Or a dataclass of dictionaries. + Or a dataclass of dictionaries. - >>> @dataclass - ... class Object: - ... a: dict[str, Any] - ... b: dict[str, Any] + >>> @dataclass + ... class Object: + ... a: dict[str, Any] + ... b: dict[str, Any] - >>> p = Object({"a": 1, "b": 2}, {"c": 3, "d": 4}) - >>> replace(p, {"a": {"b": 5}, "b": {"c": 6}}) - Object(a={'a': 1, 'b': 5}, b={'c': 6, 'd': 4}) + >>> p = Object({"a": 1, "b": 2}, {"c": 3, "d": 4}) + >>> replace(p, {"a": {"b": 5}, "b": {"c": 6}}) + Object(a={'a': 1, 'b': 5}, b={'c': 6, 'd': 4}) - """ - kwargs = {k: _recursive_replace_dataclass_helper(obj, k, v) for k, v in fs.items()} - return _dataclass_replace(obj, **kwargs) + """ + kwargs = {k: _recursive_replace_helper(obj, k, v) for k, v in fs.items()} + return replace(obj, **kwargs) # =================================================================== diff --git a/src/dataclassish/_src/register_mapping.py b/src/dataclassish/_src/register_mapping.py index 9f68e51..e7456cc 100644 --- a/src/dataclassish/_src/register_mapping.py +++ b/src/dataclassish/_src/register_mapping.py @@ -8,7 +8,7 @@ from plum import dispatch -from .types import F +from .register_base import _recursive_replace_helper # =================================================================== @@ -33,7 +33,7 @@ def get_field(obj: Mapping[Hashable, Any], k: Hashable, /) -> Any: # Replace -@dispatch # type: ignore[misc] +@dispatch def replace(obj: Mapping[str, Any], /, **kwargs: Any) -> Mapping[str, Any]: """Replace the fields of a mapping. @@ -67,19 +67,7 @@ def replace(obj: Mapping[str, Any], /, **kwargs: Any) -> Mapping[str, Any]: return type(obj)(**{**obj, **kwargs}) -def _recursive_replace_mapping_helper( - obj: Mapping[Hashable, Any], k: str, v: Any, / -) -> Any: - if isinstance(v, F): # Field, stop here. - out = v.value - elif isinstance(v, Mapping): # more to replace, recurse. - out = replace(get_field(obj, k), v) - else: # nothing to replace, keep the value. - out = v - return out - - -@dispatch # type: ignore[misc,no-redef] +@dispatch # type: ignore[no-redef] def replace( obj: Mapping[Hashable, Any], fs: Mapping[str, Any], / ) -> Mapping[Hashable, Any]: @@ -115,8 +103,7 @@ def replace( """ # Recursively replace the fields - kwargs = {k: _recursive_replace_mapping_helper(obj, k, v) for k, v in fs.items()} - + kwargs = {k: _recursive_replace_helper(obj, k, v) for k, v in fs.items()} return type(obj)(**(dict(obj) | kwargs)) diff --git a/src/dataclassish/_src/types.py b/src/dataclassish/_src/types.py index 4895c13..5301c54 100644 --- a/src/dataclassish/_src/types.py +++ b/src/dataclassish/_src/types.py @@ -1,10 +1,12 @@ """Data types for ``dataclassish``.""" -__all__ = ["DataclassInstance", "F"] +__all__ = ["DataclassInstance", "CanCopyReplace", "F"] from dataclasses import dataclass from typing import Any, ClassVar, Generic, Protocol, TypeVar, runtime_checkable +from typing_extensions import Self + @runtime_checkable class DataclassInstance(Protocol): @@ -71,3 +73,59 @@ class F(Generic[V]): """ value: V + + +# =================================================================== + + +@runtime_checkable +class CanCopyReplace(Protocol): + """Protocol for objects that implement the ``__replace__`` method. + + This is used by ``copy.replace`` (Python 3.13+) to replace fields of an + object. This is a generalization of the ``dataclasses.replace`` function. + + Examples + -------- + >>> from dataclassish import CanCopyReplace + + >>> class Point: + ... def __init__(self, x, y): + ... self.x = x + ... self.y = y + ... def __replace__(self, **changes): + ... return Point(**(self.__dict__ | changes)) + ... def __repr__(self): + ... return f"Point(x={self.x}, y={self.y})" + + >>> issubclass(Point, CanCopyReplace) + True + + >>> point = Point(1.0, 2.0) + >>> isinstance(point, CanCopyReplace) + True + + The ``__replace__`` method was introduced in Python 3.13 to bring + ``dataclasses.replace``-like functionality to any implementing object. The + method is publicly exposed via the ``copy.replace`` function. + + % invisible-code-block: python + % + % import sys + + % skip: start if(sys.version_info < (3, 13), reason="py3.13+") + + >>> import copy + >>> copy.replace(point, x=3.0) + Point(x=3.0, y=2.0) + + % skip: end + + """ + + def __replace__(self: Self, /, **changes: Any) -> Self: + """Replace the fields of the object. + + This method should return a new object with the fields replaced. + + """