diff --git a/src/torchpme/_utils.py b/src/torchpme/_utils.py index 2ce9f68c..1656d88a 100644 --- a/src/torchpme/_utils.py +++ b/src/torchpme/_utils.py @@ -3,6 +3,17 @@ import torch +def _get_device(device: Union[None, str, torch.device]) -> torch.device: + new_device = torch.get_default_device() if device is None else torch.device(device) + + # Add default index of 0 to a cuda device to avoid errors when comparing with + # devices from tensors + if new_device.type == "cuda" and new_device.index is None: + new_device = torch.device("cuda:0") + + return new_device + + def _validate_parameters( charges: torch.Tensor, cell: torch.Tensor, diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py index e9ee124c..1fef3e9a 100644 --- a/src/torchpme/calculators/calculator.py +++ b/src/torchpme/calculators/calculator.py @@ -3,7 +3,7 @@ import torch from torch import profiler -from .._utils import _validate_parameters +from .._utils import _get_device, _validate_parameters from ..potentials import Potential @@ -46,11 +46,7 @@ def __init__( f"Potential must be an instance of Potential, got {type(potential)}" ) - self.device = ( - torch.get_default_device() if device is None else torch.device(device) - ) - if self.device.type == "cuda" and self.device.index is None: - self.device = torch.device("cuda:0") + self.device = _get_device(device) self.dtype = torch.get_default_dtype() if dtype is None else dtype if self.dtype != potential.dtype: diff --git a/src/torchpme/potentials/potential.py b/src/torchpme/potentials/potential.py index 8d0cada0..ae65f888 100644 --- a/src/torchpme/potentials/potential.py +++ b/src/torchpme/potentials/potential.py @@ -2,6 +2,8 @@ import torch +from .._utils import _get_device + class Potential(torch.nn.Module): r""" @@ -42,12 +44,10 @@ def __init__( device: Union[None, str, torch.device] = None, ): super().__init__() + + self.device = _get_device(device) self.dtype = torch.get_default_dtype() if dtype is None else dtype - self.device = ( - torch.get_default_device() if device is None else torch.device(device) - ) - if self.device.type == "cuda" and self.device.index is None: - self.device = torch.device("cuda:0") + if smearing is not None: self.register_buffer( "smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype) diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py index 929b4616..1191b5a4 100644 --- a/src/torchpme/tuning/tuner.py +++ b/src/torchpme/tuning/tuner.py @@ -4,7 +4,7 @@ import torch -from .._utils import _validate_parameters +from .._utils import _get_device, _validate_parameters from ..calculators import Calculator from ..potentials import InversePowerLawPotential @@ -91,11 +91,7 @@ def __init__( f"Only exponent = 1 is supported but got {exponent}." ) - self.device = ( - torch.get_default_device() if device is None else torch.device(device) - ) - if self.device.type == "cuda" and self.device.index is None: - self.device = torch.device("cuda:0") + self.device = _get_device(device) self.dtype = torch.get_default_dtype() if dtype is None else dtype _validate_parameters( @@ -298,12 +294,8 @@ def __init__( ): super().__init__() + self.device = _get_device(device) self.dtype = torch.get_default_dtype() if dtype is None else dtype - self.device = ( - torch.get_default_device() if device is None else torch.device(device) - ) - if self.device.type == "cuda" and self.device.index is None: - self.device = torch.device("cuda:0") _validate_parameters( charges=charges,