From 1513c2d4953eb0283af534c1f4519d7e337b11ad Mon Sep 17 00:00:00 2001 From: E-Rum Date: Wed, 29 Jan 2025 16:55:50 +0000 Subject: [PATCH] Fix device initialization for CUDA in Calculator and Potential classes --- src/torchpme/calculators/calculator.py | 6 +- src/torchpme/potentials/potential.py | 6 +- src/torchpme/tuning/tuner.py | 12 +++- tests/calculators/test_workflow.py | 67 ++++++++++++++------ tests/metatensor/test_workflow_metatensor.py | 27 +++++--- 5 files changed, 83 insertions(+), 35 deletions(-) diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py index 44d08280..e9ee124c 100644 --- a/src/torchpme/calculators/calculator.py +++ b/src/torchpme/calculators/calculator.py @@ -46,7 +46,11 @@ def __init__( f"Potential must be an instance of Potential, got {type(potential)}" ) - self.device = torch.get_default_device() if device is None else torch.device(device) + self.device = ( + torch.get_default_device() if device is None else torch.device(device) + ) + if self.device.type == "cuda" and self.device.index is None: + self.device = torch.device("cuda:0") self.dtype = torch.get_default_dtype() if dtype is None else dtype if self.dtype != potential.dtype: diff --git a/src/torchpme/potentials/potential.py b/src/torchpme/potentials/potential.py index d9fbadb5..8d0cada0 100644 --- a/src/torchpme/potentials/potential.py +++ b/src/torchpme/potentials/potential.py @@ -43,7 +43,11 @@ def __init__( ): super().__init__() self.dtype = torch.get_default_dtype() if dtype is None else dtype - self.device = torch.get_default_device() if device is None else torch.device(device) + self.device = ( + torch.get_default_device() if device is None else torch.device(device) + ) + if self.device.type == "cuda" and self.device.index is None: + self.device = torch.device("cuda:0") if smearing is not None: self.register_buffer( "smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype) diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py index 1d885408..929b4616 100644 --- a/src/torchpme/tuning/tuner.py +++ b/src/torchpme/tuning/tuner.py @@ -91,7 +91,11 @@ def __init__( f"Only exponent = 1 is supported but got {exponent}." ) - self.device = torch.get_default_device() if device is None else torch.device(device) + self.device = ( + torch.get_default_device() if device is None else torch.device(device) + ) + if self.device.type == "cuda" and self.device.index is None: + self.device = torch.device("cuda:0") self.dtype = torch.get_default_dtype() if dtype is None else dtype _validate_parameters( @@ -295,7 +299,11 @@ def __init__( super().__init__() self.dtype = torch.get_default_dtype() if dtype is None else dtype - self.device = torch.get_default_device() if device is None else torch.device(device) + self.device = ( + torch.get_default_device() if device is None else torch.device(device) + ) + if self.device.type == "cuda" and self.device.index is None: + self.device = torch.device("cuda:0") _validate_parameters( charges=charges, diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py index bc9914ad..a3ada714 100644 --- a/tests/calculators/test_workflow.py +++ b/tests/calculators/test_workflow.py @@ -31,27 +31,35 @@ ( Calculator, { - "potential": CoulombPotential(smearing=None), + "potential": lambda dtype, device: CoulombPotential( + smearing=None, dtype=dtype, device=device + ), }, ), ( EwaldCalculator, { - "potential": CoulombPotential(smearing=SMEARING), + "potential": lambda dtype, device: CoulombPotential( + smearing=SMEARING, dtype=dtype, device=device + ), "lr_wavelength": LR_WAVELENGTH, }, ), ( PMECalculator, { - "potential": CoulombPotential(smearing=SMEARING), + "potential": lambda dtype, device: CoulombPotential( + smearing=SMEARING, dtype=dtype, device=device + ), "mesh_spacing": MESH_SPACING, }, ), ( P3MCalculator, { - "potential": CoulombPotential(smearing=SMEARING), + "potential": lambda dtype, device: CoulombPotential( + smearing=SMEARING, dtype=dtype, device=device + ), "mesh_spacing": MESH_SPACING, }, ), @@ -75,6 +83,7 @@ def cscl_system(self, device=None, dtype=None): def test_smearing_non_positive(self, CalculatorClass, params, device, dtype): params = params.copy() + params["potential"] = params["potential"](dtype, device) match = r"`smearing` .* has to be positive" if type(CalculatorClass) in [EwaldCalculator, PMECalculator]: params["smearing"] = 0 @@ -86,6 +95,7 @@ def test_smearing_non_positive(self, CalculatorClass, params, device, dtype): def test_interpolation_order_error(self, CalculatorClass, params, device, dtype): params = params.copy() + params["potential"] = params["potential"](dtype, device) if type(CalculatorClass) in [PMECalculator]: match = "Only `interpolation_nodes` from 1 to 5" params["interpolation_nodes"] = 10 @@ -94,6 +104,7 @@ def test_interpolation_order_error(self, CalculatorClass, params, device, dtype) def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype): params = params.copy() + params["potential"] = params["potential"](dtype, device) match = r"`lr_wavelength` .* has to be positive" if type(CalculatorClass) in [EwaldCalculator]: params["lr_wavelength"] = 0 @@ -106,8 +117,7 @@ def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype def test_dtype_device(self, CalculatorClass, params, device, dtype): """Test that the output dtype and device are the same as the input.""" params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) potential = calculator.forward(*self.cscl_system(device=device, dtype=dtype)) @@ -127,8 +137,7 @@ def check_operation(self, calculator, device, dtype): def test_operation_as_python(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a normal python script""" params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) self.check_operation(calculator=calculator, device=device, dtype=dtype) @@ -136,8 +145,7 @@ def test_operation_as_python(self, CalculatorClass, params, device, dtype): def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a compiled torch script module.""" params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) @@ -145,8 +153,7 @@ def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype) def test_save_load(self, CalculatorClass, params, device, dtype): params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) @@ -158,8 +165,7 @@ def test_save_load(self, CalculatorClass, params, device, dtype): def test_prefactor(self, CalculatorClass, params, device, dtype): """Test if the prefactor is applied correctly.""" params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) prefactor = 2.0 calculator1 = CalculatorClass(**params, device=device, dtype=dtype) @@ -175,8 +181,7 @@ def test_prefactor(self, CalculatorClass, params, device, dtype): def test_not_nan(self, CalculatorClass, params, device, dtype): """Make sure derivatives are not NaN.""" params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) system = self.cscl_system(device=device, dtype=dtype) @@ -212,9 +217,7 @@ def test_dtype_and_device_incompatability( params = params.copy() other_dtype = torch.float32 if dtype == torch.float64 else torch.float64 - - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) match = ( rf"dtype of `potential` \({params['potential'].dtype}\) must be same as " @@ -239,11 +242,33 @@ def test_potential_and_calculator_incompatability( ): """Test that the calculator raises an error if the potential and calculator are incompatible.""" params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) params["potential"] = torch.jit.script(params["potential"]) with pytest.raises( TypeError, match="Potential must be an instance of Potential, got.*" ): CalculatorClass(**params, device=device, dtype=dtype) + + def test_device_string_compatability(self, CalculatorClass, params, dtype, device): + """Test that the calculator works with device strings.""" + params = params.copy() + params["potential"] = params["potential"](dtype, "cpu") + calculator = CalculatorClass( + **params, + device=torch.device("cpu"), + dtype=dtype, + ) + + assert calculator.device == params["potential"].device + + def test_device_index_compatability(self, CalculatorClass, params, dtype, device): + """Test that the calculator works with no index on the device.""" + if torch.cuda.is_available(): + params = params.copy() + params["potential"] = params["potential"](dtype, "cuda") + calculator = CalculatorClass( + **params, device=torch.device("cuda:0"), dtype=dtype + ) + + assert calculator.device == params["potential"].device diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py index 2d7b7f16..80fe7e70 100644 --- a/tests/metatensor/test_workflow_metatensor.py +++ b/tests/metatensor/test_workflow_metatensor.py @@ -26,27 +26,35 @@ ( torchpme.metatensor.Calculator, { - "potential": torchpme.CoulombPotential(smearing=None), + "potential": lambda dtype, device: torchpme.CoulombPotential( + smearing=None, dtype=dtype, device=device + ), }, ), ( torchpme.metatensor.EwaldCalculator, { - "potential": torchpme.CoulombPotential(smearing=SMEARING), + "potential": lambda dtype, device: torchpme.CoulombPotential( + smearing=SMEARING, dtype=dtype, device=device + ), "lr_wavelength": LR_WAVELENGTH, }, ), ( torchpme.metatensor.PMECalculator, { - "potential": torchpme.CoulombPotential(smearing=SMEARING), + "potential": lambda dtype, device: torchpme.CoulombPotential( + smearing=SMEARING, dtype=dtype, device=device + ), "mesh_spacing": MESH_SPACING, }, ), ( torchpme.metatensor.P3MCalculator, { - "potential": torchpme.CoulombPotential(smearing=SMEARING), + "potential": lambda dtype, device: torchpme.CoulombPotential( + smearing=SMEARING, dtype=dtype, device=device + ), "mesh_spacing": MESH_SPACING, }, ), @@ -109,23 +117,22 @@ def check_operation(self, calculator, device, dtype): def test_operation_as_python(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a normal python script""" - params["potential"].device = device - params["potential"].dtype = dtype + params = params.copy() + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) self.check_operation(calculator=calculator, device=device, dtype=dtype) def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a compiled torch script module.""" - params["potential"].device = device - params["potential"].dtype = dtype + params = params.copy() + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) self.check_operation(calculator=scripted, device=device, dtype=dtype) def test_save_load(self, CalculatorClass, params, device, dtype): params = params.copy() - params["potential"].device = device - params["potential"].dtype = dtype + params["potential"] = params["potential"](dtype, device) calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator)