Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: enable copy.replace for py3.13+ #42

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ messages_control.disable = [
"line-too-long",
"missing-module-docstring",
"missing-function-docstring",
"no-member", # handled by mypy
"no-value-for-parameter", # for plum-dispatch
"too-many-function-args", # handled by testing
"unused-argument", # handled by ruff
"unused-wildcard-import", # handled by ruff
"wildcard-import", # handled by ruff
Expand Down
12 changes: 9 additions & 3 deletions src/dataclassish/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"field_items",
# Classes
"DataclassInstance",
"CanCopyReplace",
"F",
]

Expand All @@ -33,11 +34,16 @@
get_field,
replace,
)
from ._src.types import DataclassInstance, F
from ._src.types import CanCopyReplace, DataclassInstance, F
from ._version import version as __version__

# Register dispatches by importing the submodules
# isort: split
from ._src import register_base, register_dataclass, register_mapping
from ._src import (
register_base,
register_copyreplace,
register_dataclass,
register_mapping,
)

del register_base, register_dataclass, register_mapping
del register_base, register_dataclass, register_mapping, register_copyreplace
18 changes: 17 additions & 1 deletion src/dataclassish/_src/register_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

__all__: list[str] = []

from collections.abc import Mapping
from typing import Any, TypeVar

from plum import dispatch

from .api import fields
from .api import fields, replace
from .types import F

K = TypeVar("K")
V = TypeVar("V")
Expand Down Expand Up @@ -37,6 +39,20 @@ def get_field(obj: Any, k: str, /) -> Any:
return getattr(obj, k)


# ===================================================================
# Replace


def _recursive_replace_helper(obj: object, k: str, v: Any, /) -> Any:
if isinstance(v, F):
out = v.value
elif isinstance(v, Mapping):
out = replace(get_field(obj, k), v)
else:
out = v
return out


# ===================================================================
# Field keys

Expand Down
179 changes: 179 additions & 0 deletions src/dataclassish/_src/register_copyreplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""Register dispatches for CanCopyReplace objects."""

__all__: list[str] = []

import copy
import sys
from collections.abc import Mapping
from typing import Any

from plum import dispatch

from .register_base import _recursive_replace_helper
from .types import CanCopyReplace

# ===================================================================
# Get field


@dispatch # type: ignore[misc]
def get_field(obj: CanCopyReplace, k: str, /) -> Any:
"""Get a field of a dataclass instance by name.

Examples
--------
>>> from dataclasses import dataclass
>>> from dataclassish import get_field

% invisible-code-block: python
%
% import sys

% skip: start if(sys.version_info < (3, 13), reason="py3.13+")

>>> @dataclass
... class Point:
... x: float
... y: float

>>> p = Point(1.0, 2.0)
>>> get_field(p, "x")
1.0

% skip: end

This works for any object that implements the ``__replace__`` method.

>>> class Point:
... def __init__(self, x, y):
... self.x = x
... self.y = y
... def __replace__(self, **changes):
... return Point(**(self.__dict__ | changes))
... def __repr__(self):
... return f"Point(x={self.x}, y={self.y})"

>>> p = Point(1.0, 2.0)
>>> get_field(p, "x")
1.0

"""
return getattr(obj, k)


# ===================================================================
# Replace


@dispatch
def replace(obj: CanCopyReplace, /, **kwargs: Any) -> CanCopyReplace:
"""Replace the fields of an object.

Examples
--------
>>> from dataclassish import replace

% invisible-code-block: python
%
% import sys

% skip: start if(sys.version_info < (3, 13), reason="py3.13+")

As of Python 3.13, dataclasses implement the ``__replace__`` method.

>>> from dataclasses import dataclass

>>> @dataclass
... class Point:
... x: float
... y: float

>>> p = Point(1.0, 2.0)
>>> p
Point(x=1.0, y=2.0)

>>> replace(p, x=3.0)
Point(x=3.0, y=2.0)

% skip: end

>>> class Point:
... def __init__(self, x, y):
... self.x = x
... self.y = y
... def __replace__(self, **changes):
... return Point(**(self.__dict__ | changes))
... def __repr__(self):
... return f"Point(x={self.x}, y={self.y})"

>>> p = Point(1.0, 2.0)
>>> replace(p, x=2.0)
Point(x=2.0, y=2.0)

The ``__replace__`` method was introduced in Python 3.13 to bring
``dataclasses.replace``-like functionality to any implementing object. The
method is publicly exposed via the ``copy.replace`` function.

% invisible-code-block: python
%
% import sys

% skip: start if(sys.version_info < (3, 13), reason="py3.13+")

>>> import copy
>>> copy.replace(p, x=3.0)
Point(x=3.0, y=2.0)

% skip: end

"""
return (
obj.__replace__(**kwargs)
if sys.version_info < (3, 13)
else copy.replace(obj, **kwargs)
)


@dispatch # type: ignore[no-redef]
def replace(obj: CanCopyReplace, fs: Mapping[str, Any], /) -> CanCopyReplace:
"""Replace the fields of a dataclass instance.

Examples
--------
>>> from dataclasses import dataclass
>>> from dataclassish import replace, F

>>> class Point:
... def __init__(self, x, y):
... self.x = x
... self.y = y
... def __replace__(self, **changes):
... return Point(**(self.__dict__ | changes))
... def __repr__(self):
... return f"Point(x={self.x}, y={self.y})"

>>> @dataclass
... class TwoPoint:
... a: Point
... b: Point

>>> p = TwoPoint(Point(1.0, 2.0), Point(3.0, 4.0))
>>> p
TwoPoint(a=Point(x=1.0, y=2.0), b=Point(x=3.0, y=4.0))

>>> replace(p, {"a": {"x": 5.0}, "b": {"y": 6.0}})
TwoPoint(a=Point(x=5.0, y=2.0), b=Point(x=3.0, y=6.0))

>>> replace(p, {"a": {"x": F({"thing": 5.0})}})
TwoPoint(a=Point(x={'thing': 5.0}, y=2.0),
b=Point(x=3.0, y=4.0))

This also works on mixed-type structures, e.g. a dictionary of objects.

>>> p = {"a": Point(1.0, 2.0), "b": Point(3.0, 4.0)}
>>> replace(p, {"a": {"x": 5.0}, "b": {"y": 6.0}})
{'a': Point(x=5.0, y=2.0), 'b': Point(x=3.0, y=6.0)}

"""
kwargs = {k: _recursive_replace_helper(obj, k, v) for k, v in fs.items()}
return replace(obj, **kwargs)
Loading
Loading