Skip to content

Commit

Permalink
Fix tests for the case when CUDA is available on the system
Browse files Browse the repository at this point in the history
  • Loading branch information
E-Rum committed Jan 29, 2025
1 parent bf52afa commit 04d16ba
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tests/metatensor/test_workflow_metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def system(self, device=None, dtype=None):
properties=mts_torch.Labels.range("distance", 1),
)

return system.to(device=device), neighbors.to(device=device)
return system.to(device=device, dtype=dtype), neighbors.to(
device=device, dtype=dtype
)

def check_operation(self, calculator, device, dtype):
"""Make sure computation runs and returns a metatensor.TensorMap."""
Expand All @@ -107,12 +109,16 @@ def check_operation(self, calculator, device, dtype):

def test_operation_as_python(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a normal python script"""
calculator = CalculatorClass(**params)
params["potential"].device = device
params["potential"].dtype = dtype
calculator = CalculatorClass(**params, device=device, dtype=dtype)
self.check_operation(calculator=calculator, device=device, dtype=dtype)

def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a compiled torch script module."""
calculator = CalculatorClass(**params)
params["potential"].device = device
params["potential"].dtype = dtype
calculator = CalculatorClass(**params, device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
self.check_operation(calculator=scripted, device=device, dtype=dtype)

Expand Down

0 comments on commit 04d16ba

Please sign in to comment.