Skip to content

Commit

Permalink
add consistent tests for dtypes and devices
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jan 28, 2025
1 parent d2b85ab commit c865f64
Show file tree
Hide file tree
Showing 22 changed files with 339 additions and 216 deletions.
8 changes: 8 additions & 0 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,17 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
Added
#####

* Enhanced ``device`` and ``dtype`` consistency checks throughout the library
* Require consistent ``dtype`` between ``positions`` and ``neighbor_distances`` in
``Calculator`` classes and tuning functions.


Fixed
#####

* Fix ``device`` and ``dtype`` not being used in the init of the ``P3MCalculator``


`Version 0.2.0 <https://github.com/lab-cosmo/torch-pme/releases/tag/v0.2.0>`_ - 2025-01-23
------------------------------------------------------------------------------------------

Expand Down
38 changes: 26 additions & 12 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,23 @@ def _validate_parameters(
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
dtype: torch.dtype,
device: Union[str, torch.device],
) -> None:
device = positions.device
dtype = positions.dtype
if positions.dtype != dtype:
raise TypeError(
f"type of `positions` ({positions.dtype}) must be same as class "
f"type ({dtype})"
)

if isinstance(device, torch.device):
device = device.type

if positions.device.type != device:
raise ValueError(
f"device of `positions` ({positions.device}) must be same as class "
f"device ({device})"
)

# check shape, dtype and device of positions
num_atoms = len(positions)
Expand All @@ -29,12 +43,12 @@ def _validate_parameters(
f"{list(cell.shape)}"
)

if cell.dtype != dtype:
raise ValueError(
if cell.dtype != positions.dtype:
raise TypeError(
f"type of `cell` ({cell.dtype}) must be same as `positions` ({dtype})"
)

if cell.device != device:
if cell.device != positions.device:
raise ValueError(
f"device of `cell` ({cell.device}) must be same as `positions` ({device})"
)
Expand Down Expand Up @@ -63,12 +77,12 @@ def _validate_parameters(
f"{len(positions)} atoms"
)

if charges.dtype != dtype:
raise ValueError(
if charges.dtype != positions.dtype:
raise TypeError(
f"type of `charges` ({charges.dtype}) must be same as `positions` ({dtype})"
)

if charges.device != device:
if charges.device != positions.device:
raise ValueError(
f"device of `charges` ({charges.device}) must be same as `positions` "
f"({device})"
Expand All @@ -82,7 +96,7 @@ def _validate_parameters(
"structure"
)

if neighbor_indices.device != device:
if neighbor_indices.device != positions.device:
raise ValueError(
f"device of `neighbor_indices` ({neighbor_indices.device}) must be "
f"same as `positions` ({device})"
Expand All @@ -95,14 +109,14 @@ def _validate_parameters(
f"{list(neighbor_indices.shape)} and {list(neighbor_distances.shape)}"
)

if neighbor_distances.device != device:
if neighbor_distances.device != positions.device:
raise ValueError(
f"device of `neighbor_distances` ({neighbor_distances.device}) must be "
f"same as `positions` ({device})"
)

if neighbor_distances.dtype != dtype:
raise ValueError(
if neighbor_distances.dtype != positions.dtype:
raise TypeError(
f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same "
f"as `positions` ({dtype})"
)
27 changes: 16 additions & 11 deletions src/torchpme/calculators/calculator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

import torch
from torch import profiler
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__()

Expand All @@ -48,17 +48,20 @@ def __init__(

self.device = torch.get_default_device() if device is None else device
self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.potential = potential

assert self.dtype == self.potential.dtype, (
f"Potential and Calculator must have the same dtype, got {self.dtype} and "
f"{self.potential.dtype}"
)
assert self.device == self.potential.device, (
f"Potential and Calculator must have the same device, got {self.device} and "
f"{self.potential.device}"
)
if self.dtype != potential.dtype:
raise TypeError(
f"dtype of `potential` ({potential.dtype}) must be same as of "
f"`calculator` ({self.dtype})"
)

if self.device != potential.device:
raise ValueError(
f"device of `potential` ({potential.device}) must be same as of "
f"`calculator` ({self.device})"
)

self.potential = potential
self.full_neighbor_list = full_neighbor_list

self.prefactor = prefactor
Expand Down Expand Up @@ -179,6 +182,8 @@ def forward(
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
smearing=self.potential.smearing,
dtype=self.dtype,
device=self.device,
)

# Compute short-range (SR) part using a real space sum
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/calculators/ewald.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__(
potential=potential,
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/calculators/p3m.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
device: Union[None, str, torch.device] = None,
):
self.mesh_spacing: float = mesh_spacing

Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/calculators/pme.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

import torch
from torch import profiler
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__(
potential=potential,
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/potentials/combined.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__(
smearing=smearing,
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/potentials/coulomb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -35,7 +35,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)

Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/potentials/inversepowerlaw.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

import torch
from torch.special import gammainc
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)

Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/potentials/potential.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__()
self.dtype = torch.get_default_dtype() if dtype is None else dtype
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/potentials/spline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__(
smearing=smearing,
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/tuning/ewald.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Any, Optional
from typing import Any, Optional, Union
from warnings import warn

import torch
Expand All @@ -20,7 +20,7 @@ def tune_ewald(
ns_hi: int = 14,
accuracy: float = 1e-3,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.EwaldCalculator`.
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/tuning/p3m.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from itertools import product
from typing import Any, Optional
from typing import Any, Optional, Union
from warnings import warn

import torch
Expand Down Expand Up @@ -80,7 +80,7 @@ def tune_p3m(
mesh_hi: int = 7,
accuracy: float = 1e-3,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.calculators.pme.PMECalculator`.
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/tuning/pme.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from itertools import product
from typing import Any, Optional
from typing import Any, Optional, Union
from warnings import warn

import torch
Expand All @@ -23,7 +23,7 @@ def tune_pme(
mesh_hi: int = 7,
accuracy: float = 1e-3,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.PMECalculator`.
Expand Down
Loading

0 comments on commit c865f64

Please sign in to comment.