diff --git a/docs/src/references/changelog.rst b/docs/src/references/changelog.rst index 5631f33c..d927b525 100644 --- a/docs/src/references/changelog.rst +++ b/docs/src/references/changelog.rst @@ -27,9 +27,17 @@ changelog `_ format. This project follows Added ##### +* Enhanced ``device`` and ``dtype`` consistency checks throughout the library * Require consistent ``dtype`` between ``positions`` and ``neighbor_distances`` in ``Calculator`` classes and tuning functions. + +Fixed +##### + +* Fix ``device`` and ``dtype`` not being used in the init of the ``P3MCalculator`` + + `Version 0.2.0 `_ - 2025-01-23 ------------------------------------------------------------------------------------------ diff --git a/src/torchpme/_utils.py b/src/torchpme/_utils.py index 575e9edc..f282a626 100644 --- a/src/torchpme/_utils.py +++ b/src/torchpme/_utils.py @@ -10,9 +10,23 @@ def _validate_parameters( neighbor_indices: torch.Tensor, neighbor_distances: torch.Tensor, smearing: Union[float, None], + dtype: torch.dtype, + device: Union[str, torch.device], ) -> None: - device = positions.device - dtype = positions.dtype + if positions.dtype != dtype: + raise TypeError( + f"type of `positions` ({positions.dtype}) must be same as class " + f"type ({dtype})" + ) + + if isinstance(device, torch.device): + device = device.type + + if positions.device.type != device: + raise ValueError( + f"device of `positions` ({positions.device}) must be same as class " + f"device ({device})" + ) # check shape, dtype and device of positions num_atoms = len(positions) @@ -29,12 +43,12 @@ def _validate_parameters( f"{list(cell.shape)}" ) - if cell.dtype != dtype: - raise ValueError( + if cell.dtype != positions.dtype: + raise TypeError( f"type of `cell` ({cell.dtype}) must be same as `positions` ({dtype})" ) - if cell.device != device: + if cell.device != positions.device: raise ValueError( f"device of `cell` ({cell.device}) must be same as `positions` ({device})" ) @@ -63,12 +77,12 @@ def _validate_parameters( f"{len(positions)} atoms" ) - if charges.dtype != dtype: - raise ValueError( + if charges.dtype != positions.dtype: + raise TypeError( f"type of `charges` ({charges.dtype}) must be same as `positions` ({dtype})" ) - if charges.device != device: + if charges.device != positions.device: raise ValueError( f"device of `charges` ({charges.device}) must be same as `positions` " f"({device})" @@ -82,7 +96,7 @@ def _validate_parameters( "structure" ) - if neighbor_indices.device != device: + if neighbor_indices.device != positions.device: raise ValueError( f"device of `neighbor_indices` ({neighbor_indices.device}) must be " f"same as `positions` ({device})" @@ -95,14 +109,14 @@ def _validate_parameters( f"{list(neighbor_indices.shape)} and {list(neighbor_distances.shape)}" ) - if neighbor_distances.device != device: + if neighbor_distances.device != positions.device: raise ValueError( f"device of `neighbor_distances` ({neighbor_distances.device}) must be " f"same as `positions` ({device})" ) - if neighbor_distances.dtype != dtype: - raise ValueError( + if neighbor_distances.dtype != positions.dtype: + raise TypeError( f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same " f"as `positions` ({dtype})" ) diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py index 70ede3ef..29eaadff 100644 --- a/src/torchpme/calculators/calculator.py +++ b/src/torchpme/calculators/calculator.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch from torch import profiler @@ -37,7 +37,7 @@ def __init__( full_neighbor_list: bool = False, prefactor: float = 1.0, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__() @@ -48,17 +48,20 @@ def __init__( self.device = torch.get_default_device() if device is None else device self.dtype = torch.get_default_dtype() if dtype is None else dtype - self.potential = potential - assert self.dtype == self.potential.dtype, ( - f"Potential and Calculator must have the same dtype, got {self.dtype} and " - f"{self.potential.dtype}" - ) - assert self.device == self.potential.device, ( - f"Potential and Calculator must have the same device, got {self.device} and " - f"{self.potential.device}" - ) + if self.dtype != potential.dtype: + raise TypeError( + f"dtype of `potential` ({potential.dtype}) must be same as of " + f"`calculator` ({self.dtype})" + ) + + if self.device != potential.device: + raise ValueError( + f"device of `potential` ({potential.device}) must be same as of " + f"`calculator` ({self.device})" + ) + self.potential = potential self.full_neighbor_list = full_neighbor_list self.prefactor = prefactor @@ -179,6 +182,8 @@ def forward( neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, smearing=self.potential.smearing, + dtype=self.dtype, + device=self.device, ) # Compute short-range (SR) part using a real space sum diff --git a/src/torchpme/calculators/ewald.py b/src/torchpme/calculators/ewald.py index 7d213f2f..83e2dc85 100644 --- a/src/torchpme/calculators/ewald.py +++ b/src/torchpme/calculators/ewald.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -66,7 +66,7 @@ def __init__( full_neighbor_list: bool = False, prefactor: float = 1.0, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__( potential=potential, diff --git a/src/torchpme/calculators/p3m.py b/src/torchpme/calculators/p3m.py index eb23c780..f85533db 100644 --- a/src/torchpme/calculators/p3m.py +++ b/src/torchpme/calculators/p3m.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -56,7 +56,7 @@ def __init__( full_neighbor_list: bool = False, prefactor: float = 1.0, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): self.mesh_spacing: float = mesh_spacing diff --git a/src/torchpme/calculators/pme.py b/src/torchpme/calculators/pme.py index 95f74216..dd389812 100644 --- a/src/torchpme/calculators/pme.py +++ b/src/torchpme/calculators/pme.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch from torch import profiler @@ -59,7 +59,7 @@ def __init__( full_neighbor_list: bool = False, prefactor: float = 1.0, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__( potential=potential, diff --git a/src/torchpme/potentials/combined.py b/src/torchpme/potentials/combined.py index d76a20c0..212f4744 100644 --- a/src/torchpme/potentials/combined.py +++ b/src/torchpme/potentials/combined.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -39,7 +39,7 @@ def __init__( smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__( smearing=smearing, diff --git a/src/torchpme/potentials/coulomb.py b/src/torchpme/potentials/coulomb.py index 4cde5611..1e35897c 100644 --- a/src/torchpme/potentials/coulomb.py +++ b/src/torchpme/potentials/coulomb.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -35,7 +35,7 @@ def __init__( smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__(smearing, exclusion_radius, dtype, device) diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py index 35ff7ac7..374ab56e 100644 --- a/src/torchpme/potentials/inversepowerlaw.py +++ b/src/torchpme/potentials/inversepowerlaw.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch from torch.special import gammainc @@ -41,7 +41,7 @@ def __init__( smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__(smearing, exclusion_radius, dtype, device) diff --git a/src/torchpme/potentials/potential.py b/src/torchpme/potentials/potential.py index 674a8632..1efa783d 100644 --- a/src/torchpme/potentials/potential.py +++ b/src/torchpme/potentials/potential.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -39,7 +39,7 @@ def __init__( smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__() self.dtype = torch.get_default_dtype() if dtype is None else dtype diff --git a/src/torchpme/potentials/spline.py b/src/torchpme/potentials/spline.py index e8ffc3c5..b58d31eb 100644 --- a/src/torchpme/potentials/spline.py +++ b/src/torchpme/potentials/spline.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -58,7 +58,7 @@ def __init__( smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__( smearing=smearing, diff --git a/src/torchpme/tuning/ewald.py b/src/torchpme/tuning/ewald.py index 6452fa7d..459f0a02 100644 --- a/src/torchpme/tuning/ewald.py +++ b/src/torchpme/tuning/ewald.py @@ -1,5 +1,5 @@ import math -from typing import Any, Optional +from typing import Any, Optional, Union from warnings import warn import torch @@ -20,7 +20,7 @@ def tune_ewald( ns_hi: int = 14, accuracy: float = 1e-3, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ) -> tuple[float, dict[str, Any], float]: r""" Find the optimal parameters for :class:`torchpme.EwaldCalculator`. diff --git a/src/torchpme/tuning/p3m.py b/src/torchpme/tuning/p3m.py index 5685ffaf..92a6bbe0 100644 --- a/src/torchpme/tuning/p3m.py +++ b/src/torchpme/tuning/p3m.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import Any, Optional +from typing import Any, Optional, Union from warnings import warn import torch @@ -80,7 +80,7 @@ def tune_p3m( mesh_hi: int = 7, accuracy: float = 1e-3, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ) -> tuple[float, dict[str, Any], float]: r""" Find the optimal parameters for :class:`torchpme.calculators.pme.PMECalculator`. diff --git a/src/torchpme/tuning/pme.py b/src/torchpme/tuning/pme.py index 55a0556e..f9ddc4d8 100644 --- a/src/torchpme/tuning/pme.py +++ b/src/torchpme/tuning/pme.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import Any, Optional +from typing import Any, Optional, Union from warnings import warn import torch @@ -23,7 +23,7 @@ def tune_pme( mesh_hi: int = 7, accuracy: float = 1e-3, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ) -> tuple[float, dict[str, Any], float]: r""" Find the optimal parameters for :class:`torchpme.PMECalculator`. diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py index 8461fda5..354d6029 100644 --- a/src/torchpme/tuning/tuner.py +++ b/src/torchpme/tuning/tuner.py @@ -1,6 +1,6 @@ import math import time -from typing import Optional +from typing import Optional, Union import torch @@ -80,13 +80,16 @@ def __init__( calculator: type[Calculator], exponent: int = 1, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): if exponent != 1: raise NotImplementedError( f"Only exponent = 1 is supported but got {exponent}." ) + self.device = torch.get_default_device() if device is None else device + self.dtype = torch.get_default_dtype() if dtype is None else dtype + _validate_parameters( charges=charges, cell=cell, @@ -96,16 +99,15 @@ def __init__( [1.0], device=positions.device, dtype=positions.dtype ), smearing=1.0, # dummy value because; always have range-seperated potentials + dtype=self.dtype, + device=self.device, ) - self.charges = charges self.cell = cell self.positions = positions self.cutoff = cutoff self.calculator = calculator self.exponent = exponent - self.device = torch.get_default_device() if device is None else device - self.dtype = torch.get_default_dtype() if dtype is None else dtype self._prefac = 2 * float((charges**2).sum()) / math.sqrt(len(positions)) @@ -182,7 +184,7 @@ def __init__( neighbor_distances: torch.Tensor, exponent: int = 1, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__( charges=charges, @@ -280,10 +282,13 @@ def __init__( n_warmup: int = 4, run_backward: Optional[bool] = True, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__() + self.dtype = torch.get_default_dtype() if dtype is None else dtype + self.device = torch.get_default_device() if device is None else device + _validate_parameters( charges=charges, cell=cell, @@ -291,13 +296,13 @@ def __init__( neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, smearing=1.0, # dummy value because; always have range-seperated potentials + device=self.device, + dtype=self.dtype, ) self.charges = charges self.cell = cell self.positions = positions - self.dtype = dtype - self.device = device self.n_repeat = n_repeat self.n_warmup = n_warmup self.run_backward = run_backward diff --git a/tests/calculators/test_calculator.py b/tests/calculators/test_calculator.py index a6fc976a..567b0f15 100644 --- a/tests/calculators/test_calculator.py +++ b/tests/calculators/test_calculator.py @@ -43,6 +43,35 @@ def test_compute_output_shapes(): assert result.shape == charges.shape +def test_wrong_device_positions(): + calculator = CalculatorTest() + match = r"device of `positions` \(meta\) must be same as class device \(cpu\)" + with pytest.raises(ValueError, match=match): + calculator.forward( + positions=POSITIONS_1.to(device="meta"), + charges=CHARGES_1, + cell=CELL_1, + neighbor_indices=NEIGHBOR_INDICES, + neighbor_distances=NEIGHBOR_DISTANCES, + ) + + +def test_wrong_dtype_positions(): + calculator = CalculatorTest() + match = ( + r"type of `positions` \(torch.float64\) must be same as class type " + r"\(torch.float32\)" + ) + with pytest.raises(TypeError, match=match): + calculator.forward( + positions=POSITIONS_1.to(dtype=torch.float64), + charges=CHARGES_1, + cell=CELL_1, + neighbor_indices=NEIGHBOR_INDICES, + neighbor_distances=NEIGHBOR_DISTANCES, + ) + + # Tests for invalid shape, dtype and device of positions def test_invalid_shape_positions(): calculator = CalculatorTest() @@ -82,7 +111,7 @@ def test_invalid_dtype_cell(): r"type of `cell` \(torch.float64\) must be same as `positions` " r"\(torch.float32\)" ) - with pytest.raises(ValueError, match=match): + with pytest.raises(TypeError, match=match): calculator.forward( positions=POSITIONS_1, charges=CHARGES_1, @@ -163,7 +192,7 @@ def test_invalid_dtype_charges(): r"type of `charges` \(torch.float64\) must be same as `positions` " r"\(torch.float32\)" ) - with pytest.raises(ValueError, match=match): + with pytest.raises(TypeError, match=match): calculator.forward( positions=POSITIONS_1, charges=CHARGES_1.to(dtype=torch.float64), @@ -252,7 +281,7 @@ def test_invalid_dtype_neighbor_distances(): r"type of `neighbor_distances` \(torch.float64\) must be same " r"as `positions` \(torch.float32\)" ) - with pytest.raises(ValueError, match=match): + with pytest.raises(TypeError, match=match): calculator.forward( positions=POSITIONS_1, charges=CHARGES_1, diff --git a/tests/calculators/test_values_direct.py b/tests/calculators/test_values_direct.py index 47f819af..a5ace407 100644 --- a/tests/calculators/test_values_direct.py +++ b/tests/calculators/test_values_direct.py @@ -17,7 +17,11 @@ class CalculatorTest(Calculator): def __init__(self, **kwargs): super().__init__( - potential=CoulombPotential(smearing=None, exclusion_radius=None), **kwargs + potential=CoulombPotential( + smearing=None, exclusion_radius=None, dtype=DTYPE + ), + **kwargs, + dtype=DTYPE, ) diff --git a/tests/calculators/test_values_ewald.py b/tests/calculators/test_values_ewald.py index cb3ff708..edcb886e 100644 --- a/tests/calculators/test_values_ewald.py +++ b/tests/calculators/test_values_ewald.py @@ -87,7 +87,9 @@ def test_madelung(crystal_name, scaling_factor, calc_name): to triclinic, as well as cation-anion ratios of 1:1, 1:2 and 2:1. """ # Get input parameters and adjust to account for scaling - pos, charges, cell, madelung_ref, num_units = define_crystal(crystal_name) + pos, charges, cell, madelung_ref, num_units = define_crystal( + crystal_name, dtype=DTYPE + ) pos *= scaling_factor cell *= scaling_factor madelung_ref /= scaling_factor @@ -99,11 +101,9 @@ def test_madelung(crystal_name, scaling_factor, calc_name): smearing = sr_cutoff / 5.0 lr_wavelength = 0.5 * smearing calc = EwaldCalculator( - InversePowerLawPotential( - exponent=1, - smearing=smearing, - ), + InversePowerLawPotential(exponent=1, smearing=smearing, dtype=DTYPE), lr_wavelength=lr_wavelength, + dtype=DTYPE, ) rtol = 4e-6 elif calc_name == "pme": @@ -113,16 +113,19 @@ def test_madelung(crystal_name, scaling_factor, calc_name): InversePowerLawPotential( exponent=1, smearing=smearing, + dtype=DTYPE, ), mesh_spacing=smearing / 8, + dtype=DTYPE, ) rtol = 9e-4 elif calc_name == "p3m": sr_cutoff = 2 * scaling_factor smearing = sr_cutoff / 5.0 calc = P3MCalculator( - CoulombPotential(smearing=smearing), + CoulombPotential(smearing=smearing, dtype=DTYPE), mesh_spacing=smearing / 8, + dtype=DTYPE, ) rtol = 9e-4 @@ -132,7 +135,6 @@ def test_madelung(crystal_name, scaling_factor, calc_name): ) # Compute potential and compare against target value using default hypers - calc.to(dtype=DTYPE) potentials = calc.forward( positions=pos, charges=charges, @@ -186,26 +188,22 @@ def test_wigner(crystal_name, scaling_factor): # The first value of 0.1 corresponds to what would be # chosen by default for the "wigner_sc" or "wigner_bcc_cubiccell" structure. - smearings = torch.tensor([0.1, 0.06, 0.019], dtype=torch.float64) - for smearing in smearings: + for smearing in [0.1, 0.06, 0.019]: # Readjust smearing parameter to match nearest neighbor distance if crystal_name in ["wigner_fcc", "wigner_fcc_cubiccell"]: - smeareff = float(smearing) / np.sqrt(2) + smeareff = smearing / np.sqrt(2) elif crystal_name in ["wigner_bcc_cubiccell", "wigner_bcc"]: - smeareff = float(smearing) * np.sqrt(3) / 2 + smeareff = smearing * np.sqrt(3) / 2 elif crystal_name == "wigner_sc": - smeareff = float(smearing) + smeareff = smearing smeareff *= scaling_factor # Compute potential and compare against reference calc = EwaldCalculator( - InversePowerLawPotential( - exponent=1, - smearing=smeareff, - ), + InversePowerLawPotential(exponent=1, smearing=smeareff, dtype=DTYPE), lr_wavelength=smeareff / 2, + dtype=DTYPE, ) - calc.to(dtype=DTYPE) potentials = calc.forward( positions=positions, charges=charges, @@ -253,25 +251,28 @@ def test_random_structure( if calc_name == "ewald": calc = EwaldCalculator( - CoulombPotential(smearing=smearing), + CoulombPotential(smearing=smearing, dtype=DTYPE), lr_wavelength=0.5 * smearing, full_neighbor_list=full_neighbor_list, prefactor=torchpme.prefactors.eV_A, + dtype=DTYPE, ) elif calc_name == "pme": calc = PMECalculator( - CoulombPotential(smearing=smearing), + CoulombPotential(smearing=smearing, dtype=DTYPE), mesh_spacing=smearing / 8.0, full_neighbor_list=full_neighbor_list, prefactor=torchpme.prefactors.eV_A, + dtype=DTYPE, ) elif calc_name == "p3m": calc = P3MCalculator( - CoulombPotential(smearing=smearing), + CoulombPotential(smearing=smearing, dtype=DTYPE), mesh_spacing=smearing / 8.0, full_neighbor_list=full_neighbor_list, prefactor=torchpme.prefactors.eV_A, + dtype=DTYPE, ) neighbor_indices, neighbor_shifts = neighbor_list( @@ -294,7 +295,6 @@ def test_random_structure( neighbor_shifts=neighbor_shifts, ) - calc.to(dtype=DTYPE) potentials = calc.forward( positions=positions, charges=charges, diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py index 83377006..209355d4 100644 --- a/tests/calculators/test_workflow.py +++ b/tests/calculators/test_workflow.py @@ -4,7 +4,6 @@ """ import io -import math import pytest import torch @@ -17,15 +16,15 @@ PMECalculator, ) -AVAILABLE_DEVICES = ["cpu"] + torch.cuda.is_available() * ["cuda"] -MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) -CHARGES_CSCL = torch.tensor([1.0, -1.0]) +DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"] +DTYPES = [torch.float32, torch.float64] SMEARING = 0.1 LR_WAVELENGTH = SMEARING / 4 MESH_SPACING = SMEARING / 4 -@pytest.mark.parametrize("device", AVAILABLE_DEVICES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize( ("CalculatorClass", "params"), [ @@ -59,125 +58,128 @@ ], ) class TestWorkflow: - def cscl_system(self, device=None): + def cscl_system(self, device=None, dtype=None): """CsCl crystal. Same as in the madelung test""" - if device is None: - device = torch.device("cpu") - - positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - charges = torch.tensor([1.0, -1.0]).reshape((-1, 1)) - cell = torch.eye(3) - neighbor_indices = torch.tensor([[0, 1]], dtype=torch.int64) - neighbor_distances = torch.tensor([0.8660]) - - return ( - charges.to(device=device), - cell.to(device=device), - positions.to(device=device), - neighbor_indices.to(device=device), - neighbor_distances.to(device=device), + device = torch.get_default_device() if device is None else device + dtype = torch.get_default_dtype() if dtype is None else dtype + + positions = torch.tensor( + [[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype, device=device ) + charges = torch.tensor([1.0, -1.0], dtype=dtype, device=device).reshape((-1, 1)) + cell = torch.eye(3, dtype=dtype, device=device) + neighbor_indices = torch.tensor([[0, 1]], dtype=torch.int64, device=device) + neighbor_distances = torch.tensor([0.8660], dtype=dtype, device=device) + + return charges, cell, positions, neighbor_indices, neighbor_distances - def test_smearing_non_positive(self, CalculatorClass, params, device): + def test_smearing_non_positive(self, CalculatorClass, params, device, dtype): params = params.copy() match = r"`smearing` .* has to be positive" if type(CalculatorClass) in [EwaldCalculator, PMECalculator]: params["smearing"] = 0 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device) + CalculatorClass(**params, device=device, dtype=dtype) params["smearing"] = -0.1 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device) + CalculatorClass(**params, device=device, dtype=dtype) - def test_interpolation_order_error(self, CalculatorClass, params, device): + def test_interpolation_order_error(self, CalculatorClass, params, device, dtype): params = params.copy() if type(CalculatorClass) in [PMECalculator]: match = "Only `interpolation_nodes` from 1 to 5" params["interpolation_nodes"] = 10 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device) + CalculatorClass(**params, device=device, dtype=dtype) - def test_lr_wavelength_non_positive(self, CalculatorClass, params, device): + def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype): params = params.copy() match = r"`lr_wavelength` .* has to be positive" if type(CalculatorClass) in [EwaldCalculator]: params["lr_wavelength"] = 0 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device) + CalculatorClass(**params, device=device, dtype=dtype) params["lr_wavelength"] = -0.1 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device) + CalculatorClass(**params, device=device, dtype=dtype) - def test_dtype_device(self, CalculatorClass, params, device): + def test_dtype_device(self, CalculatorClass, params, device, dtype): """Test that the output dtype and device are the same as the input.""" - dtype = torch.float64 params = params.copy() - positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device) - charges = torch.ones((1, 2), dtype=dtype, device=device) - cell = torch.eye(3, dtype=dtype, device=device) - neighbor_indices = torch.tensor([[0, 0]], device=device) - neighbor_distances = torch.tensor([0.1], dtype=dtype, device=device) params["potential"].device = device - calculator = CalculatorClass(**params, device=device) - potential = calculator.forward( - charges=charges, - cell=cell, - positions=positions, - neighbor_indices=neighbor_indices, - neighbor_distances=neighbor_distances, - ) + params["potential"].dtype = dtype + + calculator = CalculatorClass(**params, device=device, dtype=dtype) + potential = calculator.forward(*self.cscl_system(device=device, dtype=dtype)) assert potential.dtype == dtype - assert potential.device.type == device - def check_operation(self, calculator, device): + if isinstance(device, torch.device): + assert potential.device == device + else: + assert potential.device.type == device + + def check_operation(self, calculator, device, dtype): """Make sure computation runs and returns a torch.Tensor.""" - descriptor = calculator.forward(*self.cscl_system(device)) + descriptor = calculator.forward(*self.cscl_system(device=device, dtype=dtype)) assert type(descriptor) is torch.Tensor - def test_operation_as_python(self, CalculatorClass, params, device): + def test_operation_as_python(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a normal python script""" params = params.copy() params["potential"].device = device - calculator = CalculatorClass(**params, device=device) - self.check_operation(calculator=calculator, device=device) + params["potential"].dtype = dtype + + calculator = CalculatorClass(**params, device=device, dtype=dtype) + self.check_operation(calculator=calculator, device=device, dtype=dtype) - def test_operation_as_torch_script(self, CalculatorClass, params, device): + def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a compiled torch script module.""" params = params.copy() params["potential"].device = device - calculator = CalculatorClass(**params, device=device) + params["potential"].dtype = dtype + + calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) - self.check_operation(calculator=scripted, device=device) + self.check_operation(calculator=scripted, device=device, dtype=dtype) - def test_save_load(self, CalculatorClass, params, device): + def test_save_load(self, CalculatorClass, params, device, dtype): params = params.copy() params["potential"].device = device - calculator = CalculatorClass(**params, device=device) + params["potential"].dtype = dtype + + calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) with io.BytesIO() as buffer: torch.jit.save(scripted, buffer) buffer.seek(0) torch.jit.load(buffer) - def test_prefactor(self, CalculatorClass, params, device): + def test_prefactor(self, CalculatorClass, params, device, dtype): """Test if the prefactor is applied correctly.""" params = params.copy() params["potential"].device = device + params["potential"].dtype = dtype + prefactor = 2.0 - calculator1 = CalculatorClass(**params, device=device) - calculator2 = CalculatorClass(**params, prefactor=prefactor, device=device) - potentials1 = calculator1.forward(*self.cscl_system()) - potentials2 = calculator2.forward(*self.cscl_system()) + calculator1 = CalculatorClass(**params, device=device, dtype=dtype) + calculator2 = CalculatorClass( + **params, prefactor=prefactor, device=device, dtype=dtype + ) + + potentials1 = calculator1.forward(*self.cscl_system(device=device, dtype=dtype)) + potentials2 = calculator2.forward(*self.cscl_system(device=device, dtype=dtype)) + assert torch.allclose(potentials1 * prefactor, potentials2) - def test_not_nan(self, CalculatorClass, params, device): + def test_not_nan(self, CalculatorClass, params, device, dtype): """Make sure derivatives are not NaN.""" params = params.copy() params["potential"].device = device + params["potential"].dtype = dtype - calculator = CalculatorClass(**params, device=device) - system = self.cscl_system(device) + calculator = CalculatorClass(**params, device=device, dtype=dtype) + system = self.cscl_system(device=device, dtype=dtype) system[0].requires_grad = True system[1].requires_grad = True system[2].requires_grad = True @@ -203,26 +205,45 @@ def test_not_nan(self, CalculatorClass, params, device): torch.autograd.grad(energy, system[2], retain_graph=True)[0] ).any() - def test_dtype_and_device_incompatability(self, CalculatorClass, params, device): - """Test that the calculator raises an error if the dtype and device are incompatible.""" + def test_dtype_and_device_incompatability( + self, CalculatorClass, params, device, dtype + ): + """Test that the calculator raises an error if the dtype and device are incompatible with potential.""" params = params.copy() + + other_dtype = torch.float32 if dtype == torch.float64 else torch.float64 + params["potential"].device = device - params["potential"].dtype = torch.float64 - with pytest.raises(AssertionError, match=".*dtype.*"): - CalculatorClass(**params, dtype=torch.float32, device=device) - with pytest.raises(AssertionError, match=".*device.*"): - CalculatorClass( - **params, dtype=params["potential"].dtype, device=torch.device("meta") - ) + params["potential"].dtype = dtype + + match = ( + rf"dtype of `potential` \({params['potential'].dtype}\) must be same as " + rf"of `calculator` \({other_dtype}\)" + ) + with pytest.raises(TypeError, match=match): + CalculatorClass(**params, dtype=other_dtype, device=device) + + match = ( + rf"device of `potential` \({params['potential'].device}\) must be same as " + rf"of `calculator` \(meta\)" + ) + with pytest.raises(ValueError, match=match): + CalculatorClass(**params, dtype=dtype, device=torch.device("meta")) def test_potential_and_calculator_incompatability( - self, CalculatorClass, params, device + self, + CalculatorClass, + params, + device, + dtype, ): """Test that the calculator raises an error if the potential and calculator are incompatible.""" params = params.copy() params["potential"].device = device + params["potential"].dtype = dtype + params["potential"] = torch.jit.script(params["potential"]) with pytest.raises( TypeError, match="Potential must be an instance of Potential, got.*" ): - CalculatorClass(**params) + CalculatorClass(**params, device=device, dtype=dtype) diff --git a/tests/helpers.py b/tests/helpers.py index f970f596..d38e9f9b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -227,14 +227,13 @@ def define_crystal(crystal_name="CsCl", dtype=None, device=None): else: raise ValueError(f"crystal_name = {crystal_name} is not supported!") - madelung_ref = torch.tensor(madelung_ref) charges = charges.reshape((-1, 1)) return ( positions.to(device=device, dtype=dtype), charges.to(device=device, dtype=dtype), cell.to(device=device, dtype=dtype), - madelung_ref, + torch.tensor(madelung_ref, device=device, dtype=dtype), num_formula_units, ) diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py index 025bd91c..6beed927 100644 --- a/tests/metatensor/test_workflow_metatensor.py +++ b/tests/metatensor/test_workflow_metatensor.py @@ -11,14 +11,15 @@ mts_torch = pytest.importorskip("metatensor.torch") mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") -AVAILABLE_DEVICES = [torch.device("cpu")] + torch.cuda.is_available() * [ - torch.device("cuda") -] +DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"] +DTYPES = [torch.float32, torch.float64] SMEARING = 0.1 LR_WAVELENGTH = SMEARING / 4 MESH_SPACING = SMEARING / 4 +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize( ("CalculatorClass", "params"), [ @@ -52,7 +53,10 @@ ], ) class TestWorkflow: - def system(self, device=None): + def system(self, device=None, dtype=None): + device = torch.get_default_device() if device is None else device + dtype = torch.get_default_dtype() if dtype is None else dtype + system = mts_atomistic.System( types=torch.tensor([1, 2, 2]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.2], [0.0, 0.0, 0.5]]), @@ -92,30 +96,32 @@ def system(self, device=None): return system.to(device=device), neighbors.to(device=device) - def check_operation(self, calculator, device): + def check_operation(self, calculator, device, dtype): """Make sure computation runs and returns a metatensor.TensorMap.""" - system, neighbors = self.system(device) + system, neighbors = self.system(device=device, dtype=dtype) descriptor = calculator.forward(system, neighbors) assert isinstance(descriptor, torch.ScriptObject) if version.parse(torch.__version__) >= version.parse("2.1"): assert descriptor._type().name() == "TensorMap" - @pytest.mark.parametrize("device", AVAILABLE_DEVICES) - def test_operation_as_python(self, CalculatorClass, params, device): + def test_operation_as_python(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a normal python script""" calculator = CalculatorClass(**params) - self.check_operation(calculator=calculator, device=device) + self.check_operation(calculator=calculator, device=device, dtype=dtype) - @pytest.mark.parametrize("device", AVAILABLE_DEVICES) - def test_operation_as_torch_script(self, CalculatorClass, params, device): + def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a compiled torch script module.""" calculator = CalculatorClass(**params) scripted = torch.jit.script(calculator) - self.check_operation(calculator=scripted, device=device) + self.check_operation(calculator=scripted, device=device, dtype=dtype) - def test_save_load(self, CalculatorClass, params): - calculator = CalculatorClass(**params) + def test_save_load(self, CalculatorClass, params, device, dtype): + params = params.copy() + params["potential"].device = device + params["potential"].dtype = dtype + + calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) with io.BytesIO() as buffer: torch.jit.save(scripted, buffer) diff --git a/tests/tuning/test_tuning.py b/tests/tuning/test_tuning.py index 17a43580..6477115b 100644 --- a/tests/tuning/test_tuning.py +++ b/tests/tuning/test_tuning.py @@ -16,31 +16,46 @@ sys.path.append(str(Path(__file__).parents[1])) from helpers import define_crystal, neighbor_list -DTYPE = torch.float32 -DEVICE = "cpu" DEFAULT_CUTOFF = 4.4 -CHARGES_1 = torch.ones((4, 1), dtype=DTYPE, device=DEVICE) -POSITIONS_1 = 0.3 * torch.arange(12, dtype=DTYPE, device=DEVICE).reshape((4, 3)) -CELL_1 = torch.eye(3, dtype=DTYPE, device=DEVICE) +DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"] +DTYPES = [torch.float32, torch.float64] -def test_TunerBase_double(): +def system(device=None, dtype=None): + device = torch.get_default_device() if device is None else device + dtype = torch.get_default_dtype() if dtype is None else dtype + + charges = torch.ones((4, 1), dtype=dtype, device=device) + cell = torch.eye(3, dtype=dtype, device=device) + positions = 0.3 * torch.arange(12, dtype=dtype, device=device).reshape((4, 3)) + + return charges, cell, positions + + +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("dtype", DTYPES) +def test_TunerBase_init(device, dtype): """ - Check that `TunerBase` initilizes with double precisions tensors. + Check that `TunerBase` initilizes correctly. We are using dummy `neighbor_indices` and `neighbor_distances` to verify types. Have to be sure that these dummy variables are initilized correctly. """ + charges, cell, positions = system(device, dtype) TunerBase( - charges=CHARGES_1.to(dtype=torch.float64), - cell=CELL_1.to(dtype=torch.float64), - positions=POSITIONS_1.to(dtype=torch.float64), + charges=charges, + cell=cell, + positions=positions, cutoff=DEFAULT_CUTOFF, calculator=1.0, exponent=1, + dtype=dtype, + device=device, ) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize( ("calculator", "tune", "param_length"), [ @@ -50,13 +65,15 @@ def test_TunerBase_double(): ], ) @pytest.mark.parametrize("accuracy", [1e-1, 1e-3, 1e-5]) -def test_parameter_choose(calculator, tune, param_length, accuracy): +def test_parameter_choose(device, dtype, calculator, tune, param_length, accuracy): """ Check that the Madelung constants obtained from the Ewald sum calculator matches the reference values and that all branches of the from_accuracy method are covered. """ # Get input parameters and adjust to account for scaling - pos, charges, cell, madelung_ref, num_units = define_crystal() + pos, charges, cell, madelung_ref, num_units = define_crystal( + dtype=dtype, device=device + ) # Compute neighbor list neighbor_indices, neighbor_distances = neighbor_list( @@ -71,13 +88,17 @@ def test_parameter_choose(calculator, tune, param_length, accuracy): neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, accuracy=accuracy, + dtype=dtype, + device=device, ) assert len(params) == param_length # Compute potential and compare against target value using default hypers calc = calculator( - potential=(CoulombPotential(smearing=smearing)), + potential=(CoulombPotential(smearing=smearing, dtype=dtype, device=device)), + dtype=dtype, + device=device, **params, ) potentials = calc.forward( @@ -103,12 +124,12 @@ def test_accuracy_error(tune): ) with pytest.raises(ValueError, match=match): tune( - charges, - cell, - pos, - DEFAULT_CUTOFF, - neighbor_indices, - neighbor_distances, + charges=charges, + cell=cell, + positions=pos, + cutoff=DEFAULT_CUTOFF, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, accuracy="foo", ) @@ -116,79 +137,90 @@ def test_accuracy_error(tune): @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_exponent_not_1_error(tune): pos, charges, cell, _, _ = define_crystal() - - match = "Only exponent = 1 is supported but got 2." neighbor_indices, neighbor_distances = neighbor_list( positions=pos, box=cell, cutoff=DEFAULT_CUTOFF ) + + match = "Only exponent = 1 is supported but got 2." with pytest.raises(NotImplementedError, match=match): tune( - charges, - cell, - pos, - DEFAULT_CUTOFF, - neighbor_indices, - neighbor_distances, + charges=charges, + cell=cell, + positions=pos, + cutoff=DEFAULT_CUTOFF, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, exponent=2, ) @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_invalid_shape_positions(tune): + charges, cell, _ = system() match = ( r"`positions` must be a tensor with shape \[n_atoms, 3\], got tensor with " r"shape \[4, 5\]" ) with pytest.raises(ValueError, match=match): tune( - CHARGES_1, - CELL_1, - torch.ones((4, 5), dtype=DTYPE, device=DEVICE), - DEFAULT_CUTOFF, - None, # dummy neighbor indices - None, # dummy neighbor distances + charges=charges, + cell=cell, + positions=torch.ones((4, 5)), + cutoff=DEFAULT_CUTOFF, + neighbor_indices=None, + neighbor_distances=None, ) # Tests for invalid shape, dtype and device of cell @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_invalid_shape_cell(tune): + charges, _, positions = system() match = ( r"`cell` must be a tensor with shape \[3, 3\], got tensor with shape \[2, 2\]" ) with pytest.raises(ValueError, match=match): tune( - CHARGES_1, - torch.ones([2, 2], dtype=DTYPE, device=DEVICE), - POSITIONS_1, - DEFAULT_CUTOFF, - None, # dummy neighbor indices - None, # dummy neighbor distances + charges=charges, + cell=torch.ones([2, 2]), + positions=positions, + cutoff=DEFAULT_CUTOFF, + neighbor_indices=None, + neighbor_distances=None, ) @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_invalid_cell(tune): + charges, _, positions = system() match = ( "provided `cell` has a determinant of 0 and therefore is not valid for " "periodic calculation" ) with pytest.raises(ValueError, match=match): - tune(CHARGES_1, torch.zeros(3, 3), POSITIONS_1, DEFAULT_CUTOFF, None, None) + tune( + charges=charges, + cell=torch.zeros(3, 3), + positions=positions, + cutoff=DEFAULT_CUTOFF, + neighbor_indices=None, + neighbor_distances=None, + ) @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_invalid_dtype_cell(tune): + charges, _, positions = system() match = ( r"type of `cell` \(torch.float64\) must be same as `positions` " r"\(torch.float32\)" ) - with pytest.raises(ValueError, match=match): + with pytest.raises(TypeError, match=match): tune( - CHARGES_1, - torch.eye(3, dtype=torch.float64, device=DEVICE), - POSITIONS_1, - DEFAULT_CUTOFF, - None, - None, + charges=charges, + cell=torch.eye(3, dtype=torch.float64), + positions=positions, + cutoff=DEFAULT_CUTOFF, + neighbor_indices=None, + neighbor_distances=None, )