Skip to content

Commit

Permalink
Added private function to get device
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jan 30, 2025
1 parent 1513c2d commit 4011f22
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 22 deletions.
11 changes: 11 additions & 0 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@
import torch


def _get_device(device: Union[None, str, torch.device]) -> torch.device:
new_device = torch.get_default_device() if device is None else torch.device(device)

# Add default index of 0 to a cuda device to avoid errors when comparing with
# devices from tensors
if new_device.type == "cuda" and new_device.index is None:
new_device = torch.device("cuda:0")

return new_device


def _validate_parameters(
charges: torch.Tensor,
cell: torch.Tensor,
Expand Down
8 changes: 2 additions & 6 deletions src/torchpme/calculators/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import profiler

from .._utils import _validate_parameters
from .._utils import _get_device, _validate_parameters
from ..potentials import Potential


Expand Down Expand Up @@ -46,11 +46,7 @@ def __init__(
f"Potential must be an instance of Potential, got {type(potential)}"
)

self.device = (
torch.get_default_device() if device is None else torch.device(device)
)
if self.device.type == "cuda" and self.device.index is None:
self.device = torch.device("cuda:0")
self.device = _get_device(device)
self.dtype = torch.get_default_dtype() if dtype is None else dtype

if self.dtype != potential.dtype:
Expand Down
10 changes: 5 additions & 5 deletions src/torchpme/potentials/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from .._utils import _get_device


class Potential(torch.nn.Module):
r"""
Expand Down Expand Up @@ -42,12 +44,10 @@ def __init__(
device: Union[None, str, torch.device] = None,
):
super().__init__()

self.device = _get_device(device)
self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.device = (
torch.get_default_device() if device is None else torch.device(device)
)
if self.device.type == "cuda" and self.device.index is None:
self.device = torch.device("cuda:0")

if smearing is not None:
self.register_buffer(
"smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype)
Expand Down
14 changes: 3 additions & 11 deletions src/torchpme/tuning/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from .._utils import _validate_parameters
from .._utils import _get_device, _validate_parameters
from ..calculators import Calculator
from ..potentials import InversePowerLawPotential

Expand Down Expand Up @@ -91,11 +91,7 @@ def __init__(
f"Only exponent = 1 is supported but got {exponent}."
)

self.device = (
torch.get_default_device() if device is None else torch.device(device)
)
if self.device.type == "cuda" and self.device.index is None:
self.device = torch.device("cuda:0")
self.device = _get_device(device)
self.dtype = torch.get_default_dtype() if dtype is None else dtype

_validate_parameters(
Expand Down Expand Up @@ -298,12 +294,8 @@ def __init__(
):
super().__init__()

self.device = _get_device(device)
self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.device = (
torch.get_default_device() if device is None else torch.device(device)
)
if self.device.type == "cuda" and self.device.index is None:
self.device = torch.device("cuda:0")

_validate_parameters(
charges=charges,
Expand Down

0 comments on commit 4011f22

Please sign in to comment.