diff --git a/docs/src/references/changelog.rst b/docs/src/references/changelog.rst
index 5631f33c..d927b525 100644
--- a/docs/src/references/changelog.rst
+++ b/docs/src/references/changelog.rst
@@ -27,9 +27,17 @@ changelog `_ format. This project follows
Added
#####
+* Enhanced ``device`` and ``dtype`` consistency checks throughout the library
* Require consistent ``dtype`` between ``positions`` and ``neighbor_distances`` in
``Calculator`` classes and tuning functions.
+
+Fixed
+#####
+
+* Fix ``device`` and ``dtype`` not being used in the init of the ``P3MCalculator``
+
+
`Version 0.2.0 `_ - 2025-01-23
------------------------------------------------------------------------------------------
diff --git a/src/torchpme/_utils.py b/src/torchpme/_utils.py
index 575e9edc..f282a626 100644
--- a/src/torchpme/_utils.py
+++ b/src/torchpme/_utils.py
@@ -10,9 +10,23 @@ def _validate_parameters(
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
+ dtype: torch.dtype,
+ device: Union[str, torch.device],
) -> None:
- device = positions.device
- dtype = positions.dtype
+ if positions.dtype != dtype:
+ raise TypeError(
+ f"type of `positions` ({positions.dtype}) must be same as class "
+ f"type ({dtype})"
+ )
+
+ if isinstance(device, torch.device):
+ device = device.type
+
+ if positions.device.type != device:
+ raise ValueError(
+ f"device of `positions` ({positions.device}) must be same as class "
+ f"device ({device})"
+ )
# check shape, dtype and device of positions
num_atoms = len(positions)
@@ -29,12 +43,12 @@ def _validate_parameters(
f"{list(cell.shape)}"
)
- if cell.dtype != dtype:
- raise ValueError(
+ if cell.dtype != positions.dtype:
+ raise TypeError(
f"type of `cell` ({cell.dtype}) must be same as `positions` ({dtype})"
)
- if cell.device != device:
+ if cell.device != positions.device:
raise ValueError(
f"device of `cell` ({cell.device}) must be same as `positions` ({device})"
)
@@ -63,12 +77,12 @@ def _validate_parameters(
f"{len(positions)} atoms"
)
- if charges.dtype != dtype:
- raise ValueError(
+ if charges.dtype != positions.dtype:
+ raise TypeError(
f"type of `charges` ({charges.dtype}) must be same as `positions` ({dtype})"
)
- if charges.device != device:
+ if charges.device != positions.device:
raise ValueError(
f"device of `charges` ({charges.device}) must be same as `positions` "
f"({device})"
@@ -82,7 +96,7 @@ def _validate_parameters(
"structure"
)
- if neighbor_indices.device != device:
+ if neighbor_indices.device != positions.device:
raise ValueError(
f"device of `neighbor_indices` ({neighbor_indices.device}) must be "
f"same as `positions` ({device})"
@@ -95,14 +109,14 @@ def _validate_parameters(
f"{list(neighbor_indices.shape)} and {list(neighbor_distances.shape)}"
)
- if neighbor_distances.device != device:
+ if neighbor_distances.device != positions.device:
raise ValueError(
f"device of `neighbor_distances` ({neighbor_distances.device}) must be "
f"same as `positions` ({device})"
)
- if neighbor_distances.dtype != dtype:
- raise ValueError(
+ if neighbor_distances.dtype != positions.dtype:
+ raise TypeError(
f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same "
f"as `positions` ({dtype})"
)
diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py
index 70ede3ef..29eaadff 100644
--- a/src/torchpme/calculators/calculator.py
+++ b/src/torchpme/calculators/calculator.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
from torch import profiler
@@ -37,7 +37,7 @@ def __init__(
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__()
@@ -48,17 +48,20 @@ def __init__(
self.device = torch.get_default_device() if device is None else device
self.dtype = torch.get_default_dtype() if dtype is None else dtype
- self.potential = potential
- assert self.dtype == self.potential.dtype, (
- f"Potential and Calculator must have the same dtype, got {self.dtype} and "
- f"{self.potential.dtype}"
- )
- assert self.device == self.potential.device, (
- f"Potential and Calculator must have the same device, got {self.device} and "
- f"{self.potential.device}"
- )
+ if self.dtype != potential.dtype:
+ raise TypeError(
+ f"dtype of `potential` ({potential.dtype}) must be same as of "
+ f"`calculator` ({self.dtype})"
+ )
+
+ if self.device != potential.device:
+ raise ValueError(
+ f"device of `potential` ({potential.device}) must be same as of "
+ f"`calculator` ({self.device})"
+ )
+ self.potential = potential
self.full_neighbor_list = full_neighbor_list
self.prefactor = prefactor
@@ -179,6 +182,8 @@ def forward(
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
smearing=self.potential.smearing,
+ dtype=self.dtype,
+ device=self.device,
)
# Compute short-range (SR) part using a real space sum
diff --git a/src/torchpme/calculators/ewald.py b/src/torchpme/calculators/ewald.py
index 7d213f2f..83e2dc85 100644
--- a/src/torchpme/calculators/ewald.py
+++ b/src/torchpme/calculators/ewald.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -66,7 +66,7 @@ def __init__(
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(
potential=potential,
diff --git a/src/torchpme/calculators/p3m.py b/src/torchpme/calculators/p3m.py
index eb23c780..f85533db 100644
--- a/src/torchpme/calculators/p3m.py
+++ b/src/torchpme/calculators/p3m.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -56,7 +56,7 @@ def __init__(
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
self.mesh_spacing: float = mesh_spacing
diff --git a/src/torchpme/calculators/pme.py b/src/torchpme/calculators/pme.py
index 95f74216..dd389812 100644
--- a/src/torchpme/calculators/pme.py
+++ b/src/torchpme/calculators/pme.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
from torch import profiler
@@ -59,7 +59,7 @@ def __init__(
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(
potential=potential,
diff --git a/src/torchpme/potentials/combined.py b/src/torchpme/potentials/combined.py
index d76a20c0..212f4744 100644
--- a/src/torchpme/potentials/combined.py
+++ b/src/torchpme/potentials/combined.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -39,7 +39,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(
smearing=smearing,
diff --git a/src/torchpme/potentials/coulomb.py b/src/torchpme/potentials/coulomb.py
index 4cde5611..1e35897c 100644
--- a/src/torchpme/potentials/coulomb.py
+++ b/src/torchpme/potentials/coulomb.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -35,7 +35,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)
diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py
index 35ff7ac7..374ab56e 100644
--- a/src/torchpme/potentials/inversepowerlaw.py
+++ b/src/torchpme/potentials/inversepowerlaw.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
from torch.special import gammainc
@@ -41,7 +41,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)
diff --git a/src/torchpme/potentials/potential.py b/src/torchpme/potentials/potential.py
index 674a8632..1efa783d 100644
--- a/src/torchpme/potentials/potential.py
+++ b/src/torchpme/potentials/potential.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -39,7 +39,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__()
self.dtype = torch.get_default_dtype() if dtype is None else dtype
diff --git a/src/torchpme/potentials/spline.py b/src/torchpme/potentials/spline.py
index e8ffc3c5..b58d31eb 100644
--- a/src/torchpme/potentials/spline.py
+++ b/src/torchpme/potentials/spline.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -58,7 +58,7 @@ def __init__(
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(
smearing=smearing,
diff --git a/src/torchpme/tuning/ewald.py b/src/torchpme/tuning/ewald.py
index 6452fa7d..459f0a02 100644
--- a/src/torchpme/tuning/ewald.py
+++ b/src/torchpme/tuning/ewald.py
@@ -1,5 +1,5 @@
import math
-from typing import Any, Optional
+from typing import Any, Optional, Union
from warnings import warn
import torch
@@ -20,7 +20,7 @@ def tune_ewald(
ns_hi: int = 14,
accuracy: float = 1e-3,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.EwaldCalculator`.
diff --git a/src/torchpme/tuning/p3m.py b/src/torchpme/tuning/p3m.py
index 5685ffaf..92a6bbe0 100644
--- a/src/torchpme/tuning/p3m.py
+++ b/src/torchpme/tuning/p3m.py
@@ -1,6 +1,6 @@
import math
from itertools import product
-from typing import Any, Optional
+from typing import Any, Optional, Union
from warnings import warn
import torch
@@ -80,7 +80,7 @@ def tune_p3m(
mesh_hi: int = 7,
accuracy: float = 1e-3,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.calculators.pme.PMECalculator`.
diff --git a/src/torchpme/tuning/pme.py b/src/torchpme/tuning/pme.py
index 55a0556e..f9ddc4d8 100644
--- a/src/torchpme/tuning/pme.py
+++ b/src/torchpme/tuning/pme.py
@@ -1,6 +1,6 @@
import math
from itertools import product
-from typing import Any, Optional
+from typing import Any, Optional, Union
from warnings import warn
import torch
@@ -23,7 +23,7 @@ def tune_pme(
mesh_hi: int = 7,
accuracy: float = 1e-3,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.PMECalculator`.
diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py
index 8461fda5..354d6029 100644
--- a/src/torchpme/tuning/tuner.py
+++ b/src/torchpme/tuning/tuner.py
@@ -1,6 +1,6 @@
import math
import time
-from typing import Optional
+from typing import Optional, Union
import torch
@@ -80,13 +80,16 @@ def __init__(
calculator: type[Calculator],
exponent: int = 1,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
if exponent != 1:
raise NotImplementedError(
f"Only exponent = 1 is supported but got {exponent}."
)
+ self.device = torch.get_default_device() if device is None else device
+ self.dtype = torch.get_default_dtype() if dtype is None else dtype
+
_validate_parameters(
charges=charges,
cell=cell,
@@ -96,16 +99,15 @@ def __init__(
[1.0], device=positions.device, dtype=positions.dtype
),
smearing=1.0, # dummy value because; always have range-seperated potentials
+ dtype=self.dtype,
+ device=self.device,
)
-
self.charges = charges
self.cell = cell
self.positions = positions
self.cutoff = cutoff
self.calculator = calculator
self.exponent = exponent
- self.device = torch.get_default_device() if device is None else device
- self.dtype = torch.get_default_dtype() if dtype is None else dtype
self._prefac = 2 * float((charges**2).sum()) / math.sqrt(len(positions))
@@ -182,7 +184,7 @@ def __init__(
neighbor_distances: torch.Tensor,
exponent: int = 1,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__(
charges=charges,
@@ -280,10 +282,13 @@ def __init__(
n_warmup: int = 4,
run_backward: Optional[bool] = True,
dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
):
super().__init__()
+ self.dtype = torch.get_default_dtype() if dtype is None else dtype
+ self.device = torch.get_default_device() if device is None else device
+
_validate_parameters(
charges=charges,
cell=cell,
@@ -291,13 +296,13 @@ def __init__(
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
smearing=1.0, # dummy value because; always have range-seperated potentials
+ device=self.device,
+ dtype=self.dtype,
)
self.charges = charges
self.cell = cell
self.positions = positions
- self.dtype = dtype
- self.device = device
self.n_repeat = n_repeat
self.n_warmup = n_warmup
self.run_backward = run_backward
diff --git a/tests/calculators/test_calculator.py b/tests/calculators/test_calculator.py
index a6fc976a..567b0f15 100644
--- a/tests/calculators/test_calculator.py
+++ b/tests/calculators/test_calculator.py
@@ -43,6 +43,35 @@ def test_compute_output_shapes():
assert result.shape == charges.shape
+def test_wrong_device_positions():
+ calculator = CalculatorTest()
+ match = r"device of `positions` \(meta\) must be same as class device \(cpu\)"
+ with pytest.raises(ValueError, match=match):
+ calculator.forward(
+ positions=POSITIONS_1.to(device="meta"),
+ charges=CHARGES_1,
+ cell=CELL_1,
+ neighbor_indices=NEIGHBOR_INDICES,
+ neighbor_distances=NEIGHBOR_DISTANCES,
+ )
+
+
+def test_wrong_dtype_positions():
+ calculator = CalculatorTest()
+ match = (
+ r"type of `positions` \(torch.float64\) must be same as class type "
+ r"\(torch.float32\)"
+ )
+ with pytest.raises(TypeError, match=match):
+ calculator.forward(
+ positions=POSITIONS_1.to(dtype=torch.float64),
+ charges=CHARGES_1,
+ cell=CELL_1,
+ neighbor_indices=NEIGHBOR_INDICES,
+ neighbor_distances=NEIGHBOR_DISTANCES,
+ )
+
+
# Tests for invalid shape, dtype and device of positions
def test_invalid_shape_positions():
calculator = CalculatorTest()
@@ -82,7 +111,7 @@ def test_invalid_dtype_cell():
r"type of `cell` \(torch.float64\) must be same as `positions` "
r"\(torch.float32\)"
)
- with pytest.raises(ValueError, match=match):
+ with pytest.raises(TypeError, match=match):
calculator.forward(
positions=POSITIONS_1,
charges=CHARGES_1,
@@ -163,7 +192,7 @@ def test_invalid_dtype_charges():
r"type of `charges` \(torch.float64\) must be same as `positions` "
r"\(torch.float32\)"
)
- with pytest.raises(ValueError, match=match):
+ with pytest.raises(TypeError, match=match):
calculator.forward(
positions=POSITIONS_1,
charges=CHARGES_1.to(dtype=torch.float64),
@@ -252,7 +281,7 @@ def test_invalid_dtype_neighbor_distances():
r"type of `neighbor_distances` \(torch.float64\) must be same "
r"as `positions` \(torch.float32\)"
)
- with pytest.raises(ValueError, match=match):
+ with pytest.raises(TypeError, match=match):
calculator.forward(
positions=POSITIONS_1,
charges=CHARGES_1,
diff --git a/tests/calculators/test_values_direct.py b/tests/calculators/test_values_direct.py
index 47f819af..a5ace407 100644
--- a/tests/calculators/test_values_direct.py
+++ b/tests/calculators/test_values_direct.py
@@ -17,7 +17,11 @@
class CalculatorTest(Calculator):
def __init__(self, **kwargs):
super().__init__(
- potential=CoulombPotential(smearing=None, exclusion_radius=None), **kwargs
+ potential=CoulombPotential(
+ smearing=None, exclusion_radius=None, dtype=DTYPE
+ ),
+ **kwargs,
+ dtype=DTYPE,
)
diff --git a/tests/calculators/test_values_ewald.py b/tests/calculators/test_values_ewald.py
index cb3ff708..edcb886e 100644
--- a/tests/calculators/test_values_ewald.py
+++ b/tests/calculators/test_values_ewald.py
@@ -87,7 +87,9 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
to triclinic, as well as cation-anion ratios of 1:1, 1:2 and 2:1.
"""
# Get input parameters and adjust to account for scaling
- pos, charges, cell, madelung_ref, num_units = define_crystal(crystal_name)
+ pos, charges, cell, madelung_ref, num_units = define_crystal(
+ crystal_name, dtype=DTYPE
+ )
pos *= scaling_factor
cell *= scaling_factor
madelung_ref /= scaling_factor
@@ -99,11 +101,9 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
smearing = sr_cutoff / 5.0
lr_wavelength = 0.5 * smearing
calc = EwaldCalculator(
- InversePowerLawPotential(
- exponent=1,
- smearing=smearing,
- ),
+ InversePowerLawPotential(exponent=1, smearing=smearing, dtype=DTYPE),
lr_wavelength=lr_wavelength,
+ dtype=DTYPE,
)
rtol = 4e-6
elif calc_name == "pme":
@@ -113,16 +113,19 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
InversePowerLawPotential(
exponent=1,
smearing=smearing,
+ dtype=DTYPE,
),
mesh_spacing=smearing / 8,
+ dtype=DTYPE,
)
rtol = 9e-4
elif calc_name == "p3m":
sr_cutoff = 2 * scaling_factor
smearing = sr_cutoff / 5.0
calc = P3MCalculator(
- CoulombPotential(smearing=smearing),
+ CoulombPotential(smearing=smearing, dtype=DTYPE),
mesh_spacing=smearing / 8,
+ dtype=DTYPE,
)
rtol = 9e-4
@@ -132,7 +135,6 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
)
# Compute potential and compare against target value using default hypers
- calc.to(dtype=DTYPE)
potentials = calc.forward(
positions=pos,
charges=charges,
@@ -186,26 +188,22 @@ def test_wigner(crystal_name, scaling_factor):
# The first value of 0.1 corresponds to what would be
# chosen by default for the "wigner_sc" or "wigner_bcc_cubiccell" structure.
- smearings = torch.tensor([0.1, 0.06, 0.019], dtype=torch.float64)
- for smearing in smearings:
+ for smearing in [0.1, 0.06, 0.019]:
# Readjust smearing parameter to match nearest neighbor distance
if crystal_name in ["wigner_fcc", "wigner_fcc_cubiccell"]:
- smeareff = float(smearing) / np.sqrt(2)
+ smeareff = smearing / np.sqrt(2)
elif crystal_name in ["wigner_bcc_cubiccell", "wigner_bcc"]:
- smeareff = float(smearing) * np.sqrt(3) / 2
+ smeareff = smearing * np.sqrt(3) / 2
elif crystal_name == "wigner_sc":
- smeareff = float(smearing)
+ smeareff = smearing
smeareff *= scaling_factor
# Compute potential and compare against reference
calc = EwaldCalculator(
- InversePowerLawPotential(
- exponent=1,
- smearing=smeareff,
- ),
+ InversePowerLawPotential(exponent=1, smearing=smeareff, dtype=DTYPE),
lr_wavelength=smeareff / 2,
+ dtype=DTYPE,
)
- calc.to(dtype=DTYPE)
potentials = calc.forward(
positions=positions,
charges=charges,
@@ -253,25 +251,28 @@ def test_random_structure(
if calc_name == "ewald":
calc = EwaldCalculator(
- CoulombPotential(smearing=smearing),
+ CoulombPotential(smearing=smearing, dtype=DTYPE),
lr_wavelength=0.5 * smearing,
full_neighbor_list=full_neighbor_list,
prefactor=torchpme.prefactors.eV_A,
+ dtype=DTYPE,
)
elif calc_name == "pme":
calc = PMECalculator(
- CoulombPotential(smearing=smearing),
+ CoulombPotential(smearing=smearing, dtype=DTYPE),
mesh_spacing=smearing / 8.0,
full_neighbor_list=full_neighbor_list,
prefactor=torchpme.prefactors.eV_A,
+ dtype=DTYPE,
)
elif calc_name == "p3m":
calc = P3MCalculator(
- CoulombPotential(smearing=smearing),
+ CoulombPotential(smearing=smearing, dtype=DTYPE),
mesh_spacing=smearing / 8.0,
full_neighbor_list=full_neighbor_list,
prefactor=torchpme.prefactors.eV_A,
+ dtype=DTYPE,
)
neighbor_indices, neighbor_shifts = neighbor_list(
@@ -294,7 +295,6 @@ def test_random_structure(
neighbor_shifts=neighbor_shifts,
)
- calc.to(dtype=DTYPE)
potentials = calc.forward(
positions=positions,
charges=charges,
diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py
index 83377006..209355d4 100644
--- a/tests/calculators/test_workflow.py
+++ b/tests/calculators/test_workflow.py
@@ -4,7 +4,6 @@
"""
import io
-import math
import pytest
import torch
@@ -17,15 +16,15 @@
PMECalculator,
)
-AVAILABLE_DEVICES = ["cpu"] + torch.cuda.is_available() * ["cuda"]
-MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3))
-CHARGES_CSCL = torch.tensor([1.0, -1.0])
+DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"]
+DTYPES = [torch.float32, torch.float64]
SMEARING = 0.1
LR_WAVELENGTH = SMEARING / 4
MESH_SPACING = SMEARING / 4
-@pytest.mark.parametrize("device", AVAILABLE_DEVICES)
+@pytest.mark.parametrize("device", DEVICES)
+@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
("CalculatorClass", "params"),
[
@@ -59,125 +58,128 @@
],
)
class TestWorkflow:
- def cscl_system(self, device=None):
+ def cscl_system(self, device=None, dtype=None):
"""CsCl crystal. Same as in the madelung test"""
- if device is None:
- device = torch.device("cpu")
-
- positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]])
- charges = torch.tensor([1.0, -1.0]).reshape((-1, 1))
- cell = torch.eye(3)
- neighbor_indices = torch.tensor([[0, 1]], dtype=torch.int64)
- neighbor_distances = torch.tensor([0.8660])
-
- return (
- charges.to(device=device),
- cell.to(device=device),
- positions.to(device=device),
- neighbor_indices.to(device=device),
- neighbor_distances.to(device=device),
+ device = torch.get_default_device() if device is None else device
+ dtype = torch.get_default_dtype() if dtype is None else dtype
+
+ positions = torch.tensor(
+ [[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype, device=device
)
+ charges = torch.tensor([1.0, -1.0], dtype=dtype, device=device).reshape((-1, 1))
+ cell = torch.eye(3, dtype=dtype, device=device)
+ neighbor_indices = torch.tensor([[0, 1]], dtype=torch.int64, device=device)
+ neighbor_distances = torch.tensor([0.8660], dtype=dtype, device=device)
+
+ return charges, cell, positions, neighbor_indices, neighbor_distances
- def test_smearing_non_positive(self, CalculatorClass, params, device):
+ def test_smearing_non_positive(self, CalculatorClass, params, device, dtype):
params = params.copy()
match = r"`smearing` .* has to be positive"
if type(CalculatorClass) in [EwaldCalculator, PMECalculator]:
params["smearing"] = 0
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device)
+ CalculatorClass(**params, device=device, dtype=dtype)
params["smearing"] = -0.1
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device)
+ CalculatorClass(**params, device=device, dtype=dtype)
- def test_interpolation_order_error(self, CalculatorClass, params, device):
+ def test_interpolation_order_error(self, CalculatorClass, params, device, dtype):
params = params.copy()
if type(CalculatorClass) in [PMECalculator]:
match = "Only `interpolation_nodes` from 1 to 5"
params["interpolation_nodes"] = 10
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device)
+ CalculatorClass(**params, device=device, dtype=dtype)
- def test_lr_wavelength_non_positive(self, CalculatorClass, params, device):
+ def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype):
params = params.copy()
match = r"`lr_wavelength` .* has to be positive"
if type(CalculatorClass) in [EwaldCalculator]:
params["lr_wavelength"] = 0
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device)
+ CalculatorClass(**params, device=device, dtype=dtype)
params["lr_wavelength"] = -0.1
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device)
+ CalculatorClass(**params, device=device, dtype=dtype)
- def test_dtype_device(self, CalculatorClass, params, device):
+ def test_dtype_device(self, CalculatorClass, params, device, dtype):
"""Test that the output dtype and device are the same as the input."""
- dtype = torch.float64
params = params.copy()
- positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device)
- charges = torch.ones((1, 2), dtype=dtype, device=device)
- cell = torch.eye(3, dtype=dtype, device=device)
- neighbor_indices = torch.tensor([[0, 0]], device=device)
- neighbor_distances = torch.tensor([0.1], dtype=dtype, device=device)
params["potential"].device = device
- calculator = CalculatorClass(**params, device=device)
- potential = calculator.forward(
- charges=charges,
- cell=cell,
- positions=positions,
- neighbor_indices=neighbor_indices,
- neighbor_distances=neighbor_distances,
- )
+ params["potential"].dtype = dtype
+
+ calculator = CalculatorClass(**params, device=device, dtype=dtype)
+ potential = calculator.forward(*self.cscl_system(device=device, dtype=dtype))
assert potential.dtype == dtype
- assert potential.device.type == device
- def check_operation(self, calculator, device):
+ if isinstance(device, torch.device):
+ assert potential.device == device
+ else:
+ assert potential.device.type == device
+
+ def check_operation(self, calculator, device, dtype):
"""Make sure computation runs and returns a torch.Tensor."""
- descriptor = calculator.forward(*self.cscl_system(device))
+ descriptor = calculator.forward(*self.cscl_system(device=device, dtype=dtype))
assert type(descriptor) is torch.Tensor
- def test_operation_as_python(self, CalculatorClass, params, device):
+ def test_operation_as_python(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a normal python script"""
params = params.copy()
params["potential"].device = device
- calculator = CalculatorClass(**params, device=device)
- self.check_operation(calculator=calculator, device=device)
+ params["potential"].dtype = dtype
+
+ calculator = CalculatorClass(**params, device=device, dtype=dtype)
+ self.check_operation(calculator=calculator, device=device, dtype=dtype)
- def test_operation_as_torch_script(self, CalculatorClass, params, device):
+ def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a compiled torch script module."""
params = params.copy()
params["potential"].device = device
- calculator = CalculatorClass(**params, device=device)
+ params["potential"].dtype = dtype
+
+ calculator = CalculatorClass(**params, device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
- self.check_operation(calculator=scripted, device=device)
+ self.check_operation(calculator=scripted, device=device, dtype=dtype)
- def test_save_load(self, CalculatorClass, params, device):
+ def test_save_load(self, CalculatorClass, params, device, dtype):
params = params.copy()
params["potential"].device = device
- calculator = CalculatorClass(**params, device=device)
+ params["potential"].dtype = dtype
+
+ calculator = CalculatorClass(**params, device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
with io.BytesIO() as buffer:
torch.jit.save(scripted, buffer)
buffer.seek(0)
torch.jit.load(buffer)
- def test_prefactor(self, CalculatorClass, params, device):
+ def test_prefactor(self, CalculatorClass, params, device, dtype):
"""Test if the prefactor is applied correctly."""
params = params.copy()
params["potential"].device = device
+ params["potential"].dtype = dtype
+
prefactor = 2.0
- calculator1 = CalculatorClass(**params, device=device)
- calculator2 = CalculatorClass(**params, prefactor=prefactor, device=device)
- potentials1 = calculator1.forward(*self.cscl_system())
- potentials2 = calculator2.forward(*self.cscl_system())
+ calculator1 = CalculatorClass(**params, device=device, dtype=dtype)
+ calculator2 = CalculatorClass(
+ **params, prefactor=prefactor, device=device, dtype=dtype
+ )
+
+ potentials1 = calculator1.forward(*self.cscl_system(device=device, dtype=dtype))
+ potentials2 = calculator2.forward(*self.cscl_system(device=device, dtype=dtype))
+
assert torch.allclose(potentials1 * prefactor, potentials2)
- def test_not_nan(self, CalculatorClass, params, device):
+ def test_not_nan(self, CalculatorClass, params, device, dtype):
"""Make sure derivatives are not NaN."""
params = params.copy()
params["potential"].device = device
+ params["potential"].dtype = dtype
- calculator = CalculatorClass(**params, device=device)
- system = self.cscl_system(device)
+ calculator = CalculatorClass(**params, device=device, dtype=dtype)
+ system = self.cscl_system(device=device, dtype=dtype)
system[0].requires_grad = True
system[1].requires_grad = True
system[2].requires_grad = True
@@ -203,26 +205,45 @@ def test_not_nan(self, CalculatorClass, params, device):
torch.autograd.grad(energy, system[2], retain_graph=True)[0]
).any()
- def test_dtype_and_device_incompatability(self, CalculatorClass, params, device):
- """Test that the calculator raises an error if the dtype and device are incompatible."""
+ def test_dtype_and_device_incompatability(
+ self, CalculatorClass, params, device, dtype
+ ):
+ """Test that the calculator raises an error if the dtype and device are incompatible with potential."""
params = params.copy()
+
+ other_dtype = torch.float32 if dtype == torch.float64 else torch.float64
+
params["potential"].device = device
- params["potential"].dtype = torch.float64
- with pytest.raises(AssertionError, match=".*dtype.*"):
- CalculatorClass(**params, dtype=torch.float32, device=device)
- with pytest.raises(AssertionError, match=".*device.*"):
- CalculatorClass(
- **params, dtype=params["potential"].dtype, device=torch.device("meta")
- )
+ params["potential"].dtype = dtype
+
+ match = (
+ rf"dtype of `potential` \({params['potential'].dtype}\) must be same as "
+ rf"of `calculator` \({other_dtype}\)"
+ )
+ with pytest.raises(TypeError, match=match):
+ CalculatorClass(**params, dtype=other_dtype, device=device)
+
+ match = (
+ rf"device of `potential` \({params['potential'].device}\) must be same as "
+ rf"of `calculator` \(meta\)"
+ )
+ with pytest.raises(ValueError, match=match):
+ CalculatorClass(**params, dtype=dtype, device=torch.device("meta"))
def test_potential_and_calculator_incompatability(
- self, CalculatorClass, params, device
+ self,
+ CalculatorClass,
+ params,
+ device,
+ dtype,
):
"""Test that the calculator raises an error if the potential and calculator are incompatible."""
params = params.copy()
params["potential"].device = device
+ params["potential"].dtype = dtype
+
params["potential"] = torch.jit.script(params["potential"])
with pytest.raises(
TypeError, match="Potential must be an instance of Potential, got.*"
):
- CalculatorClass(**params)
+ CalculatorClass(**params, device=device, dtype=dtype)
diff --git a/tests/helpers.py b/tests/helpers.py
index f970f596..d38e9f9b 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -227,14 +227,13 @@ def define_crystal(crystal_name="CsCl", dtype=None, device=None):
else:
raise ValueError(f"crystal_name = {crystal_name} is not supported!")
- madelung_ref = torch.tensor(madelung_ref)
charges = charges.reshape((-1, 1))
return (
positions.to(device=device, dtype=dtype),
charges.to(device=device, dtype=dtype),
cell.to(device=device, dtype=dtype),
- madelung_ref,
+ torch.tensor(madelung_ref, device=device, dtype=dtype),
num_formula_units,
)
diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py
index 025bd91c..6beed927 100644
--- a/tests/metatensor/test_workflow_metatensor.py
+++ b/tests/metatensor/test_workflow_metatensor.py
@@ -11,14 +11,15 @@
mts_torch = pytest.importorskip("metatensor.torch")
mts_atomistic = pytest.importorskip("metatensor.torch.atomistic")
-AVAILABLE_DEVICES = [torch.device("cpu")] + torch.cuda.is_available() * [
- torch.device("cuda")
-]
+DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"]
+DTYPES = [torch.float32, torch.float64]
SMEARING = 0.1
LR_WAVELENGTH = SMEARING / 4
MESH_SPACING = SMEARING / 4
+@pytest.mark.parametrize("device", DEVICES)
+@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
("CalculatorClass", "params"),
[
@@ -52,7 +53,10 @@
],
)
class TestWorkflow:
- def system(self, device=None):
+ def system(self, device=None, dtype=None):
+ device = torch.get_default_device() if device is None else device
+ dtype = torch.get_default_dtype() if dtype is None else dtype
+
system = mts_atomistic.System(
types=torch.tensor([1, 2, 2]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.2], [0.0, 0.0, 0.5]]),
@@ -92,30 +96,32 @@ def system(self, device=None):
return system.to(device=device), neighbors.to(device=device)
- def check_operation(self, calculator, device):
+ def check_operation(self, calculator, device, dtype):
"""Make sure computation runs and returns a metatensor.TensorMap."""
- system, neighbors = self.system(device)
+ system, neighbors = self.system(device=device, dtype=dtype)
descriptor = calculator.forward(system, neighbors)
assert isinstance(descriptor, torch.ScriptObject)
if version.parse(torch.__version__) >= version.parse("2.1"):
assert descriptor._type().name() == "TensorMap"
- @pytest.mark.parametrize("device", AVAILABLE_DEVICES)
- def test_operation_as_python(self, CalculatorClass, params, device):
+ def test_operation_as_python(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a normal python script"""
calculator = CalculatorClass(**params)
- self.check_operation(calculator=calculator, device=device)
+ self.check_operation(calculator=calculator, device=device, dtype=dtype)
- @pytest.mark.parametrize("device", AVAILABLE_DEVICES)
- def test_operation_as_torch_script(self, CalculatorClass, params, device):
+ def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a compiled torch script module."""
calculator = CalculatorClass(**params)
scripted = torch.jit.script(calculator)
- self.check_operation(calculator=scripted, device=device)
+ self.check_operation(calculator=scripted, device=device, dtype=dtype)
- def test_save_load(self, CalculatorClass, params):
- calculator = CalculatorClass(**params)
+ def test_save_load(self, CalculatorClass, params, device, dtype):
+ params = params.copy()
+ params["potential"].device = device
+ params["potential"].dtype = dtype
+
+ calculator = CalculatorClass(**params, device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
with io.BytesIO() as buffer:
torch.jit.save(scripted, buffer)
diff --git a/tests/tuning/test_tuning.py b/tests/tuning/test_tuning.py
index 17a43580..6477115b 100644
--- a/tests/tuning/test_tuning.py
+++ b/tests/tuning/test_tuning.py
@@ -16,31 +16,46 @@
sys.path.append(str(Path(__file__).parents[1]))
from helpers import define_crystal, neighbor_list
-DTYPE = torch.float32
-DEVICE = "cpu"
DEFAULT_CUTOFF = 4.4
-CHARGES_1 = torch.ones((4, 1), dtype=DTYPE, device=DEVICE)
-POSITIONS_1 = 0.3 * torch.arange(12, dtype=DTYPE, device=DEVICE).reshape((4, 3))
-CELL_1 = torch.eye(3, dtype=DTYPE, device=DEVICE)
+DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"]
+DTYPES = [torch.float32, torch.float64]
-def test_TunerBase_double():
+def system(device=None, dtype=None):
+ device = torch.get_default_device() if device is None else device
+ dtype = torch.get_default_dtype() if dtype is None else dtype
+
+ charges = torch.ones((4, 1), dtype=dtype, device=device)
+ cell = torch.eye(3, dtype=dtype, device=device)
+ positions = 0.3 * torch.arange(12, dtype=dtype, device=device).reshape((4, 3))
+
+ return charges, cell, positions
+
+
+@pytest.mark.parametrize("device", DEVICES)
+@pytest.mark.parametrize("dtype", DTYPES)
+def test_TunerBase_init(device, dtype):
"""
- Check that `TunerBase` initilizes with double precisions tensors.
+ Check that `TunerBase` initilizes correctly.
We are using dummy `neighbor_indices` and `neighbor_distances` to verify types. Have
to be sure that these dummy variables are initilized correctly.
"""
+ charges, cell, positions = system(device, dtype)
TunerBase(
- charges=CHARGES_1.to(dtype=torch.float64),
- cell=CELL_1.to(dtype=torch.float64),
- positions=POSITIONS_1.to(dtype=torch.float64),
+ charges=charges,
+ cell=cell,
+ positions=positions,
cutoff=DEFAULT_CUTOFF,
calculator=1.0,
exponent=1,
+ dtype=dtype,
+ device=device,
)
+@pytest.mark.parametrize("device", DEVICES)
+@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
("calculator", "tune", "param_length"),
[
@@ -50,13 +65,15 @@ def test_TunerBase_double():
],
)
@pytest.mark.parametrize("accuracy", [1e-1, 1e-3, 1e-5])
-def test_parameter_choose(calculator, tune, param_length, accuracy):
+def test_parameter_choose(device, dtype, calculator, tune, param_length, accuracy):
"""
Check that the Madelung constants obtained from the Ewald sum calculator matches
the reference values and that all branches of the from_accuracy method are covered.
"""
# Get input parameters and adjust to account for scaling
- pos, charges, cell, madelung_ref, num_units = define_crystal()
+ pos, charges, cell, madelung_ref, num_units = define_crystal(
+ dtype=dtype, device=device
+ )
# Compute neighbor list
neighbor_indices, neighbor_distances = neighbor_list(
@@ -71,13 +88,17 @@ def test_parameter_choose(calculator, tune, param_length, accuracy):
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
accuracy=accuracy,
+ dtype=dtype,
+ device=device,
)
assert len(params) == param_length
# Compute potential and compare against target value using default hypers
calc = calculator(
- potential=(CoulombPotential(smearing=smearing)),
+ potential=(CoulombPotential(smearing=smearing, dtype=dtype, device=device)),
+ dtype=dtype,
+ device=device,
**params,
)
potentials = calc.forward(
@@ -103,12 +124,12 @@ def test_accuracy_error(tune):
)
with pytest.raises(ValueError, match=match):
tune(
- charges,
- cell,
- pos,
- DEFAULT_CUTOFF,
- neighbor_indices,
- neighbor_distances,
+ charges=charges,
+ cell=cell,
+ positions=pos,
+ cutoff=DEFAULT_CUTOFF,
+ neighbor_indices=neighbor_indices,
+ neighbor_distances=neighbor_distances,
accuracy="foo",
)
@@ -116,79 +137,90 @@ def test_accuracy_error(tune):
@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m])
def test_exponent_not_1_error(tune):
pos, charges, cell, _, _ = define_crystal()
-
- match = "Only exponent = 1 is supported but got 2."
neighbor_indices, neighbor_distances = neighbor_list(
positions=pos, box=cell, cutoff=DEFAULT_CUTOFF
)
+
+ match = "Only exponent = 1 is supported but got 2."
with pytest.raises(NotImplementedError, match=match):
tune(
- charges,
- cell,
- pos,
- DEFAULT_CUTOFF,
- neighbor_indices,
- neighbor_distances,
+ charges=charges,
+ cell=cell,
+ positions=pos,
+ cutoff=DEFAULT_CUTOFF,
+ neighbor_indices=neighbor_indices,
+ neighbor_distances=neighbor_distances,
exponent=2,
)
@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m])
def test_invalid_shape_positions(tune):
+ charges, cell, _ = system()
match = (
r"`positions` must be a tensor with shape \[n_atoms, 3\], got tensor with "
r"shape \[4, 5\]"
)
with pytest.raises(ValueError, match=match):
tune(
- CHARGES_1,
- CELL_1,
- torch.ones((4, 5), dtype=DTYPE, device=DEVICE),
- DEFAULT_CUTOFF,
- None, # dummy neighbor indices
- None, # dummy neighbor distances
+ charges=charges,
+ cell=cell,
+ positions=torch.ones((4, 5)),
+ cutoff=DEFAULT_CUTOFF,
+ neighbor_indices=None,
+ neighbor_distances=None,
)
# Tests for invalid shape, dtype and device of cell
@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m])
def test_invalid_shape_cell(tune):
+ charges, _, positions = system()
match = (
r"`cell` must be a tensor with shape \[3, 3\], got tensor with shape \[2, 2\]"
)
with pytest.raises(ValueError, match=match):
tune(
- CHARGES_1,
- torch.ones([2, 2], dtype=DTYPE, device=DEVICE),
- POSITIONS_1,
- DEFAULT_CUTOFF,
- None, # dummy neighbor indices
- None, # dummy neighbor distances
+ charges=charges,
+ cell=torch.ones([2, 2]),
+ positions=positions,
+ cutoff=DEFAULT_CUTOFF,
+ neighbor_indices=None,
+ neighbor_distances=None,
)
@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m])
def test_invalid_cell(tune):
+ charges, _, positions = system()
match = (
"provided `cell` has a determinant of 0 and therefore is not valid for "
"periodic calculation"
)
with pytest.raises(ValueError, match=match):
- tune(CHARGES_1, torch.zeros(3, 3), POSITIONS_1, DEFAULT_CUTOFF, None, None)
+ tune(
+ charges=charges,
+ cell=torch.zeros(3, 3),
+ positions=positions,
+ cutoff=DEFAULT_CUTOFF,
+ neighbor_indices=None,
+ neighbor_distances=None,
+ )
@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m])
def test_invalid_dtype_cell(tune):
+ charges, _, positions = system()
match = (
r"type of `cell` \(torch.float64\) must be same as `positions` "
r"\(torch.float32\)"
)
- with pytest.raises(ValueError, match=match):
+ with pytest.raises(TypeError, match=match):
tune(
- CHARGES_1,
- torch.eye(3, dtype=torch.float64, device=DEVICE),
- POSITIONS_1,
- DEFAULT_CUTOFF,
- None,
- None,
+ charges=charges,
+ cell=torch.eye(3, dtype=torch.float64),
+ positions=positions,
+ cutoff=DEFAULT_CUTOFF,
+ neighbor_indices=None,
+ neighbor_distances=None,
)