From c2952d2d321243cdca5b7e0177fa071675b95677 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Wed, 29 Jan 2025 16:02:04 +0100 Subject: [PATCH] fix device om calculator init --- src/torchpme/_utils.py | 18 ++++++------------ src/torchpme/calculators/calculator.py | 2 +- src/torchpme/potentials/potential.py | 2 +- src/torchpme/tuning/tuner.py | 4 ++-- tests/calculators/test_workflow.py | 2 +- tests/helpers.py | 2 +- tests/metatensor/test_workflow_metatensor.py | 2 +- tests/tuning/test_tuning.py | 2 +- 8 files changed, 14 insertions(+), 20 deletions(-) diff --git a/src/torchpme/_utils.py b/src/torchpme/_utils.py index 8d42e45c..2ce9f68c 100644 --- a/src/torchpme/_utils.py +++ b/src/torchpme/_utils.py @@ -11,7 +11,7 @@ def _validate_parameters( neighbor_distances: torch.Tensor, smearing: Union[float, None], dtype: torch.dtype, - device: Union[str, torch.device], + device: torch.device, ) -> None: if positions.dtype != dtype: raise TypeError( @@ -19,18 +19,12 @@ def _validate_parameters( f"type ({dtype})" ) - if isinstance(device, torch.device): - device = device.type - - if positions.device.type != device: + if positions.device != device: raise ValueError( f"device of `positions` ({positions.device}) must be same as the class " f"device ({device})" ) - # We use `positions.device` because it includes the device type AND index, which the - # `device` parameter may lack - # check shape, dtype and device of positions num_atoms = len(positions) if list(positions.shape) != [len(positions), 3]: @@ -51,7 +45,7 @@ def _validate_parameters( f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})" ) - if cell.device != positions.device: + if cell.device != device: raise ValueError( f"device of `cell` ({cell.device}) must be same as the class ({device})" ) @@ -85,7 +79,7 @@ def _validate_parameters( f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})" ) - if charges.device != positions.device: + if charges.device != device: raise ValueError( f"device of `charges` ({charges.device}) must be same as the class " f"({device})" @@ -99,7 +93,7 @@ def _validate_parameters( "structure" ) - if neighbor_indices.device != positions.device: + if neighbor_indices.device != device: raise ValueError( f"device of `neighbor_indices` ({neighbor_indices.device}) must be " f"same as the class ({device})" @@ -112,7 +106,7 @@ def _validate_parameters( f"{list(neighbor_indices.shape)} and {list(neighbor_distances.shape)}" ) - if neighbor_distances.device != positions.device: + if neighbor_distances.device != device: raise ValueError( f"device of `neighbor_distances` ({neighbor_distances.device}) must be " f"same as the class ({device})" diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py index a7681fea..44d08280 100644 --- a/src/torchpme/calculators/calculator.py +++ b/src/torchpme/calculators/calculator.py @@ -46,7 +46,7 @@ def __init__( f"Potential must be an instance of Potential, got {type(potential)}" ) - self.device = torch.get_default_device() if device is None else device + self.device = torch.get_default_device() if device is None else torch.device(device) self.dtype = torch.get_default_dtype() if dtype is None else dtype if self.dtype != potential.dtype: diff --git a/src/torchpme/potentials/potential.py b/src/torchpme/potentials/potential.py index b6e7fb36..d9fbadb5 100644 --- a/src/torchpme/potentials/potential.py +++ b/src/torchpme/potentials/potential.py @@ -43,7 +43,7 @@ def __init__( ): super().__init__() self.dtype = torch.get_default_dtype() if dtype is None else dtype - self.device = torch.get_default_device() if device is None else device + self.device = torch.get_default_device() if device is None else torch.device(device) if smearing is not None: self.register_buffer( "smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype) diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py index 7ac463d6..1d885408 100644 --- a/src/torchpme/tuning/tuner.py +++ b/src/torchpme/tuning/tuner.py @@ -91,7 +91,7 @@ def __init__( f"Only exponent = 1 is supported but got {exponent}." ) - self.device = torch.get_default_device() if device is None else device + self.device = torch.get_default_device() if device is None else torch.device(device) self.dtype = torch.get_default_dtype() if dtype is None else dtype _validate_parameters( @@ -295,7 +295,7 @@ def __init__( super().__init__() self.dtype = torch.get_default_dtype() if dtype is None else dtype - self.device = torch.get_default_device() if device is None else device + self.device = torch.get_default_device() if device is None else torch.device(device) _validate_parameters( charges=charges, diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py index 209355d4..bc9914ad 100644 --- a/tests/calculators/test_workflow.py +++ b/tests/calculators/test_workflow.py @@ -60,7 +60,7 @@ class TestWorkflow: def cscl_system(self, device=None, dtype=None): """CsCl crystal. Same as in the madelung test""" - device = torch.get_default_device() if device is None else device + device = torch.get_default_device() if device is None else torch.device(device) dtype = torch.get_default_dtype() if dtype is None else dtype positions = torch.tensor( diff --git a/tests/helpers.py b/tests/helpers.py index bca5feb7..c20a29f2 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -15,7 +15,7 @@ def define_crystal(crystal_name="CsCl", dtype=None, device=None): - device = torch.get_default_device() if device is None else device + device = torch.get_default_device() if device is None else torch.device(device) dtype = torch.get_default_dtype() if dtype is None else dtype # Define all relevant parameters (atom positions, charges, cell) of the reference diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py index 6beed927..a1df0530 100644 --- a/tests/metatensor/test_workflow_metatensor.py +++ b/tests/metatensor/test_workflow_metatensor.py @@ -54,7 +54,7 @@ ) class TestWorkflow: def system(self, device=None, dtype=None): - device = torch.get_default_device() if device is None else device + device = torch.get_default_device() if device is None else torch.device(device) dtype = torch.get_default_dtype() if dtype is None else dtype system = mts_atomistic.System( diff --git a/tests/tuning/test_tuning.py b/tests/tuning/test_tuning.py index 3d1ce03f..5ceb7b52 100644 --- a/tests/tuning/test_tuning.py +++ b/tests/tuning/test_tuning.py @@ -22,7 +22,7 @@ def system(device=None, dtype=None): - device = torch.get_default_device() if device is None else device + device = torch.get_default_device() if device is None else torch.device(device) dtype = torch.get_default_dtype() if dtype is None else dtype charges = torch.ones((4, 1), dtype=dtype, device=device)