From ec6d55e08702f97f319dfa12e0c903d4db4dc52c Mon Sep 17 00:00:00 2001 From: E-Rum Date: Fri, 31 Jan 2025 14:00:43 +0000 Subject: [PATCH] Merge "main" --- src/torchpme/_utils.py | 35 ++++++---- src/torchpme/calculators/calculator.py | 6 +- src/torchpme/potentials/potential.py | 8 ++- src/torchpme/tuning/tuner.py | 10 +-- tests/calculators/test_workflow.py | 72 +++++++++++++------- tests/helpers.py | 6 +- tests/metatensor/test_workflow_metatensor.py | 36 +++++++--- tests/tuning/test_tuning.py | 5 +- 8 files changed, 117 insertions(+), 61 deletions(-) diff --git a/src/torchpme/_utils.py b/src/torchpme/_utils.py index 8d42e45c..b7b4063c 100644 --- a/src/torchpme/_utils.py +++ b/src/torchpme/_utils.py @@ -1,8 +1,23 @@ -from typing import Union +from typing import Optional, Union import torch +def _get_dtype(dtype: Optional[torch.dtype]) -> torch.dtype: + return torch.get_default_dtype() if dtype is None else dtype + + +def _get_device(device: Union[None, str, torch.device]) -> torch.device: + new_device = torch.get_default_device() if device is None else torch.device(device) + + # Add default index of 0 to a cuda device to avoid errors when comparing with + # devices from tensors + if new_device.type == "cuda" and new_device.index is None: + new_device = torch.device("cuda:0") + + return new_device + + def _validate_parameters( charges: torch.Tensor, cell: torch.Tensor, @@ -11,7 +26,7 @@ def _validate_parameters( neighbor_distances: torch.Tensor, smearing: Union[float, None], dtype: torch.dtype, - device: Union[str, torch.device], + device: torch.device, ) -> None: if positions.dtype != dtype: raise TypeError( @@ -19,18 +34,12 @@ def _validate_parameters( f"type ({dtype})" ) - if isinstance(device, torch.device): - device = device.type - - if positions.device.type != device: + if positions.device != device: raise ValueError( f"device of `positions` ({positions.device}) must be same as the class " f"device ({device})" ) - # We use `positions.device` because it includes the device type AND index, which the - # `device` parameter may lack - # check shape, dtype and device of positions num_atoms = len(positions) if list(positions.shape) != [len(positions), 3]: @@ -51,7 +60,7 @@ def _validate_parameters( f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})" ) - if cell.device != positions.device: + if cell.device != device: raise ValueError( f"device of `cell` ({cell.device}) must be same as the class ({device})" ) @@ -85,7 +94,7 @@ def _validate_parameters( f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})" ) - if charges.device != positions.device: + if charges.device != device: raise ValueError( f"device of `charges` ({charges.device}) must be same as the class " f"({device})" @@ -99,7 +108,7 @@ def _validate_parameters( "structure" ) - if neighbor_indices.device != positions.device: + if neighbor_indices.device != device: raise ValueError( f"device of `neighbor_indices` ({neighbor_indices.device}) must be " f"same as the class ({device})" @@ -112,7 +121,7 @@ def _validate_parameters( f"{list(neighbor_indices.shape)} and {list(neighbor_distances.shape)}" ) - if neighbor_distances.device != positions.device: + if neighbor_distances.device != device: raise ValueError( f"device of `neighbor_distances` ({neighbor_distances.device}) must be " f"same as the class ({device})" diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py index a7681fea..3dfbeb38 100644 --- a/src/torchpme/calculators/calculator.py +++ b/src/torchpme/calculators/calculator.py @@ -3,7 +3,7 @@ import torch from torch import profiler -from .._utils import _validate_parameters +from .._utils import _get_device, _get_dtype, _validate_parameters from ..potentials import Potential @@ -46,8 +46,8 @@ def __init__( f"Potential must be an instance of Potential, got {type(potential)}" ) - self.device = torch.get_default_device() if device is None else device - self.dtype = torch.get_default_dtype() if dtype is None else dtype + self.device = _get_device(device) + self.dtype = _get_dtype(dtype) if self.dtype != potential.dtype: raise TypeError( diff --git a/src/torchpme/potentials/potential.py b/src/torchpme/potentials/potential.py index b6e7fb36..af715505 100644 --- a/src/torchpme/potentials/potential.py +++ b/src/torchpme/potentials/potential.py @@ -2,6 +2,8 @@ import torch +from .._utils import _get_device, _get_dtype + class Potential(torch.nn.Module): r""" @@ -42,8 +44,10 @@ def __init__( device: Union[None, str, torch.device] = None, ): super().__init__() - self.dtype = torch.get_default_dtype() if dtype is None else dtype - self.device = torch.get_default_device() if device is None else device + + self.device = _get_device(device) + self.dtype = _get_dtype(dtype) + if smearing is not None: self.register_buffer( "smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype) diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py index 7ac463d6..546b3995 100644 --- a/src/torchpme/tuning/tuner.py +++ b/src/torchpme/tuning/tuner.py @@ -4,7 +4,7 @@ import torch -from .._utils import _validate_parameters +from .._utils import _get_device, _get_dtype, _validate_parameters from ..calculators import Calculator from ..potentials import InversePowerLawPotential @@ -91,8 +91,8 @@ def __init__( f"Only exponent = 1 is supported but got {exponent}." ) - self.device = torch.get_default_device() if device is None else device - self.dtype = torch.get_default_dtype() if dtype is None else dtype + self.device = _get_device(device) + self.dtype = _get_dtype(dtype) _validate_parameters( charges=charges, @@ -294,8 +294,8 @@ def __init__( ): super().__init__() - self.dtype = torch.get_default_dtype() if dtype is None else dtype - self.device = torch.get_default_device() if device is None else device + self.device = _get_device(device) + self.dtype = _get_dtype(dtype) _validate_parameters( charges=charges, diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py index 209355d4..3b62ae3b 100644 --- a/tests/calculators/test_workflow.py +++ b/tests/calculators/test_workflow.py @@ -15,6 +15,7 @@ P3MCalculator, PMECalculator, ) +from torchpme._utils import _get_device, _get_dtype DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"] DTYPES = [torch.float32, torch.float64] @@ -31,27 +32,35 @@ ( Calculator, { - "potential": CoulombPotential(smearing=None), + "potential": lambda dtype, device: CoulombPotential( + smearing=None, dtype=dtype, device=device + ), }, ), ( EwaldCalculator, { - "potential": CoulombPotential(smearing=SMEARING), + "potential": lambda dtype, device: CoulombPotential( + smearing=SMEARING, dtype=dtype, device=device + ), "lr_wavelength": LR_WAVELENGTH, }, ), ( PMECalculator, { - "potential": CoulombPotential(smearing=SMEARING), + "potential": lambda dtype, device: CoulombPotential( + smearing=SMEARING, dtype=dtype, device=device + ), "mesh_spacing": MESH_SPACING, }, ), ( P3MCalculator, { - "potential": CoulombPotential(smearing=SMEARING), + "potential": lambda dtype, device: CoulombPotential( + smearing=SMEARING, dtype=dtype, device=device + ), "mesh_spacing": MESH_SPACING, }, ), @@ -60,8 +69,8 @@ class TestWorkflow: def cscl_system(self, device=None, dtype=None): """CsCl crystal. Same as in the madelung test""" - device = torch.get_default_device() if device is None else device - dtype = torch.get_default_dtype() if dtype is None else dtype + device = _get_device(device) + dtype = _get_dtype(dtype) positions = torch.tensor( [[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype, device=device @@ -75,6 +84,7 @@ def cscl_system(self, device=None, dtype=None): def test_smearing_non_positive(self, CalculatorClass, params, device, dtype): params = params.copy() + params["potential"] = params["potential"](dtype, device) match = r"`smearing` .* has to be positive" if type(CalculatorClass) in [EwaldCalculator, PMECalculator]: params["smearing"] = 0 @@ -86,6 +96,7 @@ def test_smearing_non_positive(self, CalculatorClass, params, device, dtype): def test_interpolation_order_error(self, CalculatorClass, params, device, dtype): params = params.copy() + params["potential"] = params["potential"](dtype, device) if type(CalculatorClass) in [PMECalculator]: match = "Only `interpolation_nodes` from 1 to 5" params["interpolation_nodes"] = 10 @@ -94,6 +105,7 @@ def test_interpolation_order_error(self, CalculatorClass, params, device, dtype) def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype): params = params.copy() + params["potential"] = params["potential"](dtype, device) match = r"`lr_wavelength` .* has to be positive" if type(CalculatorClass) in [EwaldCalculator]: params["lr_wavelength"] = 0 @@ -106,8 +118,7 @@ def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype def test_dtype_device(self, CalculatorClass, params, device, dtype): """Test that the output dtype and device are the same as the input.""" params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) potential = calculator.forward(*self.cscl_system(device=device, dtype=dtype)) @@ -127,8 +138,7 @@ def check_operation(self, calculator, device, dtype): def test_operation_as_python(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a normal python script""" params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) self.check_operation(calculator=calculator, device=device, dtype=dtype) @@ -136,8 +146,7 @@ def test_operation_as_python(self, CalculatorClass, params, device, dtype): def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a compiled torch script module.""" params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) @@ -145,8 +154,7 @@ def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype) def test_save_load(self, CalculatorClass, params, device, dtype): params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) @@ -158,8 +166,7 @@ def test_save_load(self, CalculatorClass, params, device, dtype): def test_prefactor(self, CalculatorClass, params, device, dtype): """Test if the prefactor is applied correctly.""" params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) prefactor = 2.0 calculator1 = CalculatorClass(**params, device=device, dtype=dtype) @@ -175,8 +182,7 @@ def test_prefactor(self, CalculatorClass, params, device, dtype): def test_not_nan(self, CalculatorClass, params, device, dtype): """Make sure derivatives are not NaN.""" params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) system = self.cscl_system(device=device, dtype=dtype) @@ -212,9 +218,7 @@ def test_dtype_and_device_incompatability( params = params.copy() other_dtype = torch.float32 if dtype == torch.float64 else torch.float64 - - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) match = ( rf"dtype of `potential` \({params['potential'].dtype}\) must be same as " @@ -239,11 +243,33 @@ def test_potential_and_calculator_incompatability( ): """Test that the calculator raises an error if the potential and calculator are incompatible.""" params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) params["potential"] = torch.jit.script(params["potential"]) with pytest.raises( TypeError, match="Potential must be an instance of Potential, got.*" ): CalculatorClass(**params, device=device, dtype=dtype) + + def test_device_string_compatability(self, CalculatorClass, params, dtype, device): + """Test that the calculator works with device strings.""" + params = params.copy() + params["potential"] = params["potential"](dtype, "cpu") + calculator = CalculatorClass( + **params, + device=torch.device("cpu"), + dtype=dtype, + ) + + assert calculator.device == params["potential"].device + + def test_device_index_compatability(self, CalculatorClass, params, dtype, device): + """Test that the calculator works with no index on the device.""" + if torch.cuda.is_available(): + params = params.copy() + params["potential"] = params["potential"](dtype, "cuda") + calculator = CalculatorClass( + **params, device=torch.device("cuda:0"), dtype=dtype + ) + + assert calculator.device == params["potential"].device diff --git a/tests/helpers.py b/tests/helpers.py index bca5feb7..4682b906 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,6 +7,8 @@ import torch from vesin import NeighborList +from torchpme._utils import _get_device, _get_dtype + SQRT3 = math.sqrt(3) DIR_PATH = Path(__file__).parent @@ -15,8 +17,8 @@ def define_crystal(crystal_name="CsCl", dtype=None, device=None): - device = torch.get_default_device() if device is None else device - dtype = torch.get_default_dtype() if dtype is None else dtype + device = _get_device(device) + dtype = _get_dtype(dtype) # Define all relevant parameters (atom positions, charges, cell) of the reference # crystal structures for which the Madelung constants obtained from the Ewald sums diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py index 6beed927..141977d2 100644 --- a/tests/metatensor/test_workflow_metatensor.py +++ b/tests/metatensor/test_workflow_metatensor.py @@ -7,6 +7,7 @@ from packaging import version import torchpme +from torchpme._utils import _get_device, _get_dtype mts_torch = pytest.importorskip("metatensor.torch") mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") @@ -26,27 +27,35 @@ ( torchpme.metatensor.Calculator, { - "potential": torchpme.CoulombPotential(smearing=None), + "potential": lambda dtype, device: torchpme.CoulombPotential( + smearing=None, dtype=dtype, device=device + ), }, ), ( torchpme.metatensor.EwaldCalculator, { - "potential": torchpme.CoulombPotential(smearing=SMEARING), + "potential": lambda dtype, device: torchpme.CoulombPotential( + smearing=SMEARING, dtype=dtype, device=device + ), "lr_wavelength": LR_WAVELENGTH, }, ), ( torchpme.metatensor.PMECalculator, { - "potential": torchpme.CoulombPotential(smearing=SMEARING), + "potential": lambda dtype, device: torchpme.CoulombPotential( + smearing=SMEARING, dtype=dtype, device=device + ), "mesh_spacing": MESH_SPACING, }, ), ( torchpme.metatensor.P3MCalculator, { - "potential": torchpme.CoulombPotential(smearing=SMEARING), + "potential": lambda dtype, device: torchpme.CoulombPotential( + smearing=SMEARING, dtype=dtype, device=device + ), "mesh_spacing": MESH_SPACING, }, ), @@ -54,8 +63,8 @@ ) class TestWorkflow: def system(self, device=None, dtype=None): - device = torch.get_default_device() if device is None else device - dtype = torch.get_default_dtype() if dtype is None else dtype + device = _get_device(device) + dtype = _get_dtype(dtype) system = mts_atomistic.System( types=torch.tensor([1, 2, 2]), @@ -94,7 +103,9 @@ def system(self, device=None, dtype=None): properties=mts_torch.Labels.range("distance", 1), ) - return system.to(device=device), neighbors.to(device=device) + return system.to(device=device, dtype=dtype), neighbors.to( + device=device, dtype=dtype + ) def check_operation(self, calculator, device, dtype): """Make sure computation runs and returns a metatensor.TensorMap.""" @@ -107,19 +118,22 @@ def check_operation(self, calculator, device, dtype): def test_operation_as_python(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a normal python script""" - calculator = CalculatorClass(**params) + params = params.copy() + params["potential"] = params["potential"](dtype, device) + calculator = CalculatorClass(**params, device=device, dtype=dtype) self.check_operation(calculator=calculator, device=device, dtype=dtype) def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a compiled torch script module.""" - calculator = CalculatorClass(**params) + params = params.copy() + params["potential"] = params["potential"](dtype, device) + calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) self.check_operation(calculator=scripted, device=device, dtype=dtype) def test_save_load(self, CalculatorClass, params, device, dtype): params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) diff --git a/tests/tuning/test_tuning.py b/tests/tuning/test_tuning.py index 3d1ce03f..cabea6d9 100644 --- a/tests/tuning/test_tuning.py +++ b/tests/tuning/test_tuning.py @@ -10,6 +10,7 @@ P3MCalculator, PMECalculator, ) +from torchpme._utils import _get_device, _get_dtype from torchpme.tuning import tune_ewald, tune_p3m, tune_pme from torchpme.tuning.tuner import TunerBase @@ -22,8 +23,8 @@ def system(device=None, dtype=None): - device = torch.get_default_device() if device is None else device - dtype = torch.get_default_dtype() if dtype is None else dtype + device = _get_device(device) + dtype = _get_dtype(dtype) charges = torch.ones((4, 1), dtype=dtype, device=device) cell = torch.eye(3, dtype=dtype, device=device)