diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py index 6beed927..bb76f0ad 100644 --- a/tests/metatensor/test_workflow_metatensor.py +++ b/tests/metatensor/test_workflow_metatensor.py @@ -94,7 +94,9 @@ def system(self, device=None, dtype=None): properties=mts_torch.Labels.range("distance", 1), ) - return system.to(device=device), neighbors.to(device=device) + return system.to(device=device, dtype=dtype), neighbors.to( + device=device, dtype=dtype + ) def check_operation(self, calculator, device, dtype): """Make sure computation runs and returns a metatensor.TensorMap.""" @@ -107,12 +109,16 @@ def check_operation(self, calculator, device, dtype): def test_operation_as_python(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a normal python script""" - calculator = CalculatorClass(**params) + params["potential"].device = device + params["potential"].dtype = dtype + calculator = CalculatorClass(**params, device=device, dtype=dtype) self.check_operation(calculator=calculator, device=device, dtype=dtype) def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a compiled torch script module.""" - calculator = CalculatorClass(**params) + params["potential"].device = device + params["potential"].dtype = dtype + calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) self.check_operation(calculator=scripted, device=device, dtype=dtype)