Skip to content

Commit

Permalink
fix device om calculator init
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jan 29, 2025
1 parent bf52afa commit c2952d2
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 20 deletions.
18 changes: 6 additions & 12 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,20 @@ def _validate_parameters(
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
dtype: torch.dtype,
device: Union[str, torch.device],
device: torch.device,
) -> None:
if positions.dtype != dtype:
raise TypeError(
f"type of `positions` ({positions.dtype}) must be same as the class "
f"type ({dtype})"
)

if isinstance(device, torch.device):
device = device.type

if positions.device.type != device:
if positions.device != device:
raise ValueError(
f"device of `positions` ({positions.device}) must be same as the class "
f"device ({device})"
)

# We use `positions.device` because it includes the device type AND index, which the
# `device` parameter may lack

# check shape, dtype and device of positions
num_atoms = len(positions)
if list(positions.shape) != [len(positions), 3]:
Expand All @@ -51,7 +45,7 @@ def _validate_parameters(
f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})"
)

if cell.device != positions.device:
if cell.device != device:
raise ValueError(
f"device of `cell` ({cell.device}) must be same as the class ({device})"
)
Expand Down Expand Up @@ -85,7 +79,7 @@ def _validate_parameters(
f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})"
)

if charges.device != positions.device:
if charges.device != device:
raise ValueError(
f"device of `charges` ({charges.device}) must be same as the class "
f"({device})"
Expand All @@ -99,7 +93,7 @@ def _validate_parameters(
"structure"
)

if neighbor_indices.device != positions.device:
if neighbor_indices.device != device:
raise ValueError(
f"device of `neighbor_indices` ({neighbor_indices.device}) must be "
f"same as the class ({device})"
Expand All @@ -112,7 +106,7 @@ def _validate_parameters(
f"{list(neighbor_indices.shape)} and {list(neighbor_distances.shape)}"
)

if neighbor_distances.device != positions.device:
if neighbor_distances.device != device:
raise ValueError(
f"device of `neighbor_distances` ({neighbor_distances.device}) must be "
f"same as the class ({device})"
Expand Down
2 changes: 1 addition & 1 deletion src/torchpme/calculators/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
f"Potential must be an instance of Potential, got {type(potential)}"
)

self.device = torch.get_default_device() if device is None else device
self.device = torch.get_default_device() if device is None else torch.device(device)
self.dtype = torch.get_default_dtype() if dtype is None else dtype

if self.dtype != potential.dtype:
Expand Down
2 changes: 1 addition & 1 deletion src/torchpme/potentials/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
):
super().__init__()
self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.device = torch.get_default_device() if device is None else device
self.device = torch.get_default_device() if device is None else torch.device(device)
if smearing is not None:
self.register_buffer(
"smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype)
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/tuning/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
f"Only exponent = 1 is supported but got {exponent}."
)

self.device = torch.get_default_device() if device is None else device
self.device = torch.get_default_device() if device is None else torch.device(device)
self.dtype = torch.get_default_dtype() if dtype is None else dtype

_validate_parameters(
Expand Down Expand Up @@ -295,7 +295,7 @@ def __init__(
super().__init__()

self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.device = torch.get_default_device() if device is None else device
self.device = torch.get_default_device() if device is None else torch.device(device)

_validate_parameters(
charges=charges,
Expand Down
2 changes: 1 addition & 1 deletion tests/calculators/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
class TestWorkflow:
def cscl_system(self, device=None, dtype=None):
"""CsCl crystal. Same as in the madelung test"""
device = torch.get_default_device() if device is None else device
device = torch.get_default_device() if device is None else torch.device(device)
dtype = torch.get_default_dtype() if dtype is None else dtype

positions = torch.tensor(
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def define_crystal(crystal_name="CsCl", dtype=None, device=None):
device = torch.get_default_device() if device is None else device
device = torch.get_default_device() if device is None else torch.device(device)
dtype = torch.get_default_dtype() if dtype is None else dtype

# Define all relevant parameters (atom positions, charges, cell) of the reference
Expand Down
2 changes: 1 addition & 1 deletion tests/metatensor/test_workflow_metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
)
class TestWorkflow:
def system(self, device=None, dtype=None):
device = torch.get_default_device() if device is None else device
device = torch.get_default_device() if device is None else torch.device(device)
dtype = torch.get_default_dtype() if dtype is None else dtype

system = mts_atomistic.System(
Expand Down
2 changes: 1 addition & 1 deletion tests/tuning/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def system(device=None, dtype=None):
device = torch.get_default_device() if device is None else device
device = torch.get_default_device() if device is None else torch.device(device)
dtype = torch.get_default_dtype() if dtype is None else dtype

charges = torch.ones((4, 1), dtype=dtype, device=device)
Expand Down

0 comments on commit c2952d2

Please sign in to comment.