From d2b85ab39b4209f50bd0a34e986252a8fea9195c Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Mon, 27 Jan 2025 01:45:37 +0100 Subject: [PATCH 1/4] Fix `device` and `dtype` not being specified in the `__init__` of `P3MCalculator` --- src/torchpme/calculators/p3m.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchpme/calculators/p3m.py b/src/torchpme/calculators/p3m.py index 54b97bcb..eb23c780 100644 --- a/src/torchpme/calculators/p3m.py +++ b/src/torchpme/calculators/p3m.py @@ -73,8 +73,8 @@ def __init__( ) self.kspace_filter: P3MKSpaceFilter = P3MKSpaceFilter( - cell=torch.eye(3), - ns_mesh=torch.ones(3, dtype=int), + cell=torch.eye(3, dtype=self.dtype, device=self.device), + ns_mesh=torch.ones(3, dtype=int, device=self.device), interpolation_nodes=self.interpolation_nodes, kernel=self.potential, mode=0, # Green's function for point-charge potentials @@ -84,8 +84,8 @@ def __init__( ) self.mesh_interpolator: MeshInterpolator = MeshInterpolator( - cell=torch.eye(3), - ns_mesh=torch.ones(3, dtype=int), + cell=torch.eye(3, dtype=self.dtype, device=self.device), + ns_mesh=torch.ones(3, dtype=int, device=self.device), interpolation_nodes=self.interpolation_nodes, method="P3M", ) From c865f64f5d8043023ebd1446e07a1cdae15c5cf6 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Tue, 28 Jan 2025 16:17:29 +0100 Subject: [PATCH 2/4] add consistent tests for dtypes and devices --- docs/src/references/changelog.rst | 8 + src/torchpme/_utils.py | 38 +++-- src/torchpme/calculators/calculator.py | 27 +-- src/torchpme/calculators/ewald.py | 4 +- src/torchpme/calculators/p3m.py | 4 +- src/torchpme/calculators/pme.py | 4 +- src/torchpme/potentials/combined.py | 4 +- src/torchpme/potentials/coulomb.py | 4 +- src/torchpme/potentials/inversepowerlaw.py | 4 +- src/torchpme/potentials/potential.py | 4 +- src/torchpme/potentials/spline.py | 4 +- src/torchpme/tuning/ewald.py | 4 +- src/torchpme/tuning/p3m.py | 4 +- src/torchpme/tuning/pme.py | 4 +- src/torchpme/tuning/tuner.py | 23 ++- tests/calculators/test_calculator.py | 35 +++- tests/calculators/test_values_direct.py | 6 +- tests/calculators/test_values_ewald.py | 42 ++--- tests/calculators/test_workflow.py | 169 +++++++++++-------- tests/helpers.py | 3 +- tests/metatensor/test_workflow_metatensor.py | 34 ++-- tests/tuning/test_tuning.py | 126 ++++++++------ 22 files changed, 339 insertions(+), 216 deletions(-) diff --git a/docs/src/references/changelog.rst b/docs/src/references/changelog.rst index 5631f33c..d927b525 100644 --- a/docs/src/references/changelog.rst +++ b/docs/src/references/changelog.rst @@ -27,9 +27,17 @@ changelog `_ format. This project follows Added ##### +* Enhanced ``device`` and ``dtype`` consistency checks throughout the library * Require consistent ``dtype`` between ``positions`` and ``neighbor_distances`` in ``Calculator`` classes and tuning functions. + +Fixed +##### + +* Fix ``device`` and ``dtype`` not being used in the init of the ``P3MCalculator`` + + `Version 0.2.0 `_ - 2025-01-23 ------------------------------------------------------------------------------------------ diff --git a/src/torchpme/_utils.py b/src/torchpme/_utils.py index 575e9edc..f282a626 100644 --- a/src/torchpme/_utils.py +++ b/src/torchpme/_utils.py @@ -10,9 +10,23 @@ def _validate_parameters( neighbor_indices: torch.Tensor, neighbor_distances: torch.Tensor, smearing: Union[float, None], + dtype: torch.dtype, + device: Union[str, torch.device], ) -> None: - device = positions.device - dtype = positions.dtype + if positions.dtype != dtype: + raise TypeError( + f"type of `positions` ({positions.dtype}) must be same as class " + f"type ({dtype})" + ) + + if isinstance(device, torch.device): + device = device.type + + if positions.device.type != device: + raise ValueError( + f"device of `positions` ({positions.device}) must be same as class " + f"device ({device})" + ) # check shape, dtype and device of positions num_atoms = len(positions) @@ -29,12 +43,12 @@ def _validate_parameters( f"{list(cell.shape)}" ) - if cell.dtype != dtype: - raise ValueError( + if cell.dtype != positions.dtype: + raise TypeError( f"type of `cell` ({cell.dtype}) must be same as `positions` ({dtype})" ) - if cell.device != device: + if cell.device != positions.device: raise ValueError( f"device of `cell` ({cell.device}) must be same as `positions` ({device})" ) @@ -63,12 +77,12 @@ def _validate_parameters( f"{len(positions)} atoms" ) - if charges.dtype != dtype: - raise ValueError( + if charges.dtype != positions.dtype: + raise TypeError( f"type of `charges` ({charges.dtype}) must be same as `positions` ({dtype})" ) - if charges.device != device: + if charges.device != positions.device: raise ValueError( f"device of `charges` ({charges.device}) must be same as `positions` " f"({device})" @@ -82,7 +96,7 @@ def _validate_parameters( "structure" ) - if neighbor_indices.device != device: + if neighbor_indices.device != positions.device: raise ValueError( f"device of `neighbor_indices` ({neighbor_indices.device}) must be " f"same as `positions` ({device})" @@ -95,14 +109,14 @@ def _validate_parameters( f"{list(neighbor_indices.shape)} and {list(neighbor_distances.shape)}" ) - if neighbor_distances.device != device: + if neighbor_distances.device != positions.device: raise ValueError( f"device of `neighbor_distances` ({neighbor_distances.device}) must be " f"same as `positions` ({device})" ) - if neighbor_distances.dtype != dtype: - raise ValueError( + if neighbor_distances.dtype != positions.dtype: + raise TypeError( f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same " f"as `positions` ({dtype})" ) diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py index 70ede3ef..29eaadff 100644 --- a/src/torchpme/calculators/calculator.py +++ b/src/torchpme/calculators/calculator.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch from torch import profiler @@ -37,7 +37,7 @@ def __init__( full_neighbor_list: bool = False, prefactor: float = 1.0, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__() @@ -48,17 +48,20 @@ def __init__( self.device = torch.get_default_device() if device is None else device self.dtype = torch.get_default_dtype() if dtype is None else dtype - self.potential = potential - assert self.dtype == self.potential.dtype, ( - f"Potential and Calculator must have the same dtype, got {self.dtype} and " - f"{self.potential.dtype}" - ) - assert self.device == self.potential.device, ( - f"Potential and Calculator must have the same device, got {self.device} and " - f"{self.potential.device}" - ) + if self.dtype != potential.dtype: + raise TypeError( + f"dtype of `potential` ({potential.dtype}) must be same as of " + f"`calculator` ({self.dtype})" + ) + + if self.device != potential.device: + raise ValueError( + f"device of `potential` ({potential.device}) must be same as of " + f"`calculator` ({self.device})" + ) + self.potential = potential self.full_neighbor_list = full_neighbor_list self.prefactor = prefactor @@ -179,6 +182,8 @@ def forward( neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, smearing=self.potential.smearing, + dtype=self.dtype, + device=self.device, ) # Compute short-range (SR) part using a real space sum diff --git a/src/torchpme/calculators/ewald.py b/src/torchpme/calculators/ewald.py index 7d213f2f..83e2dc85 100644 --- a/src/torchpme/calculators/ewald.py +++ b/src/torchpme/calculators/ewald.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -66,7 +66,7 @@ def __init__( full_neighbor_list: bool = False, prefactor: float = 1.0, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__( potential=potential, diff --git a/src/torchpme/calculators/p3m.py b/src/torchpme/calculators/p3m.py index eb23c780..f85533db 100644 --- a/src/torchpme/calculators/p3m.py +++ b/src/torchpme/calculators/p3m.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -56,7 +56,7 @@ def __init__( full_neighbor_list: bool = False, prefactor: float = 1.0, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): self.mesh_spacing: float = mesh_spacing diff --git a/src/torchpme/calculators/pme.py b/src/torchpme/calculators/pme.py index 95f74216..dd389812 100644 --- a/src/torchpme/calculators/pme.py +++ b/src/torchpme/calculators/pme.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch from torch import profiler @@ -59,7 +59,7 @@ def __init__( full_neighbor_list: bool = False, prefactor: float = 1.0, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__( potential=potential, diff --git a/src/torchpme/potentials/combined.py b/src/torchpme/potentials/combined.py index d76a20c0..212f4744 100644 --- a/src/torchpme/potentials/combined.py +++ b/src/torchpme/potentials/combined.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -39,7 +39,7 @@ def __init__( smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__( smearing=smearing, diff --git a/src/torchpme/potentials/coulomb.py b/src/torchpme/potentials/coulomb.py index 4cde5611..1e35897c 100644 --- a/src/torchpme/potentials/coulomb.py +++ b/src/torchpme/potentials/coulomb.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -35,7 +35,7 @@ def __init__( smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__(smearing, exclusion_radius, dtype, device) diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py index 35ff7ac7..374ab56e 100644 --- a/src/torchpme/potentials/inversepowerlaw.py +++ b/src/torchpme/potentials/inversepowerlaw.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch from torch.special import gammainc @@ -41,7 +41,7 @@ def __init__( smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__(smearing, exclusion_radius, dtype, device) diff --git a/src/torchpme/potentials/potential.py b/src/torchpme/potentials/potential.py index 674a8632..1efa783d 100644 --- a/src/torchpme/potentials/potential.py +++ b/src/torchpme/potentials/potential.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -39,7 +39,7 @@ def __init__( smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__() self.dtype = torch.get_default_dtype() if dtype is None else dtype diff --git a/src/torchpme/potentials/spline.py b/src/torchpme/potentials/spline.py index e8ffc3c5..b58d31eb 100644 --- a/src/torchpme/potentials/spline.py +++ b/src/torchpme/potentials/spline.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -58,7 +58,7 @@ def __init__( smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__( smearing=smearing, diff --git a/src/torchpme/tuning/ewald.py b/src/torchpme/tuning/ewald.py index 6452fa7d..459f0a02 100644 --- a/src/torchpme/tuning/ewald.py +++ b/src/torchpme/tuning/ewald.py @@ -1,5 +1,5 @@ import math -from typing import Any, Optional +from typing import Any, Optional, Union from warnings import warn import torch @@ -20,7 +20,7 @@ def tune_ewald( ns_hi: int = 14, accuracy: float = 1e-3, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ) -> tuple[float, dict[str, Any], float]: r""" Find the optimal parameters for :class:`torchpme.EwaldCalculator`. diff --git a/src/torchpme/tuning/p3m.py b/src/torchpme/tuning/p3m.py index 5685ffaf..92a6bbe0 100644 --- a/src/torchpme/tuning/p3m.py +++ b/src/torchpme/tuning/p3m.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import Any, Optional +from typing import Any, Optional, Union from warnings import warn import torch @@ -80,7 +80,7 @@ def tune_p3m( mesh_hi: int = 7, accuracy: float = 1e-3, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ) -> tuple[float, dict[str, Any], float]: r""" Find the optimal parameters for :class:`torchpme.calculators.pme.PMECalculator`. diff --git a/src/torchpme/tuning/pme.py b/src/torchpme/tuning/pme.py index 55a0556e..f9ddc4d8 100644 --- a/src/torchpme/tuning/pme.py +++ b/src/torchpme/tuning/pme.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import Any, Optional +from typing import Any, Optional, Union from warnings import warn import torch @@ -23,7 +23,7 @@ def tune_pme( mesh_hi: int = 7, accuracy: float = 1e-3, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ) -> tuple[float, dict[str, Any], float]: r""" Find the optimal parameters for :class:`torchpme.PMECalculator`. diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py index 8461fda5..354d6029 100644 --- a/src/torchpme/tuning/tuner.py +++ b/src/torchpme/tuning/tuner.py @@ -1,6 +1,6 @@ import math import time -from typing import Optional +from typing import Optional, Union import torch @@ -80,13 +80,16 @@ def __init__( calculator: type[Calculator], exponent: int = 1, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): if exponent != 1: raise NotImplementedError( f"Only exponent = 1 is supported but got {exponent}." ) + self.device = torch.get_default_device() if device is None else device + self.dtype = torch.get_default_dtype() if dtype is None else dtype + _validate_parameters( charges=charges, cell=cell, @@ -96,16 +99,15 @@ def __init__( [1.0], device=positions.device, dtype=positions.dtype ), smearing=1.0, # dummy value because; always have range-seperated potentials + dtype=self.dtype, + device=self.device, ) - self.charges = charges self.cell = cell self.positions = positions self.cutoff = cutoff self.calculator = calculator self.exponent = exponent - self.device = torch.get_default_device() if device is None else device - self.dtype = torch.get_default_dtype() if dtype is None else dtype self._prefac = 2 * float((charges**2).sum()) / math.sqrt(len(positions)) @@ -182,7 +184,7 @@ def __init__( neighbor_distances: torch.Tensor, exponent: int = 1, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__( charges=charges, @@ -280,10 +282,13 @@ def __init__( n_warmup: int = 4, run_backward: Optional[bool] = True, dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: Union[None, str, torch.device] = None, ): super().__init__() + self.dtype = torch.get_default_dtype() if dtype is None else dtype + self.device = torch.get_default_device() if device is None else device + _validate_parameters( charges=charges, cell=cell, @@ -291,13 +296,13 @@ def __init__( neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, smearing=1.0, # dummy value because; always have range-seperated potentials + device=self.device, + dtype=self.dtype, ) self.charges = charges self.cell = cell self.positions = positions - self.dtype = dtype - self.device = device self.n_repeat = n_repeat self.n_warmup = n_warmup self.run_backward = run_backward diff --git a/tests/calculators/test_calculator.py b/tests/calculators/test_calculator.py index a6fc976a..567b0f15 100644 --- a/tests/calculators/test_calculator.py +++ b/tests/calculators/test_calculator.py @@ -43,6 +43,35 @@ def test_compute_output_shapes(): assert result.shape == charges.shape +def test_wrong_device_positions(): + calculator = CalculatorTest() + match = r"device of `positions` \(meta\) must be same as class device \(cpu\)" + with pytest.raises(ValueError, match=match): + calculator.forward( + positions=POSITIONS_1.to(device="meta"), + charges=CHARGES_1, + cell=CELL_1, + neighbor_indices=NEIGHBOR_INDICES, + neighbor_distances=NEIGHBOR_DISTANCES, + ) + + +def test_wrong_dtype_positions(): + calculator = CalculatorTest() + match = ( + r"type of `positions` \(torch.float64\) must be same as class type " + r"\(torch.float32\)" + ) + with pytest.raises(TypeError, match=match): + calculator.forward( + positions=POSITIONS_1.to(dtype=torch.float64), + charges=CHARGES_1, + cell=CELL_1, + neighbor_indices=NEIGHBOR_INDICES, + neighbor_distances=NEIGHBOR_DISTANCES, + ) + + # Tests for invalid shape, dtype and device of positions def test_invalid_shape_positions(): calculator = CalculatorTest() @@ -82,7 +111,7 @@ def test_invalid_dtype_cell(): r"type of `cell` \(torch.float64\) must be same as `positions` " r"\(torch.float32\)" ) - with pytest.raises(ValueError, match=match): + with pytest.raises(TypeError, match=match): calculator.forward( positions=POSITIONS_1, charges=CHARGES_1, @@ -163,7 +192,7 @@ def test_invalid_dtype_charges(): r"type of `charges` \(torch.float64\) must be same as `positions` " r"\(torch.float32\)" ) - with pytest.raises(ValueError, match=match): + with pytest.raises(TypeError, match=match): calculator.forward( positions=POSITIONS_1, charges=CHARGES_1.to(dtype=torch.float64), @@ -252,7 +281,7 @@ def test_invalid_dtype_neighbor_distances(): r"type of `neighbor_distances` \(torch.float64\) must be same " r"as `positions` \(torch.float32\)" ) - with pytest.raises(ValueError, match=match): + with pytest.raises(TypeError, match=match): calculator.forward( positions=POSITIONS_1, charges=CHARGES_1, diff --git a/tests/calculators/test_values_direct.py b/tests/calculators/test_values_direct.py index 47f819af..a5ace407 100644 --- a/tests/calculators/test_values_direct.py +++ b/tests/calculators/test_values_direct.py @@ -17,7 +17,11 @@ class CalculatorTest(Calculator): def __init__(self, **kwargs): super().__init__( - potential=CoulombPotential(smearing=None, exclusion_radius=None), **kwargs + potential=CoulombPotential( + smearing=None, exclusion_radius=None, dtype=DTYPE + ), + **kwargs, + dtype=DTYPE, ) diff --git a/tests/calculators/test_values_ewald.py b/tests/calculators/test_values_ewald.py index cb3ff708..edcb886e 100644 --- a/tests/calculators/test_values_ewald.py +++ b/tests/calculators/test_values_ewald.py @@ -87,7 +87,9 @@ def test_madelung(crystal_name, scaling_factor, calc_name): to triclinic, as well as cation-anion ratios of 1:1, 1:2 and 2:1. """ # Get input parameters and adjust to account for scaling - pos, charges, cell, madelung_ref, num_units = define_crystal(crystal_name) + pos, charges, cell, madelung_ref, num_units = define_crystal( + crystal_name, dtype=DTYPE + ) pos *= scaling_factor cell *= scaling_factor madelung_ref /= scaling_factor @@ -99,11 +101,9 @@ def test_madelung(crystal_name, scaling_factor, calc_name): smearing = sr_cutoff / 5.0 lr_wavelength = 0.5 * smearing calc = EwaldCalculator( - InversePowerLawPotential( - exponent=1, - smearing=smearing, - ), + InversePowerLawPotential(exponent=1, smearing=smearing, dtype=DTYPE), lr_wavelength=lr_wavelength, + dtype=DTYPE, ) rtol = 4e-6 elif calc_name == "pme": @@ -113,16 +113,19 @@ def test_madelung(crystal_name, scaling_factor, calc_name): InversePowerLawPotential( exponent=1, smearing=smearing, + dtype=DTYPE, ), mesh_spacing=smearing / 8, + dtype=DTYPE, ) rtol = 9e-4 elif calc_name == "p3m": sr_cutoff = 2 * scaling_factor smearing = sr_cutoff / 5.0 calc = P3MCalculator( - CoulombPotential(smearing=smearing), + CoulombPotential(smearing=smearing, dtype=DTYPE), mesh_spacing=smearing / 8, + dtype=DTYPE, ) rtol = 9e-4 @@ -132,7 +135,6 @@ def test_madelung(crystal_name, scaling_factor, calc_name): ) # Compute potential and compare against target value using default hypers - calc.to(dtype=DTYPE) potentials = calc.forward( positions=pos, charges=charges, @@ -186,26 +188,22 @@ def test_wigner(crystal_name, scaling_factor): # The first value of 0.1 corresponds to what would be # chosen by default for the "wigner_sc" or "wigner_bcc_cubiccell" structure. - smearings = torch.tensor([0.1, 0.06, 0.019], dtype=torch.float64) - for smearing in smearings: + for smearing in [0.1, 0.06, 0.019]: # Readjust smearing parameter to match nearest neighbor distance if crystal_name in ["wigner_fcc", "wigner_fcc_cubiccell"]: - smeareff = float(smearing) / np.sqrt(2) + smeareff = smearing / np.sqrt(2) elif crystal_name in ["wigner_bcc_cubiccell", "wigner_bcc"]: - smeareff = float(smearing) * np.sqrt(3) / 2 + smeareff = smearing * np.sqrt(3) / 2 elif crystal_name == "wigner_sc": - smeareff = float(smearing) + smeareff = smearing smeareff *= scaling_factor # Compute potential and compare against reference calc = EwaldCalculator( - InversePowerLawPotential( - exponent=1, - smearing=smeareff, - ), + InversePowerLawPotential(exponent=1, smearing=smeareff, dtype=DTYPE), lr_wavelength=smeareff / 2, + dtype=DTYPE, ) - calc.to(dtype=DTYPE) potentials = calc.forward( positions=positions, charges=charges, @@ -253,25 +251,28 @@ def test_random_structure( if calc_name == "ewald": calc = EwaldCalculator( - CoulombPotential(smearing=smearing), + CoulombPotential(smearing=smearing, dtype=DTYPE), lr_wavelength=0.5 * smearing, full_neighbor_list=full_neighbor_list, prefactor=torchpme.prefactors.eV_A, + dtype=DTYPE, ) elif calc_name == "pme": calc = PMECalculator( - CoulombPotential(smearing=smearing), + CoulombPotential(smearing=smearing, dtype=DTYPE), mesh_spacing=smearing / 8.0, full_neighbor_list=full_neighbor_list, prefactor=torchpme.prefactors.eV_A, + dtype=DTYPE, ) elif calc_name == "p3m": calc = P3MCalculator( - CoulombPotential(smearing=smearing), + CoulombPotential(smearing=smearing, dtype=DTYPE), mesh_spacing=smearing / 8.0, full_neighbor_list=full_neighbor_list, prefactor=torchpme.prefactors.eV_A, + dtype=DTYPE, ) neighbor_indices, neighbor_shifts = neighbor_list( @@ -294,7 +295,6 @@ def test_random_structure( neighbor_shifts=neighbor_shifts, ) - calc.to(dtype=DTYPE) potentials = calc.forward( positions=positions, charges=charges, diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py index 83377006..209355d4 100644 --- a/tests/calculators/test_workflow.py +++ b/tests/calculators/test_workflow.py @@ -4,7 +4,6 @@ """ import io -import math import pytest import torch @@ -17,15 +16,15 @@ PMECalculator, ) -AVAILABLE_DEVICES = ["cpu"] + torch.cuda.is_available() * ["cuda"] -MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) -CHARGES_CSCL = torch.tensor([1.0, -1.0]) +DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"] +DTYPES = [torch.float32, torch.float64] SMEARING = 0.1 LR_WAVELENGTH = SMEARING / 4 MESH_SPACING = SMEARING / 4 -@pytest.mark.parametrize("device", AVAILABLE_DEVICES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize( ("CalculatorClass", "params"), [ @@ -59,125 +58,128 @@ ], ) class TestWorkflow: - def cscl_system(self, device=None): + def cscl_system(self, device=None, dtype=None): """CsCl crystal. Same as in the madelung test""" - if device is None: - device = torch.device("cpu") - - positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - charges = torch.tensor([1.0, -1.0]).reshape((-1, 1)) - cell = torch.eye(3) - neighbor_indices = torch.tensor([[0, 1]], dtype=torch.int64) - neighbor_distances = torch.tensor([0.8660]) - - return ( - charges.to(device=device), - cell.to(device=device), - positions.to(device=device), - neighbor_indices.to(device=device), - neighbor_distances.to(device=device), + device = torch.get_default_device() if device is None else device + dtype = torch.get_default_dtype() if dtype is None else dtype + + positions = torch.tensor( + [[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype, device=device ) + charges = torch.tensor([1.0, -1.0], dtype=dtype, device=device).reshape((-1, 1)) + cell = torch.eye(3, dtype=dtype, device=device) + neighbor_indices = torch.tensor([[0, 1]], dtype=torch.int64, device=device) + neighbor_distances = torch.tensor([0.8660], dtype=dtype, device=device) + + return charges, cell, positions, neighbor_indices, neighbor_distances - def test_smearing_non_positive(self, CalculatorClass, params, device): + def test_smearing_non_positive(self, CalculatorClass, params, device, dtype): params = params.copy() match = r"`smearing` .* has to be positive" if type(CalculatorClass) in [EwaldCalculator, PMECalculator]: params["smearing"] = 0 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device) + CalculatorClass(**params, device=device, dtype=dtype) params["smearing"] = -0.1 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device) + CalculatorClass(**params, device=device, dtype=dtype) - def test_interpolation_order_error(self, CalculatorClass, params, device): + def test_interpolation_order_error(self, CalculatorClass, params, device, dtype): params = params.copy() if type(CalculatorClass) in [PMECalculator]: match = "Only `interpolation_nodes` from 1 to 5" params["interpolation_nodes"] = 10 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device) + CalculatorClass(**params, device=device, dtype=dtype) - def test_lr_wavelength_non_positive(self, CalculatorClass, params, device): + def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype): params = params.copy() match = r"`lr_wavelength` .* has to be positive" if type(CalculatorClass) in [EwaldCalculator]: params["lr_wavelength"] = 0 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device) + CalculatorClass(**params, device=device, dtype=dtype) params["lr_wavelength"] = -0.1 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device) + CalculatorClass(**params, device=device, dtype=dtype) - def test_dtype_device(self, CalculatorClass, params, device): + def test_dtype_device(self, CalculatorClass, params, device, dtype): """Test that the output dtype and device are the same as the input.""" - dtype = torch.float64 params = params.copy() - positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device) - charges = torch.ones((1, 2), dtype=dtype, device=device) - cell = torch.eye(3, dtype=dtype, device=device) - neighbor_indices = torch.tensor([[0, 0]], device=device) - neighbor_distances = torch.tensor([0.1], dtype=dtype, device=device) params["potential"].device = device - calculator = CalculatorClass(**params, device=device) - potential = calculator.forward( - charges=charges, - cell=cell, - positions=positions, - neighbor_indices=neighbor_indices, - neighbor_distances=neighbor_distances, - ) + params["potential"].dtype = dtype + + calculator = CalculatorClass(**params, device=device, dtype=dtype) + potential = calculator.forward(*self.cscl_system(device=device, dtype=dtype)) assert potential.dtype == dtype - assert potential.device.type == device - def check_operation(self, calculator, device): + if isinstance(device, torch.device): + assert potential.device == device + else: + assert potential.device.type == device + + def check_operation(self, calculator, device, dtype): """Make sure computation runs and returns a torch.Tensor.""" - descriptor = calculator.forward(*self.cscl_system(device)) + descriptor = calculator.forward(*self.cscl_system(device=device, dtype=dtype)) assert type(descriptor) is torch.Tensor - def test_operation_as_python(self, CalculatorClass, params, device): + def test_operation_as_python(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a normal python script""" params = params.copy() params["potential"].device = device - calculator = CalculatorClass(**params, device=device) - self.check_operation(calculator=calculator, device=device) + params["potential"].dtype = dtype + + calculator = CalculatorClass(**params, device=device, dtype=dtype) + self.check_operation(calculator=calculator, device=device, dtype=dtype) - def test_operation_as_torch_script(self, CalculatorClass, params, device): + def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a compiled torch script module.""" params = params.copy() params["potential"].device = device - calculator = CalculatorClass(**params, device=device) + params["potential"].dtype = dtype + + calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) - self.check_operation(calculator=scripted, device=device) + self.check_operation(calculator=scripted, device=device, dtype=dtype) - def test_save_load(self, CalculatorClass, params, device): + def test_save_load(self, CalculatorClass, params, device, dtype): params = params.copy() params["potential"].device = device - calculator = CalculatorClass(**params, device=device) + params["potential"].dtype = dtype + + calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) with io.BytesIO() as buffer: torch.jit.save(scripted, buffer) buffer.seek(0) torch.jit.load(buffer) - def test_prefactor(self, CalculatorClass, params, device): + def test_prefactor(self, CalculatorClass, params, device, dtype): """Test if the prefactor is applied correctly.""" params = params.copy() params["potential"].device = device + params["potential"].dtype = dtype + prefactor = 2.0 - calculator1 = CalculatorClass(**params, device=device) - calculator2 = CalculatorClass(**params, prefactor=prefactor, device=device) - potentials1 = calculator1.forward(*self.cscl_system()) - potentials2 = calculator2.forward(*self.cscl_system()) + calculator1 = CalculatorClass(**params, device=device, dtype=dtype) + calculator2 = CalculatorClass( + **params, prefactor=prefactor, device=device, dtype=dtype + ) + + potentials1 = calculator1.forward(*self.cscl_system(device=device, dtype=dtype)) + potentials2 = calculator2.forward(*self.cscl_system(device=device, dtype=dtype)) + assert torch.allclose(potentials1 * prefactor, potentials2) - def test_not_nan(self, CalculatorClass, params, device): + def test_not_nan(self, CalculatorClass, params, device, dtype): """Make sure derivatives are not NaN.""" params = params.copy() params["potential"].device = device + params["potential"].dtype = dtype - calculator = CalculatorClass(**params, device=device) - system = self.cscl_system(device) + calculator = CalculatorClass(**params, device=device, dtype=dtype) + system = self.cscl_system(device=device, dtype=dtype) system[0].requires_grad = True system[1].requires_grad = True system[2].requires_grad = True @@ -203,26 +205,45 @@ def test_not_nan(self, CalculatorClass, params, device): torch.autograd.grad(energy, system[2], retain_graph=True)[0] ).any() - def test_dtype_and_device_incompatability(self, CalculatorClass, params, device): - """Test that the calculator raises an error if the dtype and device are incompatible.""" + def test_dtype_and_device_incompatability( + self, CalculatorClass, params, device, dtype + ): + """Test that the calculator raises an error if the dtype and device are incompatible with potential.""" params = params.copy() + + other_dtype = torch.float32 if dtype == torch.float64 else torch.float64 + params["potential"].device = device - params["potential"].dtype = torch.float64 - with pytest.raises(AssertionError, match=".*dtype.*"): - CalculatorClass(**params, dtype=torch.float32, device=device) - with pytest.raises(AssertionError, match=".*device.*"): - CalculatorClass( - **params, dtype=params["potential"].dtype, device=torch.device("meta") - ) + params["potential"].dtype = dtype + + match = ( + rf"dtype of `potential` \({params['potential'].dtype}\) must be same as " + rf"of `calculator` \({other_dtype}\)" + ) + with pytest.raises(TypeError, match=match): + CalculatorClass(**params, dtype=other_dtype, device=device) + + match = ( + rf"device of `potential` \({params['potential'].device}\) must be same as " + rf"of `calculator` \(meta\)" + ) + with pytest.raises(ValueError, match=match): + CalculatorClass(**params, dtype=dtype, device=torch.device("meta")) def test_potential_and_calculator_incompatability( - self, CalculatorClass, params, device + self, + CalculatorClass, + params, + device, + dtype, ): """Test that the calculator raises an error if the potential and calculator are incompatible.""" params = params.copy() params["potential"].device = device + params["potential"].dtype = dtype + params["potential"] = torch.jit.script(params["potential"]) with pytest.raises( TypeError, match="Potential must be an instance of Potential, got.*" ): - CalculatorClass(**params) + CalculatorClass(**params, device=device, dtype=dtype) diff --git a/tests/helpers.py b/tests/helpers.py index f970f596..d38e9f9b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -227,14 +227,13 @@ def define_crystal(crystal_name="CsCl", dtype=None, device=None): else: raise ValueError(f"crystal_name = {crystal_name} is not supported!") - madelung_ref = torch.tensor(madelung_ref) charges = charges.reshape((-1, 1)) return ( positions.to(device=device, dtype=dtype), charges.to(device=device, dtype=dtype), cell.to(device=device, dtype=dtype), - madelung_ref, + torch.tensor(madelung_ref, device=device, dtype=dtype), num_formula_units, ) diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py index 025bd91c..6beed927 100644 --- a/tests/metatensor/test_workflow_metatensor.py +++ b/tests/metatensor/test_workflow_metatensor.py @@ -11,14 +11,15 @@ mts_torch = pytest.importorskip("metatensor.torch") mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") -AVAILABLE_DEVICES = [torch.device("cpu")] + torch.cuda.is_available() * [ - torch.device("cuda") -] +DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"] +DTYPES = [torch.float32, torch.float64] SMEARING = 0.1 LR_WAVELENGTH = SMEARING / 4 MESH_SPACING = SMEARING / 4 +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize( ("CalculatorClass", "params"), [ @@ -52,7 +53,10 @@ ], ) class TestWorkflow: - def system(self, device=None): + def system(self, device=None, dtype=None): + device = torch.get_default_device() if device is None else device + dtype = torch.get_default_dtype() if dtype is None else dtype + system = mts_atomistic.System( types=torch.tensor([1, 2, 2]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.2], [0.0, 0.0, 0.5]]), @@ -92,30 +96,32 @@ def system(self, device=None): return system.to(device=device), neighbors.to(device=device) - def check_operation(self, calculator, device): + def check_operation(self, calculator, device, dtype): """Make sure computation runs and returns a metatensor.TensorMap.""" - system, neighbors = self.system(device) + system, neighbors = self.system(device=device, dtype=dtype) descriptor = calculator.forward(system, neighbors) assert isinstance(descriptor, torch.ScriptObject) if version.parse(torch.__version__) >= version.parse("2.1"): assert descriptor._type().name() == "TensorMap" - @pytest.mark.parametrize("device", AVAILABLE_DEVICES) - def test_operation_as_python(self, CalculatorClass, params, device): + def test_operation_as_python(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a normal python script""" calculator = CalculatorClass(**params) - self.check_operation(calculator=calculator, device=device) + self.check_operation(calculator=calculator, device=device, dtype=dtype) - @pytest.mark.parametrize("device", AVAILABLE_DEVICES) - def test_operation_as_torch_script(self, CalculatorClass, params, device): + def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a compiled torch script module.""" calculator = CalculatorClass(**params) scripted = torch.jit.script(calculator) - self.check_operation(calculator=scripted, device=device) + self.check_operation(calculator=scripted, device=device, dtype=dtype) - def test_save_load(self, CalculatorClass, params): - calculator = CalculatorClass(**params) + def test_save_load(self, CalculatorClass, params, device, dtype): + params = params.copy() + params["potential"].device = device + params["potential"].dtype = dtype + + calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) with io.BytesIO() as buffer: torch.jit.save(scripted, buffer) diff --git a/tests/tuning/test_tuning.py b/tests/tuning/test_tuning.py index 17a43580..6477115b 100644 --- a/tests/tuning/test_tuning.py +++ b/tests/tuning/test_tuning.py @@ -16,31 +16,46 @@ sys.path.append(str(Path(__file__).parents[1])) from helpers import define_crystal, neighbor_list -DTYPE = torch.float32 -DEVICE = "cpu" DEFAULT_CUTOFF = 4.4 -CHARGES_1 = torch.ones((4, 1), dtype=DTYPE, device=DEVICE) -POSITIONS_1 = 0.3 * torch.arange(12, dtype=DTYPE, device=DEVICE).reshape((4, 3)) -CELL_1 = torch.eye(3, dtype=DTYPE, device=DEVICE) +DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"] +DTYPES = [torch.float32, torch.float64] -def test_TunerBase_double(): +def system(device=None, dtype=None): + device = torch.get_default_device() if device is None else device + dtype = torch.get_default_dtype() if dtype is None else dtype + + charges = torch.ones((4, 1), dtype=dtype, device=device) + cell = torch.eye(3, dtype=dtype, device=device) + positions = 0.3 * torch.arange(12, dtype=dtype, device=device).reshape((4, 3)) + + return charges, cell, positions + + +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("dtype", DTYPES) +def test_TunerBase_init(device, dtype): """ - Check that `TunerBase` initilizes with double precisions tensors. + Check that `TunerBase` initilizes correctly. We are using dummy `neighbor_indices` and `neighbor_distances` to verify types. Have to be sure that these dummy variables are initilized correctly. """ + charges, cell, positions = system(device, dtype) TunerBase( - charges=CHARGES_1.to(dtype=torch.float64), - cell=CELL_1.to(dtype=torch.float64), - positions=POSITIONS_1.to(dtype=torch.float64), + charges=charges, + cell=cell, + positions=positions, cutoff=DEFAULT_CUTOFF, calculator=1.0, exponent=1, + dtype=dtype, + device=device, ) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize( ("calculator", "tune", "param_length"), [ @@ -50,13 +65,15 @@ def test_TunerBase_double(): ], ) @pytest.mark.parametrize("accuracy", [1e-1, 1e-3, 1e-5]) -def test_parameter_choose(calculator, tune, param_length, accuracy): +def test_parameter_choose(device, dtype, calculator, tune, param_length, accuracy): """ Check that the Madelung constants obtained from the Ewald sum calculator matches the reference values and that all branches of the from_accuracy method are covered. """ # Get input parameters and adjust to account for scaling - pos, charges, cell, madelung_ref, num_units = define_crystal() + pos, charges, cell, madelung_ref, num_units = define_crystal( + dtype=dtype, device=device + ) # Compute neighbor list neighbor_indices, neighbor_distances = neighbor_list( @@ -71,13 +88,17 @@ def test_parameter_choose(calculator, tune, param_length, accuracy): neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, accuracy=accuracy, + dtype=dtype, + device=device, ) assert len(params) == param_length # Compute potential and compare against target value using default hypers calc = calculator( - potential=(CoulombPotential(smearing=smearing)), + potential=(CoulombPotential(smearing=smearing, dtype=dtype, device=device)), + dtype=dtype, + device=device, **params, ) potentials = calc.forward( @@ -103,12 +124,12 @@ def test_accuracy_error(tune): ) with pytest.raises(ValueError, match=match): tune( - charges, - cell, - pos, - DEFAULT_CUTOFF, - neighbor_indices, - neighbor_distances, + charges=charges, + cell=cell, + positions=pos, + cutoff=DEFAULT_CUTOFF, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, accuracy="foo", ) @@ -116,79 +137,90 @@ def test_accuracy_error(tune): @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_exponent_not_1_error(tune): pos, charges, cell, _, _ = define_crystal() - - match = "Only exponent = 1 is supported but got 2." neighbor_indices, neighbor_distances = neighbor_list( positions=pos, box=cell, cutoff=DEFAULT_CUTOFF ) + + match = "Only exponent = 1 is supported but got 2." with pytest.raises(NotImplementedError, match=match): tune( - charges, - cell, - pos, - DEFAULT_CUTOFF, - neighbor_indices, - neighbor_distances, + charges=charges, + cell=cell, + positions=pos, + cutoff=DEFAULT_CUTOFF, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, exponent=2, ) @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_invalid_shape_positions(tune): + charges, cell, _ = system() match = ( r"`positions` must be a tensor with shape \[n_atoms, 3\], got tensor with " r"shape \[4, 5\]" ) with pytest.raises(ValueError, match=match): tune( - CHARGES_1, - CELL_1, - torch.ones((4, 5), dtype=DTYPE, device=DEVICE), - DEFAULT_CUTOFF, - None, # dummy neighbor indices - None, # dummy neighbor distances + charges=charges, + cell=cell, + positions=torch.ones((4, 5)), + cutoff=DEFAULT_CUTOFF, + neighbor_indices=None, + neighbor_distances=None, ) # Tests for invalid shape, dtype and device of cell @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_invalid_shape_cell(tune): + charges, _, positions = system() match = ( r"`cell` must be a tensor with shape \[3, 3\], got tensor with shape \[2, 2\]" ) with pytest.raises(ValueError, match=match): tune( - CHARGES_1, - torch.ones([2, 2], dtype=DTYPE, device=DEVICE), - POSITIONS_1, - DEFAULT_CUTOFF, - None, # dummy neighbor indices - None, # dummy neighbor distances + charges=charges, + cell=torch.ones([2, 2]), + positions=positions, + cutoff=DEFAULT_CUTOFF, + neighbor_indices=None, + neighbor_distances=None, ) @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_invalid_cell(tune): + charges, _, positions = system() match = ( "provided `cell` has a determinant of 0 and therefore is not valid for " "periodic calculation" ) with pytest.raises(ValueError, match=match): - tune(CHARGES_1, torch.zeros(3, 3), POSITIONS_1, DEFAULT_CUTOFF, None, None) + tune( + charges=charges, + cell=torch.zeros(3, 3), + positions=positions, + cutoff=DEFAULT_CUTOFF, + neighbor_indices=None, + neighbor_distances=None, + ) @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_invalid_dtype_cell(tune): + charges, _, positions = system() match = ( r"type of `cell` \(torch.float64\) must be same as `positions` " r"\(torch.float32\)" ) - with pytest.raises(ValueError, match=match): + with pytest.raises(TypeError, match=match): tune( - CHARGES_1, - torch.eye(3, dtype=torch.float64, device=DEVICE), - POSITIONS_1, - DEFAULT_CUTOFF, - None, - None, + charges=charges, + cell=torch.eye(3, dtype=torch.float64), + positions=positions, + cutoff=DEFAULT_CUTOFF, + neighbor_indices=None, + neighbor_distances=None, ) From a9504a12a7109cb23c177bdb4d70d57a918ff360 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Tue, 28 Jan 2025 17:17:47 +0100 Subject: [PATCH 3/4] Fix examples and tests Fix again --- examples/01-charges-example.py | 17 +++++++-------- examples/02-neighbor-lists-usage.py | 8 ++++++-- examples/08-combined-potential.py | 15 +++++++------- examples/10-tuning.py | 32 ++++++++++++++++++++++------- examples/basic-usage.py | 6 ++++-- tests/helpers.py | 5 ++++- 6 files changed, 56 insertions(+), 27 deletions(-) diff --git a/examples/01-charges-example.py b/examples/01-charges-example.py index 92d82949..53d6b964 100644 --- a/examples/01-charges-example.py +++ b/examples/01-charges-example.py @@ -42,12 +42,12 @@ # %% # # Create the properties CsCl unit cell - +dtype = torch.float64 symbols = ("Cs", "Cl") types = torch.tensor([55, 17]) -charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64) -positions = torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=torch.float64) -cell = torch.eye(3, dtype=torch.float64) +charges = torch.tensor([[1.0], [-1.0]], dtype=dtype) +positions = torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=dtype) +cell = torch.eye(3, dtype=dtype) pbc = torch.tensor([True, True, True]) @@ -72,6 +72,7 @@ cutoff=cutoff, neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, + dtype=dtype, ) # %% @@ -101,7 +102,7 @@ # will be used to *compute* the potential energy of the system. calculator = torchpme.PMECalculator( - torchpme.CoulombPotential(smearing=smearing), **pme_params + torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params ) # %% @@ -112,7 +113,7 @@ # As a first application of multiple charge channels, we start simply by using the # classic definition of one charge channel per atom. -charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64) +charges = torch.tensor([[1.0], [-1.0]], dtype=dtype) # %% # @@ -160,7 +161,7 @@ # species-specific potentials and facilitating the learning process for machine learning # algorithms. -charges_one_hot = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float64) +charges_one_hot = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=dtype) # %% # @@ -205,7 +206,7 @@ # creating a new calculator with the metatensor interface. calculator_metatensor = torchpme.metatensor.PMECalculator( - torchpme.CoulombPotential(smearing=smearing), **pme_params + torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params ) # %% diff --git a/examples/02-neighbor-lists-usage.py b/examples/02-neighbor-lists-usage.py index 322e6a69..e72689f0 100644 --- a/examples/02-neighbor-lists-usage.py +++ b/examples/02-neighbor-lists-usage.py @@ -55,6 +55,7 @@ # # As a test system, we use a 2x2x2 supercell of an CsCl crystal in a cubic cell. +dtype = torch.float64 atoms_unitcell = ase.Atoms( symbols=["Cs", "Cl"], positions=np.array([(0, 0, 0), (0.5, 0.5, 0.5)]), @@ -97,7 +98,7 @@ nl = vesin.torch.NeighborList(cutoff=cutoff, full_list=False) neighbor_indices, neighbor_distances = nl.compute( points=positions.to(dtype=torch.float64, device="cpu"), - box=cell.to(dtype=torch.float64, device="cpu"), + box=cell.to(dtype=dtype, device="cpu"), periodic=True, quantities="Pd", ) @@ -109,6 +110,7 @@ cutoff=cutoff, neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, + dtype=dtype, ) # %% @@ -193,7 +195,9 @@ def distances( # compute the potential. pme = torchpme.PMECalculator( - potential=torchpme.CoulombPotential(smearing=smearing), **pme_params + potential=torchpme.CoulombPotential(smearing=smearing, dtype=dtype), + dtype=dtype, + **pme_params, ) potential = pme( charges=charges, diff --git a/examples/08-combined-potential.py b/examples/08-combined-potential.py index df455e98..8d0667bb 100644 --- a/examples/08-combined-potential.py +++ b/examples/08-combined-potential.py @@ -28,6 +28,8 @@ from torchpme import CombinedPotential, EwaldCalculator, InversePowerLawPotential from torchpme.prefactors import eV_A +dtype = torch.float64 + # %% # Combined potentials # ------------------- @@ -65,10 +67,10 @@ # evaluation, and so one has to set it also for the combined potential, even if it is # not used explicitly in the evaluation of the combination. -pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing) -pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing) +pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype) +pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing, dtype=dtype) -potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing) +potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing, dtype=dtype) # Note also that :class:`CombinedPotential` can be used with any combination of # potentials, as long they are all either direct or range separated. For instance, one @@ -80,7 +82,7 @@ # We now plot of the individual and combined ``potential`` functions together with an # explicit sum of the two potentials. -dist = torch.logspace(-3, 2, 1000) +dist = torch.logspace(-3, 2, 1000, dtype=dtype) fig, ax = plt.subplots() @@ -115,7 +117,7 @@ # combines all terms in a range-separated potential, including the k-space # kernel. -k = torch.logspace(-2, 2, 1000) +k = torch.logspace(-2, 2, 1000, dtype=dtype) fig, ax = plt.subplots() @@ -154,9 +156,8 @@ # much bigger system. calculator = EwaldCalculator( - potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A + potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A, dtype=dtype ) -calculator.to(dtype=torch.float64) # %% diff --git a/examples/10-tuning.py b/examples/10-tuning.py index c2d61881..183b0f45 100644 --- a/examples/10-tuning.py +++ b/examples/10-tuning.py @@ -120,7 +120,9 @@ pme_params = {"mesh_spacing": 1.0, "interpolation_nodes": 4} pme = torchpme.PMECalculator( - potential=torchpme.CoulombPotential(smearing=smearing), + potential=torchpme.CoulombPotential(smearing=smearing, device=device, dtype=dtype), + device=device, + dtype=dtype, **pme_params, # type: ignore[arg-type] ) @@ -168,6 +170,8 @@ neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, run_backward=True, + device=device, + dtype=dtype, ) estimated_timing = timings(pme) @@ -210,15 +214,19 @@ def filter_neighbors(cutoff, neighbor_indices, neighbor_distances): return neighbor_indices[filter_idx], neighbor_distances[filter_idx] -def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes): +def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device, dtype): filter_indices, filter_distances = filter_neighbors( cutoff, neighbor_indices, neighbor_distances ) pme = torchpme.PMECalculator( - potential=torchpme.CoulombPotential(smearing=smearing), + potential=torchpme.CoulombPotential( + smearing=smearing, device=device, dtype=dtype + ), mesh_spacing=mesh_spacing, interpolation_nodes=interpolation_nodes, + device=device, + dtype=dtype, ) potential = pme( charges=charges, @@ -239,6 +247,8 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes): run_backward=True, n_warmup=1, n_repeat=4, + device=device, + dtype=dtype, ) estimated_timing = timings(pme) return madelung, estimated_timing @@ -251,7 +261,9 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes): bounds = np.zeros((len(smearing_grid), len(spacing_grid))) for ism, smearing in enumerate(smearing_grid): for isp, spacing in enumerate(spacing_grid): - results[ism, isp], timings[ism, isp] = timed_madelung(8.0, smearing, spacing, 4) + results[ism, isp], timings[ism, isp] = timed_madelung( + 8.0, smearing, spacing, 4, device, dtype + ) bounds[ism, isp] = error_bounds(8.0, smearing, spacing, 4) # %% @@ -374,7 +386,7 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes): for inint, nint in enumerate(nint_grid): for isp, spacing in enumerate(spacing_grid): results[inint, isp], timings[inint, isp] = timed_madelung( - 5.0, 1.0, spacing, nint + 5.0, 1.0, spacing, nint, device=device, dtype=dtype ) bounds[inint, isp] = error_bounds(5.0, 1.0, spacing, nint) @@ -445,15 +457,19 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes): cutoff=5.0, neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, + device=device, + dtype=dtype, ) -print(f""" +print( + f""" Estimated PME parameters (cutoff={5.0} Å): Smearing: {smearing} Å Mesh spacing: {parameters["mesh_spacing"]} Å Interpolation order: {parameters["interpolation_nodes"]} Estimated time per step: {timing} s -""") +""" +) # %% # What is the best cutoff? @@ -476,6 +492,8 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes): cutoff=cutoff, neighbor_indices=filter_indices, neighbor_distances=filter_distances, + device=device, + dtype=dtype, ) timings_grid.append(timing) diff --git a/examples/basic-usage.py b/examples/basic-usage.py index d726a18f..47bc1506 100644 --- a/examples/basic-usage.py +++ b/examples/basic-usage.py @@ -146,7 +146,7 @@ # contains all the necessary functions (such as those defining the short-range and # long-range splits) for this potential and makes them useable in the rest of the code. -potential = CoulombPotential(smearing=smearing) +potential = CoulombPotential(smearing=smearing, device=device, dtype=dtype) # %% # @@ -193,7 +193,9 @@ # Since our structure is relatively small, we use the :class:`EwaldCalculator`. # We start by the initialization of the class. -calculator = EwaldCalculator(potential=potential, lr_wavelength=lr_wavelength) +calculator = EwaldCalculator( + potential=potential, lr_wavelength=lr_wavelength, device=device, dtype=dtype +) # %% # diff --git a/tests/helpers.py b/tests/helpers.py index d38e9f9b..bca5feb7 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -256,7 +256,10 @@ def neighbor_list( nl = NeighborList(cutoff=cutoff, full_list=full_neighbor_list) neighbor_indices, d, S = nl.compute( - points=positions, box=box, periodic=periodic, quantities="PdS" + points=positions.to(dtype=torch.float64, device="cpu"), + box=box.to(dtype=torch.float64, device="cpu"), + periodic=periodic, + quantities="PdS", ) neighbor_indices = torch.from_numpy(neighbor_indices.astype(int)).to( From 4eae184561ede4e0e5b25f69681427e2e9485d05 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Wed, 29 Jan 2025 08:17:56 +0100 Subject: [PATCH 4/4] clearer error messages --- examples/01-charges-example.py | 3 ++- src/torchpme/_utils.py | 21 ++++++++++++--------- tests/calculators/test_calculator.py | 21 +++++++++------------ tests/tuning/test_tuning.py | 3 +-- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/examples/01-charges-example.py b/examples/01-charges-example.py index 53d6b964..6465a507 100644 --- a/examples/01-charges-example.py +++ b/examples/01-charges-example.py @@ -39,10 +39,11 @@ import torchpme from torchpme.tuning import tune_pme +dtype = torch.float64 + # %% # # Create the properties CsCl unit cell -dtype = torch.float64 symbols = ("Cs", "Cl") types = torch.tensor([55, 17]) charges = torch.tensor([[1.0], [-1.0]], dtype=dtype) diff --git a/src/torchpme/_utils.py b/src/torchpme/_utils.py index f282a626..8d42e45c 100644 --- a/src/torchpme/_utils.py +++ b/src/torchpme/_utils.py @@ -15,7 +15,7 @@ def _validate_parameters( ) -> None: if positions.dtype != dtype: raise TypeError( - f"type of `positions` ({positions.dtype}) must be same as class " + f"type of `positions` ({positions.dtype}) must be same as the class " f"type ({dtype})" ) @@ -24,10 +24,13 @@ def _validate_parameters( if positions.device.type != device: raise ValueError( - f"device of `positions` ({positions.device}) must be same as class " + f"device of `positions` ({positions.device}) must be same as the class " f"device ({device})" ) + # We use `positions.device` because it includes the device type AND index, which the + # `device` parameter may lack + # check shape, dtype and device of positions num_atoms = len(positions) if list(positions.shape) != [len(positions), 3]: @@ -45,12 +48,12 @@ def _validate_parameters( if cell.dtype != positions.dtype: raise TypeError( - f"type of `cell` ({cell.dtype}) must be same as `positions` ({dtype})" + f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})" ) if cell.device != positions.device: raise ValueError( - f"device of `cell` ({cell.device}) must be same as `positions` ({device})" + f"device of `cell` ({cell.device}) must be same as the class ({device})" ) if smearing is not None and torch.equal( @@ -79,12 +82,12 @@ def _validate_parameters( if charges.dtype != positions.dtype: raise TypeError( - f"type of `charges` ({charges.dtype}) must be same as `positions` ({dtype})" + f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})" ) if charges.device != positions.device: raise ValueError( - f"device of `charges` ({charges.device}) must be same as `positions` " + f"device of `charges` ({charges.device}) must be same as the class " f"({device})" ) @@ -99,7 +102,7 @@ def _validate_parameters( if neighbor_indices.device != positions.device: raise ValueError( f"device of `neighbor_indices` ({neighbor_indices.device}) must be " - f"same as `positions` ({device})" + f"same as the class ({device})" ) if neighbor_distances.shape != neighbor_indices[:, 0].shape: @@ -112,11 +115,11 @@ def _validate_parameters( if neighbor_distances.device != positions.device: raise ValueError( f"device of `neighbor_distances` ({neighbor_distances.device}) must be " - f"same as `positions` ({device})" + f"same as the class ({device})" ) if neighbor_distances.dtype != positions.dtype: raise TypeError( f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same " - f"as `positions` ({dtype})" + f"as the class ({dtype})" ) diff --git a/tests/calculators/test_calculator.py b/tests/calculators/test_calculator.py index 567b0f15..dff72455 100644 --- a/tests/calculators/test_calculator.py +++ b/tests/calculators/test_calculator.py @@ -45,7 +45,7 @@ def test_compute_output_shapes(): def test_wrong_device_positions(): calculator = CalculatorTest() - match = r"device of `positions` \(meta\) must be same as class device \(cpu\)" + match = r"device of `positions` \(meta\) must be same as the class device \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1.to(device="meta"), @@ -59,7 +59,7 @@ def test_wrong_device_positions(): def test_wrong_dtype_positions(): calculator = CalculatorTest() match = ( - r"type of `positions` \(torch.float64\) must be same as class type " + r"type of `positions` \(torch.float64\) must be same as the class type " r"\(torch.float32\)" ) with pytest.raises(TypeError, match=match): @@ -108,8 +108,7 @@ def test_invalid_shape_cell(): def test_invalid_dtype_cell(): calculator = CalculatorTest() match = ( - r"type of `cell` \(torch.float64\) must be same as `positions` " - r"\(torch.float32\)" + r"type of `cell` \(torch.float64\) must be same as the class \(torch.float32\)" ) with pytest.raises(TypeError, match=match): calculator.forward( @@ -123,7 +122,7 @@ def test_invalid_dtype_cell(): def test_invalid_device_cell(): calculator = CalculatorTest() - match = r"device of `cell` \(meta\) must be same as `positions` \(cpu\)" + match = r"device of `cell` \(meta\) must be same as the class \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1, @@ -189,7 +188,7 @@ def test_invalid_shape_charges(): def test_invalid_dtype_charges(): calculator = CalculatorTest() match = ( - r"type of `charges` \(torch.float64\) must be same as `positions` " + r"type of `charges` \(torch.float64\) must be same as the class " r"\(torch.float32\)" ) with pytest.raises(TypeError, match=match): @@ -204,7 +203,7 @@ def test_invalid_dtype_charges(): def test_invalid_device_charges(): calculator = CalculatorTest() - match = r"device of `charges` \(meta\) must be same as `positions` \(cpu\)" + match = r"device of `charges` \(meta\) must be same as the class \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1, @@ -249,7 +248,7 @@ def test_invalid_shape_neighbor_indices_neighbor_distances(): def test_invalid_device_neighbor_indices(): calculator = CalculatorTest() - match = r"device of `neighbor_indices` \(meta\) must be same as `positions` \(cpu\)" + match = r"device of `neighbor_indices` \(meta\) must be same as the class \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1, @@ -262,9 +261,7 @@ def test_invalid_device_neighbor_indices(): def test_invalid_device_neighbor_distances(): calculator = CalculatorTest() - match = ( - r"device of `neighbor_distances` \(meta\) must be same as `positions` \(cpu\)" - ) + match = r"device of `neighbor_distances` \(meta\) must be same as the class \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1, @@ -279,7 +276,7 @@ def test_invalid_dtype_neighbor_distances(): calculator = CalculatorTest() match = ( r"type of `neighbor_distances` \(torch.float64\) must be same " - r"as `positions` \(torch.float32\)" + r"as the class \(torch.float32\)" ) with pytest.raises(TypeError, match=match): calculator.forward( diff --git a/tests/tuning/test_tuning.py b/tests/tuning/test_tuning.py index 6477115b..3d1ce03f 100644 --- a/tests/tuning/test_tuning.py +++ b/tests/tuning/test_tuning.py @@ -212,8 +212,7 @@ def test_invalid_cell(tune): def test_invalid_dtype_cell(tune): charges, _, positions = system() match = ( - r"type of `cell` \(torch.float64\) must be same as `positions` " - r"\(torch.float32\)" + r"type of `cell` \(torch.float64\) must be same as the class \(torch.float32\)" ) with pytest.raises(TypeError, match=match): tune(