diff --git a/docs/src/references/changelog.rst b/docs/src/references/changelog.rst
index 76d12dd0..ee17fdf4 100644
--- a/docs/src/references/changelog.rst
+++ b/docs/src/references/changelog.rst
@@ -27,10 +27,16 @@ changelog `_ format. This project follows
Added
#####
+* Enhanced ``device`` and ``dtype`` consistency checks throughout the library
* Better documentation for for ``cell``, ``charges`` and ``positions`` parameters
* Require consistent ``dtype`` between ``positions`` and ``neighbor_distances`` in
``Calculator`` classes and tuning functions.
+Fixed
+#####
+
+* Fix ``device`` and ``dtype`` not being used in the init of the ``P3MCalculator``
+
`Version 0.2.0 `_ - 2025-01-23
------------------------------------------------------------------------------------------
diff --git a/examples/01-charges-example.py b/examples/01-charges-example.py
index 92d82949..6465a507 100644
--- a/examples/01-charges-example.py
+++ b/examples/01-charges-example.py
@@ -39,15 +39,16 @@
import torchpme
from torchpme.tuning import tune_pme
+dtype = torch.float64
+
# %%
#
# Create the properties CsCl unit cell
-
symbols = ("Cs", "Cl")
types = torch.tensor([55, 17])
-charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64)
-positions = torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=torch.float64)
-cell = torch.eye(3, dtype=torch.float64)
+charges = torch.tensor([[1.0], [-1.0]], dtype=dtype)
+positions = torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=dtype)
+cell = torch.eye(3, dtype=dtype)
pbc = torch.tensor([True, True, True])
@@ -72,6 +73,7 @@
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
+ dtype=dtype,
)
# %%
@@ -101,7 +103,7 @@
# will be used to *compute* the potential energy of the system.
calculator = torchpme.PMECalculator(
- torchpme.CoulombPotential(smearing=smearing), **pme_params
+ torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params
)
# %%
@@ -112,7 +114,7 @@
# As a first application of multiple charge channels, we start simply by using the
# classic definition of one charge channel per atom.
-charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64)
+charges = torch.tensor([[1.0], [-1.0]], dtype=dtype)
# %%
#
@@ -160,7 +162,7 @@
# species-specific potentials and facilitating the learning process for machine learning
# algorithms.
-charges_one_hot = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float64)
+charges_one_hot = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=dtype)
# %%
#
@@ -205,7 +207,7 @@
# creating a new calculator with the metatensor interface.
calculator_metatensor = torchpme.metatensor.PMECalculator(
- torchpme.CoulombPotential(smearing=smearing), **pme_params
+ torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params
)
# %%
diff --git a/examples/02-neighbor-lists-usage.py b/examples/02-neighbor-lists-usage.py
index 322e6a69..e72689f0 100644
--- a/examples/02-neighbor-lists-usage.py
+++ b/examples/02-neighbor-lists-usage.py
@@ -55,6 +55,7 @@
#
# As a test system, we use a 2x2x2 supercell of an CsCl crystal in a cubic cell.
+dtype = torch.float64
atoms_unitcell = ase.Atoms(
symbols=["Cs", "Cl"],
positions=np.array([(0, 0, 0), (0.5, 0.5, 0.5)]),
@@ -97,7 +98,7 @@
nl = vesin.torch.NeighborList(cutoff=cutoff, full_list=False)
neighbor_indices, neighbor_distances = nl.compute(
points=positions.to(dtype=torch.float64, device="cpu"),
- box=cell.to(dtype=torch.float64, device="cpu"),
+ box=cell.to(dtype=dtype, device="cpu"),
periodic=True,
quantities="Pd",
)
@@ -109,6 +110,7 @@
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
+ dtype=dtype,
)
# %%
@@ -193,7 +195,9 @@ def distances(
# compute the potential.
pme = torchpme.PMECalculator(
- potential=torchpme.CoulombPotential(smearing=smearing), **pme_params
+ potential=torchpme.CoulombPotential(smearing=smearing, dtype=dtype),
+ dtype=dtype,
+ **pme_params,
)
potential = pme(
charges=charges,
diff --git a/examples/08-combined-potential.py b/examples/08-combined-potential.py
index df455e98..8d0667bb 100644
--- a/examples/08-combined-potential.py
+++ b/examples/08-combined-potential.py
@@ -28,6 +28,8 @@
from torchpme import CombinedPotential, EwaldCalculator, InversePowerLawPotential
from torchpme.prefactors import eV_A
+dtype = torch.float64
+
# %%
# Combined potentials
# -------------------
@@ -65,10 +67,10 @@
# evaluation, and so one has to set it also for the combined potential, even if it is
# not used explicitly in the evaluation of the combination.
-pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing)
-pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing)
+pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype)
+pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing, dtype=dtype)
-potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing)
+potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing, dtype=dtype)
# Note also that :class:`CombinedPotential` can be used with any combination of
# potentials, as long they are all either direct or range separated. For instance, one
@@ -80,7 +82,7 @@
# We now plot of the individual and combined ``potential`` functions together with an
# explicit sum of the two potentials.
-dist = torch.logspace(-3, 2, 1000)
+dist = torch.logspace(-3, 2, 1000, dtype=dtype)
fig, ax = plt.subplots()
@@ -115,7 +117,7 @@
# combines all terms in a range-separated potential, including the k-space
# kernel.
-k = torch.logspace(-2, 2, 1000)
+k = torch.logspace(-2, 2, 1000, dtype=dtype)
fig, ax = plt.subplots()
@@ -154,9 +156,8 @@
# much bigger system.
calculator = EwaldCalculator(
- potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A
+ potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A, dtype=dtype
)
-calculator.to(dtype=torch.float64)
# %%
diff --git a/examples/10-tuning.py b/examples/10-tuning.py
index c2d61881..183b0f45 100644
--- a/examples/10-tuning.py
+++ b/examples/10-tuning.py
@@ -120,7 +120,9 @@
pme_params = {"mesh_spacing": 1.0, "interpolation_nodes": 4}
pme = torchpme.PMECalculator(
- potential=torchpme.CoulombPotential(smearing=smearing),
+ potential=torchpme.CoulombPotential(smearing=smearing, device=device, dtype=dtype),
+ device=device,
+ dtype=dtype,
**pme_params, # type: ignore[arg-type]
)
@@ -168,6 +170,8 @@
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
run_backward=True,
+ device=device,
+ dtype=dtype,
)
estimated_timing = timings(pme)
@@ -210,15 +214,19 @@ def filter_neighbors(cutoff, neighbor_indices, neighbor_distances):
return neighbor_indices[filter_idx], neighbor_distances[filter_idx]
-def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
+def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device, dtype):
filter_indices, filter_distances = filter_neighbors(
cutoff, neighbor_indices, neighbor_distances
)
pme = torchpme.PMECalculator(
- potential=torchpme.CoulombPotential(smearing=smearing),
+ potential=torchpme.CoulombPotential(
+ smearing=smearing, device=device, dtype=dtype
+ ),
mesh_spacing=mesh_spacing,
interpolation_nodes=interpolation_nodes,
+ device=device,
+ dtype=dtype,
)
potential = pme(
charges=charges,
@@ -239,6 +247,8 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
run_backward=True,
n_warmup=1,
n_repeat=4,
+ device=device,
+ dtype=dtype,
)
estimated_timing = timings(pme)
return madelung, estimated_timing
@@ -251,7 +261,9 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
bounds = np.zeros((len(smearing_grid), len(spacing_grid)))
for ism, smearing in enumerate(smearing_grid):
for isp, spacing in enumerate(spacing_grid):
- results[ism, isp], timings[ism, isp] = timed_madelung(8.0, smearing, spacing, 4)
+ results[ism, isp], timings[ism, isp] = timed_madelung(
+ 8.0, smearing, spacing, 4, device, dtype
+ )
bounds[ism, isp] = error_bounds(8.0, smearing, spacing, 4)
# %%
@@ -374,7 +386,7 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
for inint, nint in enumerate(nint_grid):
for isp, spacing in enumerate(spacing_grid):
results[inint, isp], timings[inint, isp] = timed_madelung(
- 5.0, 1.0, spacing, nint
+ 5.0, 1.0, spacing, nint, device=device, dtype=dtype
)
bounds[inint, isp] = error_bounds(5.0, 1.0, spacing, nint)
@@ -445,15 +457,19 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
cutoff=5.0,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
+ device=device,
+ dtype=dtype,
)
-print(f"""
+print(
+ f"""
Estimated PME parameters (cutoff={5.0} Å):
Smearing: {smearing} Å
Mesh spacing: {parameters["mesh_spacing"]} Å
Interpolation order: {parameters["interpolation_nodes"]}
Estimated time per step: {timing} s
-""")
+"""
+)
# %%
# What is the best cutoff?
@@ -476,6 +492,8 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
cutoff=cutoff,
neighbor_indices=filter_indices,
neighbor_distances=filter_distances,
+ device=device,
+ dtype=dtype,
)
timings_grid.append(timing)
diff --git a/examples/basic-usage.py b/examples/basic-usage.py
index d726a18f..47bc1506 100644
--- a/examples/basic-usage.py
+++ b/examples/basic-usage.py
@@ -146,7 +146,7 @@
# contains all the necessary functions (such as those defining the short-range and
# long-range splits) for this potential and makes them useable in the rest of the code.
-potential = CoulombPotential(smearing=smearing)
+potential = CoulombPotential(smearing=smearing, device=device, dtype=dtype)
# %%
#
@@ -193,7 +193,9 @@
# Since our structure is relatively small, we use the :class:`EwaldCalculator`.
# We start by the initialization of the class.
-calculator = EwaldCalculator(potential=potential, lr_wavelength=lr_wavelength)
+calculator = EwaldCalculator(
+ potential=potential, lr_wavelength=lr_wavelength, device=device, dtype=dtype
+)
# %%
#
diff --git a/src/torchpme/_utils.py b/src/torchpme/_utils.py
index 575e9edc..8d42e45c 100644
--- a/src/torchpme/_utils.py
+++ b/src/torchpme/_utils.py
@@ -10,9 +10,26 @@ def _validate_parameters(
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
+ dtype: torch.dtype,
+ device: Union[str, torch.device],
) -> None:
- device = positions.device
- dtype = positions.dtype
+ if positions.dtype != dtype:
+ raise TypeError(
+ f"type of `positions` ({positions.dtype}) must be same as the class "
+ f"type ({dtype})"
+ )
+
+ if isinstance(device, torch.device):
+ device = device.type
+
+ if positions.device.type != device:
+ raise ValueError(
+ f"device of `positions` ({positions.device}) must be same as the class "
+ f"device ({device})"
+ )
+
+ # We use `positions.device` because it includes the device type AND index, which the
+ # `device` parameter may lack
# check shape, dtype and device of positions
num_atoms = len(positions)
@@ -29,14 +46,14 @@ def _validate_parameters(
f"{list(cell.shape)}"
)
- if cell.dtype != dtype:
- raise ValueError(
- f"type of `cell` ({cell.dtype}) must be same as `positions` ({dtype})"
+ if cell.dtype != positions.dtype:
+ raise TypeError(
+ f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})"
)
- if cell.device != device:
+ if cell.device != positions.device:
raise ValueError(
- f"device of `cell` ({cell.device}) must be same as `positions` ({device})"
+ f"device of `cell` ({cell.device}) must be same as the class ({device})"
)
if smearing is not None and torch.equal(
@@ -63,14 +80,14 @@ def _validate_parameters(
f"{len(positions)} atoms"
)
- if charges.dtype != dtype:
- raise ValueError(
- f"type of `charges` ({charges.dtype}) must be same as `positions` ({dtype})"
+ if charges.dtype != positions.dtype:
+ raise TypeError(
+ f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})"
)
- if charges.device != device:
+ if charges.device != positions.device:
raise ValueError(
- f"device of `charges` ({charges.device}) must be same as `positions` "
+ f"device of `charges` ({charges.device}) must be same as the class "
f"({device})"
)
@@ -82,10 +99,10 @@ def _validate_parameters(
"structure"
)
- if neighbor_indices.device != device:
+ if neighbor_indices.device != positions.device:
raise ValueError(
f"device of `neighbor_indices` ({neighbor_indices.device}) must be "
- f"same as `positions` ({device})"
+ f"same as the class ({device})"
)
if neighbor_distances.shape != neighbor_indices[:, 0].shape:
@@ -95,14 +112,14 @@ def _validate_parameters(
f"{list(neighbor_indices.shape)} and {list(neighbor_distances.shape)}"
)
- if neighbor_distances.device != device:
+ if neighbor_distances.device != positions.device:
raise ValueError(
f"device of `neighbor_distances` ({neighbor_distances.device}) must be "
- f"same as `positions` ({device})"
+ f"same as the class ({device})"
)
- if neighbor_distances.dtype != dtype:
- raise ValueError(
+ if neighbor_distances.dtype != positions.dtype:
+ raise TypeError(
f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same "
- f"as `positions` ({dtype})"
+ f"as the class ({dtype})"
)
diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py
index 0959b506..a7681fea 100644
--- a/src/torchpme/calculators/calculator.py
+++ b/src/torchpme/calculators/calculator.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
from torch import profiler
@@ -37,7 +37,7 @@ def __init__(
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__()
@@ -48,17 +48,20 @@ def __init__(
self.device = torch.get_default_device() if device is None else device
self.dtype = torch.get_default_dtype() if dtype is None else dtype
- self.potential = potential
- assert self.dtype == self.potential.dtype, (
- f"Potential and Calculator must have the same dtype, got {self.dtype} and "
- f"{self.potential.dtype}"
- )
- assert self.device == self.potential.device, (
- f"Potential and Calculator must have the same device, got {self.device} and "
- f"{self.potential.device}"
- )
+ if self.dtype != potential.dtype:
+ raise TypeError(
+ f"dtype of `potential` ({potential.dtype}) must be same as of "
+ f"`calculator` ({self.dtype})"
+ )
+
+ if self.device != potential.device:
+ raise ValueError(
+ f"device of `potential` ({potential.device}) must be same as of "
+ f"`calculator` ({self.device})"
+ )
+ self.potential = potential
self.full_neighbor_list = full_neighbor_list
self.prefactor = prefactor
@@ -161,6 +164,8 @@ def forward(
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
smearing=self.potential.smearing,
+ dtype=self.dtype,
+ device=self.device,
)
# Compute short-range (SR) part using a real space sum
diff --git a/src/torchpme/calculators/ewald.py b/src/torchpme/calculators/ewald.py
index 7d213f2f..83e2dc85 100644
--- a/src/torchpme/calculators/ewald.py
+++ b/src/torchpme/calculators/ewald.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -66,7 +66,7 @@ def __init__(
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(
potential=potential,
diff --git a/src/torchpme/calculators/p3m.py b/src/torchpme/calculators/p3m.py
index 54b97bcb..f85533db 100644
--- a/src/torchpme/calculators/p3m.py
+++ b/src/torchpme/calculators/p3m.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -56,7 +56,7 @@ def __init__(
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
self.mesh_spacing: float = mesh_spacing
@@ -73,8 +73,8 @@ def __init__(
)
self.kspace_filter: P3MKSpaceFilter = P3MKSpaceFilter(
- cell=torch.eye(3),
- ns_mesh=torch.ones(3, dtype=int),
+ cell=torch.eye(3, dtype=self.dtype, device=self.device),
+ ns_mesh=torch.ones(3, dtype=int, device=self.device),
interpolation_nodes=self.interpolation_nodes,
kernel=self.potential,
mode=0, # Green's function for point-charge potentials
@@ -84,8 +84,8 @@ def __init__(
)
self.mesh_interpolator: MeshInterpolator = MeshInterpolator(
- cell=torch.eye(3),
- ns_mesh=torch.ones(3, dtype=int),
+ cell=torch.eye(3, dtype=self.dtype, device=self.device),
+ ns_mesh=torch.ones(3, dtype=int, device=self.device),
interpolation_nodes=self.interpolation_nodes,
method="P3M",
)
diff --git a/src/torchpme/calculators/pme.py b/src/torchpme/calculators/pme.py
index 95f74216..dd389812 100644
--- a/src/torchpme/calculators/pme.py
+++ b/src/torchpme/calculators/pme.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
from torch import profiler
@@ -59,7 +59,7 @@ def __init__(
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(
potential=potential,
diff --git a/src/torchpme/potentials/combined.py b/src/torchpme/potentials/combined.py
index d76a20c0..212f4744 100644
--- a/src/torchpme/potentials/combined.py
+++ b/src/torchpme/potentials/combined.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -39,7 +39,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(
smearing=smearing,
diff --git a/src/torchpme/potentials/coulomb.py b/src/torchpme/potentials/coulomb.py
index 4cde5611..1e35897c 100644
--- a/src/torchpme/potentials/coulomb.py
+++ b/src/torchpme/potentials/coulomb.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -35,7 +35,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)
diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py
index 35ff7ac7..374ab56e 100644
--- a/src/torchpme/potentials/inversepowerlaw.py
+++ b/src/torchpme/potentials/inversepowerlaw.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
from torch.special import gammainc
@@ -41,7 +41,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)
diff --git a/src/torchpme/potentials/potential.py b/src/torchpme/potentials/potential.py
index b3503a54..b6e7fb36 100644
--- a/src/torchpme/potentials/potential.py
+++ b/src/torchpme/potentials/potential.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -39,7 +39,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__()
self.dtype = torch.get_default_dtype() if dtype is None else dtype
diff --git a/src/torchpme/potentials/spline.py b/src/torchpme/potentials/spline.py
index e8ffc3c5..b58d31eb 100644
--- a/src/torchpme/potentials/spline.py
+++ b/src/torchpme/potentials/spline.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -58,7 +58,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(
smearing=smearing,
diff --git a/src/torchpme/tuning/ewald.py b/src/torchpme/tuning/ewald.py
index 9f34f45f..b5bb0ae5 100644
--- a/src/torchpme/tuning/ewald.py
+++ b/src/torchpme/tuning/ewald.py
@@ -1,5 +1,5 @@
import math
-from typing import Any, Optional
+from typing import Any, Optional, Union
from warnings import warn
import torch
@@ -20,7 +20,7 @@ def tune_ewald(
ns_hi: int = 14,
accuracy: float = 1e-3,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.EwaldCalculator`.
diff --git a/src/torchpme/tuning/p3m.py b/src/torchpme/tuning/p3m.py
index 8c43ac16..6a64230a 100644
--- a/src/torchpme/tuning/p3m.py
+++ b/src/torchpme/tuning/p3m.py
@@ -1,6 +1,6 @@
import math
from itertools import product
-from typing import Any, Optional
+from typing import Any, Optional, Union
from warnings import warn
import torch
@@ -80,7 +80,7 @@ def tune_p3m(
mesh_hi: int = 7,
accuracy: float = 1e-3,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.calculators.pme.PMECalculator`.
diff --git a/src/torchpme/tuning/pme.py b/src/torchpme/tuning/pme.py
index b66d1afa..540f3d40 100644
--- a/src/torchpme/tuning/pme.py
+++ b/src/torchpme/tuning/pme.py
@@ -1,6 +1,6 @@
import math
from itertools import product
-from typing import Any, Optional
+from typing import Any, Optional, Union
from warnings import warn
import torch
@@ -23,7 +23,7 @@ def tune_pme(
mesh_hi: int = 7,
accuracy: float = 1e-3,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.PMECalculator`.
diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py
index d25746cf..7ac463d6 100644
--- a/src/torchpme/tuning/tuner.py
+++ b/src/torchpme/tuning/tuner.py
@@ -1,6 +1,6 @@
import math
import time
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -84,13 +84,16 @@ def __init__(
calculator: type[Calculator],
exponent: int = 1,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
if exponent != 1:
raise NotImplementedError(
f"Only exponent = 1 is supported but got {exponent}."
)
+ self.device = torch.get_default_device() if device is None else device
+ self.dtype = torch.get_default_dtype() if dtype is None else dtype
+
_validate_parameters(
charges=charges,
cell=cell,
@@ -100,16 +103,15 @@ def __init__(
[1.0], device=positions.device, dtype=positions.dtype
),
smearing=1.0, # dummy value because; always have range-seperated potentials
+ dtype=self.dtype,
+ device=self.device,
)
-
self.charges = charges
self.cell = cell
self.positions = positions
self.cutoff = cutoff
self.calculator = calculator
self.exponent = exponent
- self.device = torch.get_default_device() if device is None else device
- self.dtype = torch.get_default_dtype() if dtype is None else dtype
self._prefac = 2 * float((charges**2).sum()) / math.sqrt(len(positions))
@@ -188,7 +190,7 @@ def __init__(
neighbor_distances: torch.Tensor,
exponent: int = 1,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(
charges=charges,
@@ -288,10 +290,13 @@ def __init__(
n_warmup: int = 4,
run_backward: Optional[bool] = True,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__()
+ self.dtype = torch.get_default_dtype() if dtype is None else dtype
+ self.device = torch.get_default_device() if device is None else device
+
_validate_parameters(
charges=charges,
cell=cell,
@@ -299,13 +304,13 @@ def __init__(
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
smearing=1.0, # dummy value because; always have range-seperated potentials
+ device=self.device,
+ dtype=self.dtype,
)
self.charges = charges
self.cell = cell
self.positions = positions
- self.dtype = dtype
- self.device = device
self.n_repeat = n_repeat
self.n_warmup = n_warmup
self.run_backward = run_backward
diff --git a/tests/calculators/test_calculator.py b/tests/calculators/test_calculator.py
index a6fc976a..dff72455 100644
--- a/tests/calculators/test_calculator.py
+++ b/tests/calculators/test_calculator.py
@@ -43,6 +43,35 @@ def test_compute_output_shapes():
assert result.shape == charges.shape
+def test_wrong_device_positions():
+ calculator = CalculatorTest()
+ match = r"device of `positions` \(meta\) must be same as the class device \(cpu\)"
+ with pytest.raises(ValueError, match=match):
+ calculator.forward(
+ positions=POSITIONS_1.to(device="meta"),
+ charges=CHARGES_1,
+ cell=CELL_1,
+ neighbor_indices=NEIGHBOR_INDICES,
+ neighbor_distances=NEIGHBOR_DISTANCES,
+ )
+
+
+def test_wrong_dtype_positions():
+ calculator = CalculatorTest()
+ match = (
+ r"type of `positions` \(torch.float64\) must be same as the class type "
+ r"\(torch.float32\)"
+ )
+ with pytest.raises(TypeError, match=match):
+ calculator.forward(
+ positions=POSITIONS_1.to(dtype=torch.float64),
+ charges=CHARGES_1,
+ cell=CELL_1,
+ neighbor_indices=NEIGHBOR_INDICES,
+ neighbor_distances=NEIGHBOR_DISTANCES,
+ )
+
+
# Tests for invalid shape, dtype and device of positions
def test_invalid_shape_positions():
calculator = CalculatorTest()
@@ -79,10 +108,9 @@ def test_invalid_shape_cell():
def test_invalid_dtype_cell():
calculator = CalculatorTest()
match = (
- r"type of `cell` \(torch.float64\) must be same as `positions` "
- r"\(torch.float32\)"
+ r"type of `cell` \(torch.float64\) must be same as the class \(torch.float32\)"
)
- with pytest.raises(ValueError, match=match):
+ with pytest.raises(TypeError, match=match):
calculator.forward(
positions=POSITIONS_1,
charges=CHARGES_1,
@@ -94,7 +122,7 @@ def test_invalid_dtype_cell():
def test_invalid_device_cell():
calculator = CalculatorTest()
- match = r"device of `cell` \(meta\) must be same as `positions` \(cpu\)"
+ match = r"device of `cell` \(meta\) must be same as the class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
@@ -160,10 +188,10 @@ def test_invalid_shape_charges():
def test_invalid_dtype_charges():
calculator = CalculatorTest()
match = (
- r"type of `charges` \(torch.float64\) must be same as `positions` "
+ r"type of `charges` \(torch.float64\) must be same as the class "
r"\(torch.float32\)"
)
- with pytest.raises(ValueError, match=match):
+ with pytest.raises(TypeError, match=match):
calculator.forward(
positions=POSITIONS_1,
charges=CHARGES_1.to(dtype=torch.float64),
@@ -175,7 +203,7 @@ def test_invalid_dtype_charges():
def test_invalid_device_charges():
calculator = CalculatorTest()
- match = r"device of `charges` \(meta\) must be same as `positions` \(cpu\)"
+ match = r"device of `charges` \(meta\) must be same as the class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
@@ -220,7 +248,7 @@ def test_invalid_shape_neighbor_indices_neighbor_distances():
def test_invalid_device_neighbor_indices():
calculator = CalculatorTest()
- match = r"device of `neighbor_indices` \(meta\) must be same as `positions` \(cpu\)"
+ match = r"device of `neighbor_indices` \(meta\) must be same as the class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
@@ -233,9 +261,7 @@ def test_invalid_device_neighbor_indices():
def test_invalid_device_neighbor_distances():
calculator = CalculatorTest()
- match = (
- r"device of `neighbor_distances` \(meta\) must be same as `positions` \(cpu\)"
- )
+ match = r"device of `neighbor_distances` \(meta\) must be same as the class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
@@ -250,9 +276,9 @@ def test_invalid_dtype_neighbor_distances():
calculator = CalculatorTest()
match = (
r"type of `neighbor_distances` \(torch.float64\) must be same "
- r"as `positions` \(torch.float32\)"
+ r"as the class \(torch.float32\)"
)
- with pytest.raises(ValueError, match=match):
+ with pytest.raises(TypeError, match=match):
calculator.forward(
positions=POSITIONS_1,
charges=CHARGES_1,
diff --git a/tests/calculators/test_values_direct.py b/tests/calculators/test_values_direct.py
index 47f819af..a5ace407 100644
--- a/tests/calculators/test_values_direct.py
+++ b/tests/calculators/test_values_direct.py
@@ -17,7 +17,11 @@
class CalculatorTest(Calculator):
def __init__(self, **kwargs):
super().__init__(
- potential=CoulombPotential(smearing=None, exclusion_radius=None), **kwargs
+ potential=CoulombPotential(
+ smearing=None, exclusion_radius=None, dtype=DTYPE
+ ),
+ **kwargs,
+ dtype=DTYPE,
)
diff --git a/tests/calculators/test_values_ewald.py b/tests/calculators/test_values_ewald.py
index cb3ff708..edcb886e 100644
--- a/tests/calculators/test_values_ewald.py
+++ b/tests/calculators/test_values_ewald.py
@@ -87,7 +87,9 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
to triclinic, as well as cation-anion ratios of 1:1, 1:2 and 2:1.
"""
# Get input parameters and adjust to account for scaling
- pos, charges, cell, madelung_ref, num_units = define_crystal(crystal_name)
+ pos, charges, cell, madelung_ref, num_units = define_crystal(
+ crystal_name, dtype=DTYPE
+ )
pos *= scaling_factor
cell *= scaling_factor
madelung_ref /= scaling_factor
@@ -99,11 +101,9 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
smearing = sr_cutoff / 5.0
lr_wavelength = 0.5 * smearing
calc = EwaldCalculator(
- InversePowerLawPotential(
- exponent=1,
- smearing=smearing,
- ),
+ InversePowerLawPotential(exponent=1, smearing=smearing, dtype=DTYPE),
lr_wavelength=lr_wavelength,
+ dtype=DTYPE,
)
rtol = 4e-6
elif calc_name == "pme":
@@ -113,16 +113,19 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
InversePowerLawPotential(
exponent=1,
smearing=smearing,
+ dtype=DTYPE,
),
mesh_spacing=smearing / 8,
+ dtype=DTYPE,
)
rtol = 9e-4
elif calc_name == "p3m":
sr_cutoff = 2 * scaling_factor
smearing = sr_cutoff / 5.0
calc = P3MCalculator(
- CoulombPotential(smearing=smearing),
+ CoulombPotential(smearing=smearing, dtype=DTYPE),
mesh_spacing=smearing / 8,
+ dtype=DTYPE,
)
rtol = 9e-4
@@ -132,7 +135,6 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
)
# Compute potential and compare against target value using default hypers
- calc.to(dtype=DTYPE)
potentials = calc.forward(
positions=pos,
charges=charges,
@@ -186,26 +188,22 @@ def test_wigner(crystal_name, scaling_factor):
# The first value of 0.1 corresponds to what would be
# chosen by default for the "wigner_sc" or "wigner_bcc_cubiccell" structure.
- smearings = torch.tensor([0.1, 0.06, 0.019], dtype=torch.float64)
- for smearing in smearings:
+ for smearing in [0.1, 0.06, 0.019]:
# Readjust smearing parameter to match nearest neighbor distance
if crystal_name in ["wigner_fcc", "wigner_fcc_cubiccell"]:
- smeareff = float(smearing) / np.sqrt(2)
+ smeareff = smearing / np.sqrt(2)
elif crystal_name in ["wigner_bcc_cubiccell", "wigner_bcc"]:
- smeareff = float(smearing) * np.sqrt(3) / 2
+ smeareff = smearing * np.sqrt(3) / 2
elif crystal_name == "wigner_sc":
- smeareff = float(smearing)
+ smeareff = smearing
smeareff *= scaling_factor
# Compute potential and compare against reference
calc = EwaldCalculator(
- InversePowerLawPotential(
- exponent=1,
- smearing=smeareff,
- ),
+ InversePowerLawPotential(exponent=1, smearing=smeareff, dtype=DTYPE),
lr_wavelength=smeareff / 2,
+ dtype=DTYPE,
)
- calc.to(dtype=DTYPE)
potentials = calc.forward(
positions=positions,
charges=charges,
@@ -253,25 +251,28 @@ def test_random_structure(
if calc_name == "ewald":
calc = EwaldCalculator(
- CoulombPotential(smearing=smearing),
+ CoulombPotential(smearing=smearing, dtype=DTYPE),
lr_wavelength=0.5 * smearing,
full_neighbor_list=full_neighbor_list,
prefactor=torchpme.prefactors.eV_A,
+ dtype=DTYPE,
)
elif calc_name == "pme":
calc = PMECalculator(
- CoulombPotential(smearing=smearing),
+ CoulombPotential(smearing=smearing, dtype=DTYPE),
mesh_spacing=smearing / 8.0,
full_neighbor_list=full_neighbor_list,
prefactor=torchpme.prefactors.eV_A,
+ dtype=DTYPE,
)
elif calc_name == "p3m":
calc = P3MCalculator(
- CoulombPotential(smearing=smearing),
+ CoulombPotential(smearing=smearing, dtype=DTYPE),
mesh_spacing=smearing / 8.0,
full_neighbor_list=full_neighbor_list,
prefactor=torchpme.prefactors.eV_A,
+ dtype=DTYPE,
)
neighbor_indices, neighbor_shifts = neighbor_list(
@@ -294,7 +295,6 @@ def test_random_structure(
neighbor_shifts=neighbor_shifts,
)
- calc.to(dtype=DTYPE)
potentials = calc.forward(
positions=positions,
charges=charges,
diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py
index 83377006..209355d4 100644
--- a/tests/calculators/test_workflow.py
+++ b/tests/calculators/test_workflow.py
@@ -4,7 +4,6 @@
"""
import io
-import math
import pytest
import torch
@@ -17,15 +16,15 @@
PMECalculator,
)
-AVAILABLE_DEVICES = ["cpu"] + torch.cuda.is_available() * ["cuda"]
-MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3))
-CHARGES_CSCL = torch.tensor([1.0, -1.0])
+DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"]
+DTYPES = [torch.float32, torch.float64]
SMEARING = 0.1
LR_WAVELENGTH = SMEARING / 4
MESH_SPACING = SMEARING / 4
-@pytest.mark.parametrize("device", AVAILABLE_DEVICES)
+@pytest.mark.parametrize("device", DEVICES)
+@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
("CalculatorClass", "params"),
[
@@ -59,125 +58,128 @@
],
)
class TestWorkflow:
- def cscl_system(self, device=None):
+ def cscl_system(self, device=None, dtype=None):
"""CsCl crystal. Same as in the madelung test"""
- if device is None:
- device = torch.device("cpu")
-
- positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]])
- charges = torch.tensor([1.0, -1.0]).reshape((-1, 1))
- cell = torch.eye(3)
- neighbor_indices = torch.tensor([[0, 1]], dtype=torch.int64)
- neighbor_distances = torch.tensor([0.8660])
-
- return (
- charges.to(device=device),
- cell.to(device=device),
- positions.to(device=device),
- neighbor_indices.to(device=device),
- neighbor_distances.to(device=device),
+ device = torch.get_default_device() if device is None else device
+ dtype = torch.get_default_dtype() if dtype is None else dtype
+
+ positions = torch.tensor(
+ [[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype, device=device
)
+ charges = torch.tensor([1.0, -1.0], dtype=dtype, device=device).reshape((-1, 1))
+ cell = torch.eye(3, dtype=dtype, device=device)
+ neighbor_indices = torch.tensor([[0, 1]], dtype=torch.int64, device=device)
+ neighbor_distances = torch.tensor([0.8660], dtype=dtype, device=device)
+
+ return charges, cell, positions, neighbor_indices, neighbor_distances
- def test_smearing_non_positive(self, CalculatorClass, params, device):
+ def test_smearing_non_positive(self, CalculatorClass, params, device, dtype):
params = params.copy()
match = r"`smearing` .* has to be positive"
if type(CalculatorClass) in [EwaldCalculator, PMECalculator]:
params["smearing"] = 0
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device)
+ CalculatorClass(**params, device=device, dtype=dtype)
params["smearing"] = -0.1
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device)
+ CalculatorClass(**params, device=device, dtype=dtype)
- def test_interpolation_order_error(self, CalculatorClass, params, device):
+ def test_interpolation_order_error(self, CalculatorClass, params, device, dtype):
params = params.copy()
if type(CalculatorClass) in [PMECalculator]:
match = "Only `interpolation_nodes` from 1 to 5"
params["interpolation_nodes"] = 10
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device)
+ CalculatorClass(**params, device=device, dtype=dtype)
- def test_lr_wavelength_non_positive(self, CalculatorClass, params, device):
+ def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype):
params = params.copy()
match = r"`lr_wavelength` .* has to be positive"
if type(CalculatorClass) in [EwaldCalculator]:
params["lr_wavelength"] = 0
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device)
+ CalculatorClass(**params, device=device, dtype=dtype)
params["lr_wavelength"] = -0.1
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device)
+ CalculatorClass(**params, device=device, dtype=dtype)
- def test_dtype_device(self, CalculatorClass, params, device):
+ def test_dtype_device(self, CalculatorClass, params, device, dtype):
"""Test that the output dtype and device are the same as the input."""
- dtype = torch.float64
params = params.copy()
- positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device)
- charges = torch.ones((1, 2), dtype=dtype, device=device)
- cell = torch.eye(3, dtype=dtype, device=device)
- neighbor_indices = torch.tensor([[0, 0]], device=device)
- neighbor_distances = torch.tensor([0.1], dtype=dtype, device=device)
params["potential"].device = device
- calculator = CalculatorClass(**params, device=device)
- potential = calculator.forward(
- charges=charges,
- cell=cell,
- positions=positions,
- neighbor_indices=neighbor_indices,
- neighbor_distances=neighbor_distances,
- )
+ params["potential"].dtype = dtype
+
+ calculator = CalculatorClass(**params, device=device, dtype=dtype)
+ potential = calculator.forward(*self.cscl_system(device=device, dtype=dtype))
assert potential.dtype == dtype
- assert potential.device.type == device
- def check_operation(self, calculator, device):
+ if isinstance(device, torch.device):
+ assert potential.device == device
+ else:
+ assert potential.device.type == device
+
+ def check_operation(self, calculator, device, dtype):
"""Make sure computation runs and returns a torch.Tensor."""
- descriptor = calculator.forward(*self.cscl_system(device))
+ descriptor = calculator.forward(*self.cscl_system(device=device, dtype=dtype))
assert type(descriptor) is torch.Tensor
- def test_operation_as_python(self, CalculatorClass, params, device):
+ def test_operation_as_python(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a normal python script"""
params = params.copy()
params["potential"].device = device
- calculator = CalculatorClass(**params, device=device)
- self.check_operation(calculator=calculator, device=device)
+ params["potential"].dtype = dtype
+
+ calculator = CalculatorClass(**params, device=device, dtype=dtype)
+ self.check_operation(calculator=calculator, device=device, dtype=dtype)
- def test_operation_as_torch_script(self, CalculatorClass, params, device):
+ def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a compiled torch script module."""
params = params.copy()
params["potential"].device = device
- calculator = CalculatorClass(**params, device=device)
+ params["potential"].dtype = dtype
+
+ calculator = CalculatorClass(**params, device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
- self.check_operation(calculator=scripted, device=device)
+ self.check_operation(calculator=scripted, device=device, dtype=dtype)
- def test_save_load(self, CalculatorClass, params, device):
+ def test_save_load(self, CalculatorClass, params, device, dtype):
params = params.copy()
params["potential"].device = device
- calculator = CalculatorClass(**params, device=device)
+ params["potential"].dtype = dtype
+
+ calculator = CalculatorClass(**params, device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
with io.BytesIO() as buffer:
torch.jit.save(scripted, buffer)
buffer.seek(0)
torch.jit.load(buffer)
- def test_prefactor(self, CalculatorClass, params, device):
+ def test_prefactor(self, CalculatorClass, params, device, dtype):
"""Test if the prefactor is applied correctly."""
params = params.copy()
params["potential"].device = device
+ params["potential"].dtype = dtype
+
prefactor = 2.0
- calculator1 = CalculatorClass(**params, device=device)
- calculator2 = CalculatorClass(**params, prefactor=prefactor, device=device)
- potentials1 = calculator1.forward(*self.cscl_system())
- potentials2 = calculator2.forward(*self.cscl_system())
+ calculator1 = CalculatorClass(**params, device=device, dtype=dtype)
+ calculator2 = CalculatorClass(
+ **params, prefactor=prefactor, device=device, dtype=dtype
+ )
+
+ potentials1 = calculator1.forward(*self.cscl_system(device=device, dtype=dtype))
+ potentials2 = calculator2.forward(*self.cscl_system(device=device, dtype=dtype))
+
assert torch.allclose(potentials1 * prefactor, potentials2)
- def test_not_nan(self, CalculatorClass, params, device):
+ def test_not_nan(self, CalculatorClass, params, device, dtype):
"""Make sure derivatives are not NaN."""
params = params.copy()
params["potential"].device = device
+ params["potential"].dtype = dtype
- calculator = CalculatorClass(**params, device=device)
- system = self.cscl_system(device)
+ calculator = CalculatorClass(**params, device=device, dtype=dtype)
+ system = self.cscl_system(device=device, dtype=dtype)
system[0].requires_grad = True
system[1].requires_grad = True
system[2].requires_grad = True
@@ -203,26 +205,45 @@ def test_not_nan(self, CalculatorClass, params, device):
torch.autograd.grad(energy, system[2], retain_graph=True)[0]
).any()
- def test_dtype_and_device_incompatability(self, CalculatorClass, params, device):
- """Test that the calculator raises an error if the dtype and device are incompatible."""
+ def test_dtype_and_device_incompatability(
+ self, CalculatorClass, params, device, dtype
+ ):
+ """Test that the calculator raises an error if the dtype and device are incompatible with potential."""
params = params.copy()
+
+ other_dtype = torch.float32 if dtype == torch.float64 else torch.float64
+
params["potential"].device = device
- params["potential"].dtype = torch.float64
- with pytest.raises(AssertionError, match=".*dtype.*"):
- CalculatorClass(**params, dtype=torch.float32, device=device)
- with pytest.raises(AssertionError, match=".*device.*"):
- CalculatorClass(
- **params, dtype=params["potential"].dtype, device=torch.device("meta")
- )
+ params["potential"].dtype = dtype
+
+ match = (
+ rf"dtype of `potential` \({params['potential'].dtype}\) must be same as "
+ rf"of `calculator` \({other_dtype}\)"
+ )
+ with pytest.raises(TypeError, match=match):
+ CalculatorClass(**params, dtype=other_dtype, device=device)
+
+ match = (
+ rf"device of `potential` \({params['potential'].device}\) must be same as "
+ rf"of `calculator` \(meta\)"
+ )
+ with pytest.raises(ValueError, match=match):
+ CalculatorClass(**params, dtype=dtype, device=torch.device("meta"))
def test_potential_and_calculator_incompatability(
- self, CalculatorClass, params, device
+ self,
+ CalculatorClass,
+ params,
+ device,
+ dtype,
):
"""Test that the calculator raises an error if the potential and calculator are incompatible."""
params = params.copy()
params["potential"].device = device
+ params["potential"].dtype = dtype
+
params["potential"] = torch.jit.script(params["potential"])
with pytest.raises(
TypeError, match="Potential must be an instance of Potential, got.*"
):
- CalculatorClass(**params)
+ CalculatorClass(**params, device=device, dtype=dtype)
diff --git a/tests/helpers.py b/tests/helpers.py
index f970f596..bca5feb7 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -227,14 +227,13 @@ def define_crystal(crystal_name="CsCl", dtype=None, device=None):
else:
raise ValueError(f"crystal_name = {crystal_name} is not supported!")
- madelung_ref = torch.tensor(madelung_ref)
charges = charges.reshape((-1, 1))
return (
positions.to(device=device, dtype=dtype),
charges.to(device=device, dtype=dtype),
cell.to(device=device, dtype=dtype),
- madelung_ref,
+ torch.tensor(madelung_ref, device=device, dtype=dtype),
num_formula_units,
)
@@ -257,7 +256,10 @@ def neighbor_list(
nl = NeighborList(cutoff=cutoff, full_list=full_neighbor_list)
neighbor_indices, d, S = nl.compute(
- points=positions, box=box, periodic=periodic, quantities="PdS"
+ points=positions.to(dtype=torch.float64, device="cpu"),
+ box=box.to(dtype=torch.float64, device="cpu"),
+ periodic=periodic,
+ quantities="PdS",
)
neighbor_indices = torch.from_numpy(neighbor_indices.astype(int)).to(
diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py
index 025bd91c..6beed927 100644
--- a/tests/metatensor/test_workflow_metatensor.py
+++ b/tests/metatensor/test_workflow_metatensor.py
@@ -11,14 +11,15 @@
mts_torch = pytest.importorskip("metatensor.torch")
mts_atomistic = pytest.importorskip("metatensor.torch.atomistic")
-AVAILABLE_DEVICES = [torch.device("cpu")] + torch.cuda.is_available() * [
- torch.device("cuda")
-]
+DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"]
+DTYPES = [torch.float32, torch.float64]
SMEARING = 0.1
LR_WAVELENGTH = SMEARING / 4
MESH_SPACING = SMEARING / 4
+@pytest.mark.parametrize("device", DEVICES)
+@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
("CalculatorClass", "params"),
[
@@ -52,7 +53,10 @@
],
)
class TestWorkflow:
- def system(self, device=None):
+ def system(self, device=None, dtype=None):
+ device = torch.get_default_device() if device is None else device
+ dtype = torch.get_default_dtype() if dtype is None else dtype
+
system = mts_atomistic.System(
types=torch.tensor([1, 2, 2]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.2], [0.0, 0.0, 0.5]]),
@@ -92,30 +96,32 @@ def system(self, device=None):
return system.to(device=device), neighbors.to(device=device)
- def check_operation(self, calculator, device):
+ def check_operation(self, calculator, device, dtype):
"""Make sure computation runs and returns a metatensor.TensorMap."""
- system, neighbors = self.system(device)
+ system, neighbors = self.system(device=device, dtype=dtype)
descriptor = calculator.forward(system, neighbors)
assert isinstance(descriptor, torch.ScriptObject)
if version.parse(torch.__version__) >= version.parse("2.1"):
assert descriptor._type().name() == "TensorMap"
- @pytest.mark.parametrize("device", AVAILABLE_DEVICES)
- def test_operation_as_python(self, CalculatorClass, params, device):
+ def test_operation_as_python(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a normal python script"""
calculator = CalculatorClass(**params)
- self.check_operation(calculator=calculator, device=device)
+ self.check_operation(calculator=calculator, device=device, dtype=dtype)
- @pytest.mark.parametrize("device", AVAILABLE_DEVICES)
- def test_operation_as_torch_script(self, CalculatorClass, params, device):
+ def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a compiled torch script module."""
calculator = CalculatorClass(**params)
scripted = torch.jit.script(calculator)
- self.check_operation(calculator=scripted, device=device)
+ self.check_operation(calculator=scripted, device=device, dtype=dtype)
- def test_save_load(self, CalculatorClass, params):
- calculator = CalculatorClass(**params)
+ def test_save_load(self, CalculatorClass, params, device, dtype):
+ params = params.copy()
+ params["potential"].device = device
+ params["potential"].dtype = dtype
+
+ calculator = CalculatorClass(**params, device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
with io.BytesIO() as buffer:
torch.jit.save(scripted, buffer)
diff --git a/tests/tuning/test_tuning.py b/tests/tuning/test_tuning.py
index 17a43580..3d1ce03f 100644
--- a/tests/tuning/test_tuning.py
+++ b/tests/tuning/test_tuning.py
@@ -16,31 +16,46 @@
sys.path.append(str(Path(__file__).parents[1]))
from helpers import define_crystal, neighbor_list
-DTYPE = torch.float32
-DEVICE = "cpu"
DEFAULT_CUTOFF = 4.4
-CHARGES_1 = torch.ones((4, 1), dtype=DTYPE, device=DEVICE)
-POSITIONS_1 = 0.3 * torch.arange(12, dtype=DTYPE, device=DEVICE).reshape((4, 3))
-CELL_1 = torch.eye(3, dtype=DTYPE, device=DEVICE)
+DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"]
+DTYPES = [torch.float32, torch.float64]
-def test_TunerBase_double():
+def system(device=None, dtype=None):
+ device = torch.get_default_device() if device is None else device
+ dtype = torch.get_default_dtype() if dtype is None else dtype
+
+ charges = torch.ones((4, 1), dtype=dtype, device=device)
+ cell = torch.eye(3, dtype=dtype, device=device)
+ positions = 0.3 * torch.arange(12, dtype=dtype, device=device).reshape((4, 3))
+
+ return charges, cell, positions
+
+
+@pytest.mark.parametrize("device", DEVICES)
+@pytest.mark.parametrize("dtype", DTYPES)
+def test_TunerBase_init(device, dtype):
"""
- Check that `TunerBase` initilizes with double precisions tensors.
+ Check that `TunerBase` initilizes correctly.
We are using dummy `neighbor_indices` and `neighbor_distances` to verify types. Have
to be sure that these dummy variables are initilized correctly.
"""
+ charges, cell, positions = system(device, dtype)
TunerBase(
- charges=CHARGES_1.to(dtype=torch.float64),
- cell=CELL_1.to(dtype=torch.float64),
- positions=POSITIONS_1.to(dtype=torch.float64),
+ charges=charges,
+ cell=cell,
+ positions=positions,
cutoff=DEFAULT_CUTOFF,
calculator=1.0,
exponent=1,
+ dtype=dtype,
+ device=device,
)
+@pytest.mark.parametrize("device", DEVICES)
+@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
("calculator", "tune", "param_length"),
[
@@ -50,13 +65,15 @@ def test_TunerBase_double():
],
)
@pytest.mark.parametrize("accuracy", [1e-1, 1e-3, 1e-5])
-def test_parameter_choose(calculator, tune, param_length, accuracy):
+def test_parameter_choose(device, dtype, calculator, tune, param_length, accuracy):
"""
Check that the Madelung constants obtained from the Ewald sum calculator matches
the reference values and that all branches of the from_accuracy method are covered.
"""
# Get input parameters and adjust to account for scaling
- pos, charges, cell, madelung_ref, num_units = define_crystal()
+ pos, charges, cell, madelung_ref, num_units = define_crystal(
+ dtype=dtype, device=device
+ )
# Compute neighbor list
neighbor_indices, neighbor_distances = neighbor_list(
@@ -71,13 +88,17 @@ def test_parameter_choose(calculator, tune, param_length, accuracy):
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
accuracy=accuracy,
+ dtype=dtype,
+ device=device,
)
assert len(params) == param_length
# Compute potential and compare against target value using default hypers
calc = calculator(
- potential=(CoulombPotential(smearing=smearing)),
+ potential=(CoulombPotential(smearing=smearing, dtype=dtype, device=device)),
+ dtype=dtype,
+ device=device,
**params,
)
potentials = calc.forward(
@@ -103,12 +124,12 @@ def test_accuracy_error(tune):
)
with pytest.raises(ValueError, match=match):
tune(
- charges,
- cell,
- pos,
- DEFAULT_CUTOFF,
- neighbor_indices,
- neighbor_distances,
+ charges=charges,
+ cell=cell,
+ positions=pos,
+ cutoff=DEFAULT_CUTOFF,
+ neighbor_indices=neighbor_indices,
+ neighbor_distances=neighbor_distances,
accuracy="foo",
)
@@ -116,79 +137,89 @@ def test_accuracy_error(tune):
@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m])
def test_exponent_not_1_error(tune):
pos, charges, cell, _, _ = define_crystal()
-
- match = "Only exponent = 1 is supported but got 2."
neighbor_indices, neighbor_distances = neighbor_list(
positions=pos, box=cell, cutoff=DEFAULT_CUTOFF
)
+
+ match = "Only exponent = 1 is supported but got 2."
with pytest.raises(NotImplementedError, match=match):
tune(
- charges,
- cell,
- pos,
- DEFAULT_CUTOFF,
- neighbor_indices,
- neighbor_distances,
+ charges=charges,
+ cell=cell,
+ positions=pos,
+ cutoff=DEFAULT_CUTOFF,
+ neighbor_indices=neighbor_indices,
+ neighbor_distances=neighbor_distances,
exponent=2,
)
@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m])
def test_invalid_shape_positions(tune):
+ charges, cell, _ = system()
match = (
r"`positions` must be a tensor with shape \[n_atoms, 3\], got tensor with "
r"shape \[4, 5\]"
)
with pytest.raises(ValueError, match=match):
tune(
- CHARGES_1,
- CELL_1,
- torch.ones((4, 5), dtype=DTYPE, device=DEVICE),
- DEFAULT_CUTOFF,
- None, # dummy neighbor indices
- None, # dummy neighbor distances
+ charges=charges,
+ cell=cell,
+ positions=torch.ones((4, 5)),
+ cutoff=DEFAULT_CUTOFF,
+ neighbor_indices=None,
+ neighbor_distances=None,
)
# Tests for invalid shape, dtype and device of cell
@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m])
def test_invalid_shape_cell(tune):
+ charges, _, positions = system()
match = (
r"`cell` must be a tensor with shape \[3, 3\], got tensor with shape \[2, 2\]"
)
with pytest.raises(ValueError, match=match):
tune(
- CHARGES_1,
- torch.ones([2, 2], dtype=DTYPE, device=DEVICE),
- POSITIONS_1,
- DEFAULT_CUTOFF,
- None, # dummy neighbor indices
- None, # dummy neighbor distances
+ charges=charges,
+ cell=torch.ones([2, 2]),
+ positions=positions,
+ cutoff=DEFAULT_CUTOFF,
+ neighbor_indices=None,
+ neighbor_distances=None,
)
@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m])
def test_invalid_cell(tune):
+ charges, _, positions = system()
match = (
"provided `cell` has a determinant of 0 and therefore is not valid for "
"periodic calculation"
)
with pytest.raises(ValueError, match=match):
- tune(CHARGES_1, torch.zeros(3, 3), POSITIONS_1, DEFAULT_CUTOFF, None, None)
+ tune(
+ charges=charges,
+ cell=torch.zeros(3, 3),
+ positions=positions,
+ cutoff=DEFAULT_CUTOFF,
+ neighbor_indices=None,
+ neighbor_distances=None,
+ )
@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m])
def test_invalid_dtype_cell(tune):
+ charges, _, positions = system()
match = (
- r"type of `cell` \(torch.float64\) must be same as `positions` "
- r"\(torch.float32\)"
+ r"type of `cell` \(torch.float64\) must be same as the class \(torch.float32\)"
)
- with pytest.raises(ValueError, match=match):
+ with pytest.raises(TypeError, match=match):
tune(
- CHARGES_1,
- torch.eye(3, dtype=torch.float64, device=DEVICE),
- POSITIONS_1,
- DEFAULT_CUTOFF,
- None,
- None,
+ charges=charges,
+ cell=torch.eye(3, dtype=torch.float64),
+ positions=positions,
+ cutoff=DEFAULT_CUTOFF,
+ neighbor_indices=None,
+ neighbor_distances=None,
)