Skip to content

Commit

Permalink
✨ feat: enable copy.replace for py3.13+
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Dec 4, 2024
1 parent b0075cf commit 6bad785
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 108 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ messages_control.disable = [
"line-too-long",
"missing-module-docstring",
"missing-function-docstring",
"no-member", # handled by mypy
"no-value-for-parameter", # for plum-dispatch
"unused-argument", # handled by ruff
"unused-wildcard-import", # handled by ruff
Expand Down
12 changes: 9 additions & 3 deletions src/dataclassish/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"field_items",
# Classes
"DataclassInstance",
"CanCopyReplace",
"F",
]

Expand All @@ -33,11 +34,16 @@
get_field,
replace,
)
from ._src.types import DataclassInstance, F
from ._src.types import CanCopyReplace, DataclassInstance, F
from ._version import version as __version__

# Register dispatches by importing the submodules
# isort: split
from ._src import register_base, register_dataclass, register_mapping
from ._src import (
register_base,
register_copyreplace,
register_dataclass,
register_mapping,
)

del register_base, register_dataclass, register_mapping
del register_base, register_dataclass, register_mapping, register_copyreplace
18 changes: 17 additions & 1 deletion src/dataclassish/_src/register_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

__all__: list[str] = []

from collections.abc import Mapping
from typing import Any, TypeVar

from plum import dispatch

from .api import fields
from .api import fields, replace
from .types import F

K = TypeVar("K")
V = TypeVar("V")
Expand Down Expand Up @@ -37,6 +39,20 @@ def get_field(obj: Any, k: str, /) -> Any:
return getattr(obj, k)


# ===================================================================
# Replace


def _recursive_replace_helper(obj: object, k: str, v: Any, /) -> Any:
if isinstance(v, F):
out = v.value
elif isinstance(v, Mapping):
out = replace(get_field(obj, k), v)
else:
out = v
return out


# ===================================================================
# Field keys

Expand Down
184 changes: 184 additions & 0 deletions src/dataclassish/_src/register_copyreplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""Register dispatches for CanCopyReplace objects."""

__all__: list[str] = []

import copy
import sys
from collections.abc import Mapping
from typing import Any

from plum import dispatch

from .register_base import _recursive_replace_helper
from .types import CanCopyReplace

# ===================================================================
# Get field


@dispatch # type: ignore[misc]
def get_field(obj: CanCopyReplace, k: str, /) -> Any:
"""Get a field of a dataclass instance by name.
Examples
--------
>>> from dataclasses import dataclass
>>> from dataclassish import get_field
% invisible-code-block: python
%
% import sys
% skip: start if(sys.version_info < (3, 13), reason="py3.13+")
>>> @dataclass
... class Point:
... x: float
... y: float
>>> p = Point(1.0, 2.0)
>>> get_field(p, "x")
1.0
% skip: end
This works for any object that implements the ``__replace__`` method.
>>> class Point:
... def __init__(self, x, y):
... self.x = x
... self.y = y
... def __replace__(self, **changes):
... return Point(**(self.__dict__ | changes))
... def __repr__(self):
... return f"Point(x={self.x}, y={self.y})"
>>> p = Point(1.0, 2.0)
>>> get_field(p, "x")
1.0
"""
return getattr(obj, k)


# ===================================================================
# Replace


@dispatch
def replace(obj: CanCopyReplace, /, **kwargs: Any) -> CanCopyReplace:
"""Replace the fields of an object.
Examples
--------
>>> from dataclassish import replace
% invisible-code-block: python
%
% import sys
% skip: start if(sys.version_info < (3, 13), reason="py3.13+")
As of Python 3.13, dataclasses implement the ``__replace__`` method.
>>> from dataclasses import dataclass
>>> @dataclass
... class Point:
... x: float
... y: float
>>> p = Point(1.0, 2.0)
>>> p
Point(x=1.0, y=2.0)
>>> replace(p, x=3.0)
Point(x=3.0, y=2.0)
% skip: end
>>> class Point:
... def __init__(self, x, y):
... self.x = x
... self.y = y
... def __replace__(self, **changes):
... return Point(**(self.__dict__ | changes))
... def __repr__(self):
... return f"Point(x={self.x}, y={self.y})"
>>> p = Point(1.0, 2.0)
>>> replace(p, x=2.0)
Point(x=2.0, y=2.0)
The ``__replace__`` method was introduced in Python 3.13 to bring
``dataclasses.replace``-like functionality to any implementing object. The
method is publicly exposed via the ``copy.replace`` function.
% invisible-code-block: python
%
% import sys
% skip: start if(sys.version_info < (3, 13), reason="py3.13+")
>>> import copy
>>> copy.replace(p, x=3.0)
Point(x=3.0, y=2.0)
% skip: end
"""
if sys.version_info < (3, 13):
return obj.__replace__(**kwargs)
return copy.replace(obj, **kwargs)


@dispatch # type: ignore[no-redef]
def replace(obj: CanCopyReplace, fs: Mapping[str, Any], /) -> CanCopyReplace:
"""Replace the fields of a dataclass instance.
Examples
--------
>>> from dataclasses import dataclass
>>> from dataclassish import replace, F
>>> @dataclass
... class Point:
... x: float | dict
... y: float
>>> @dataclass
... class PointofPoints:
... a: Point
... b: Point
>>> p = PointofPoints(Point(1.0, 2.0), Point(3.0, 4.0))
>>> p
PointofPoints(a=Point(x=1.0, y=2.0), b=Point(x=3.0, y=4.0))
>>> replace(p, {"a": {"x": 5.0}, "b": {"y": 6.0}})
PointofPoints(a=Point(x=5.0, y=2.0), b=Point(x=3.0, y=6.0))
>>> replace(p, {"a": {"x": F({"thing": 5.0})}})
PointofPoints(a=Point(x={'thing': 5.0}, y=2.0),
b=Point(x=3.0, y=4.0))
This also works on mixed-type structures, e.g. a dictionary of dataclasses.
>>> p = {"a": Point(1.0, 2.0), "b": Point(3.0, 4.0)}
>>> replace(p, {"a": {"x": 5.0}, "b": {"y": 6.0}})
{'a': Point(x=5.0, y=2.0), 'b': Point(x=3.0, y=6.0)}
Or a dataclass of dictionaries.
>>> @dataclass
... class Object:
... a: dict[str, Any]
... b: dict[str, Any]
>>> p = Object({"a": 1, "b": 2}, {"c": 3, "d": 4})
>>> replace(p, {"a": {"b": 5}, "b": {"c": 6}})
Object(a={'a': 1, 'b': 5}, b={'c': 6, 'd': 4})
"""
kwargs = {k: _recursive_replace_helper(obj, k, v) for k, v in fs.items()}
return replace(obj, **kwargs)
Loading

0 comments on commit 6bad785

Please sign in to comment.