Skip to content

Commit

Permalink
Merge "main"
Browse files Browse the repository at this point in the history
  • Loading branch information
E-Rum committed Jan 31, 2025
1 parent a31ed9b commit ec6d55e
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 61 deletions.
35 changes: 22 additions & 13 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
from typing import Union
from typing import Optional, Union

import torch


def _get_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
return torch.get_default_dtype() if dtype is None else dtype


def _get_device(device: Union[None, str, torch.device]) -> torch.device:
new_device = torch.get_default_device() if device is None else torch.device(device)

# Add default index of 0 to a cuda device to avoid errors when comparing with
# devices from tensors
if new_device.type == "cuda" and new_device.index is None:
new_device = torch.device("cuda:0")

return new_device


def _validate_parameters(
charges: torch.Tensor,
cell: torch.Tensor,
Expand All @@ -11,26 +26,20 @@ def _validate_parameters(
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
dtype: torch.dtype,
device: Union[str, torch.device],
device: torch.device,
) -> None:
if positions.dtype != dtype:
raise TypeError(
f"type of `positions` ({positions.dtype}) must be same as the class "
f"type ({dtype})"
)

if isinstance(device, torch.device):
device = device.type

if positions.device.type != device:
if positions.device != device:
raise ValueError(
f"device of `positions` ({positions.device}) must be same as the class "
f"device ({device})"
)

# We use `positions.device` because it includes the device type AND index, which the
# `device` parameter may lack

# check shape, dtype and device of positions
num_atoms = len(positions)
if list(positions.shape) != [len(positions), 3]:
Expand All @@ -51,7 +60,7 @@ def _validate_parameters(
f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})"
)

if cell.device != positions.device:
if cell.device != device:
raise ValueError(
f"device of `cell` ({cell.device}) must be same as the class ({device})"
)
Expand Down Expand Up @@ -85,7 +94,7 @@ def _validate_parameters(
f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})"
)

if charges.device != positions.device:
if charges.device != device:
raise ValueError(
f"device of `charges` ({charges.device}) must be same as the class "
f"({device})"
Expand All @@ -99,7 +108,7 @@ def _validate_parameters(
"structure"
)

if neighbor_indices.device != positions.device:
if neighbor_indices.device != device:
raise ValueError(
f"device of `neighbor_indices` ({neighbor_indices.device}) must be "
f"same as the class ({device})"
Expand All @@ -112,7 +121,7 @@ def _validate_parameters(
f"{list(neighbor_indices.shape)} and {list(neighbor_distances.shape)}"
)

if neighbor_distances.device != positions.device:
if neighbor_distances.device != device:
raise ValueError(
f"device of `neighbor_distances` ({neighbor_distances.device}) must be "
f"same as the class ({device})"
Expand Down
6 changes: 3 additions & 3 deletions src/torchpme/calculators/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import profiler

from .._utils import _validate_parameters
from .._utils import _get_device, _get_dtype, _validate_parameters
from ..potentials import Potential


Expand Down Expand Up @@ -46,8 +46,8 @@ def __init__(
f"Potential must be an instance of Potential, got {type(potential)}"
)

self.device = torch.get_default_device() if device is None else device
self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.device = _get_device(device)
self.dtype = _get_dtype(dtype)

if self.dtype != potential.dtype:
raise TypeError(
Expand Down
8 changes: 6 additions & 2 deletions src/torchpme/potentials/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from .._utils import _get_device, _get_dtype


class Potential(torch.nn.Module):
r"""
Expand Down Expand Up @@ -42,8 +44,10 @@ def __init__(
device: Union[None, str, torch.device] = None,
):
super().__init__()
self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.device = torch.get_default_device() if device is None else device

self.device = _get_device(device)
self.dtype = _get_dtype(dtype)

if smearing is not None:
self.register_buffer(
"smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype)
Expand Down
10 changes: 5 additions & 5 deletions src/torchpme/tuning/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from .._utils import _validate_parameters
from .._utils import _get_device, _get_dtype, _validate_parameters
from ..calculators import Calculator
from ..potentials import InversePowerLawPotential

Expand Down Expand Up @@ -91,8 +91,8 @@ def __init__(
f"Only exponent = 1 is supported but got {exponent}."
)

self.device = torch.get_default_device() if device is None else device
self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.device = _get_device(device)
self.dtype = _get_dtype(dtype)

_validate_parameters(
charges=charges,
Expand Down Expand Up @@ -294,8 +294,8 @@ def __init__(
):
super().__init__()

self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.device = torch.get_default_device() if device is None else device
self.device = _get_device(device)
self.dtype = _get_dtype(dtype)

_validate_parameters(
charges=charges,
Expand Down
72 changes: 49 additions & 23 deletions tests/calculators/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
P3MCalculator,
PMECalculator,
)
from torchpme._utils import _get_device, _get_dtype

DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"]
DTYPES = [torch.float32, torch.float64]
Expand All @@ -31,27 +32,35 @@
(
Calculator,
{
"potential": CoulombPotential(smearing=None),
"potential": lambda dtype, device: CoulombPotential(
smearing=None, dtype=dtype, device=device
),
},
),
(
EwaldCalculator,
{
"potential": CoulombPotential(smearing=SMEARING),
"potential": lambda dtype, device: CoulombPotential(
smearing=SMEARING, dtype=dtype, device=device
),
"lr_wavelength": LR_WAVELENGTH,
},
),
(
PMECalculator,
{
"potential": CoulombPotential(smearing=SMEARING),
"potential": lambda dtype, device: CoulombPotential(
smearing=SMEARING, dtype=dtype, device=device
),
"mesh_spacing": MESH_SPACING,
},
),
(
P3MCalculator,
{
"potential": CoulombPotential(smearing=SMEARING),
"potential": lambda dtype, device: CoulombPotential(
smearing=SMEARING, dtype=dtype, device=device
),
"mesh_spacing": MESH_SPACING,
},
),
Expand All @@ -60,8 +69,8 @@
class TestWorkflow:
def cscl_system(self, device=None, dtype=None):
"""CsCl crystal. Same as in the madelung test"""
device = torch.get_default_device() if device is None else device
dtype = torch.get_default_dtype() if dtype is None else dtype
device = _get_device(device)
dtype = _get_dtype(dtype)

positions = torch.tensor(
[[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype, device=device
Expand All @@ -75,6 +84,7 @@ def cscl_system(self, device=None, dtype=None):

def test_smearing_non_positive(self, CalculatorClass, params, device, dtype):
params = params.copy()
params["potential"] = params["potential"](dtype, device)
match = r"`smearing` .* has to be positive"
if type(CalculatorClass) in [EwaldCalculator, PMECalculator]:
params["smearing"] = 0
Expand All @@ -86,6 +96,7 @@ def test_smearing_non_positive(self, CalculatorClass, params, device, dtype):

def test_interpolation_order_error(self, CalculatorClass, params, device, dtype):
params = params.copy()
params["potential"] = params["potential"](dtype, device)
if type(CalculatorClass) in [PMECalculator]:
match = "Only `interpolation_nodes` from 1 to 5"
params["interpolation_nodes"] = 10
Expand All @@ -94,6 +105,7 @@ def test_interpolation_order_error(self, CalculatorClass, params, device, dtype)

def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype):
params = params.copy()
params["potential"] = params["potential"](dtype, device)
match = r"`lr_wavelength` .* has to be positive"
if type(CalculatorClass) in [EwaldCalculator]:
params["lr_wavelength"] = 0
Expand All @@ -106,8 +118,7 @@ def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype
def test_dtype_device(self, CalculatorClass, params, device, dtype):
"""Test that the output dtype and device are the same as the input."""
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

calculator = CalculatorClass(**params, device=device, dtype=dtype)
potential = calculator.forward(*self.cscl_system(device=device, dtype=dtype))
Expand All @@ -127,26 +138,23 @@ def check_operation(self, calculator, device, dtype):
def test_operation_as_python(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a normal python script"""
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

calculator = CalculatorClass(**params, device=device, dtype=dtype)
self.check_operation(calculator=calculator, device=device, dtype=dtype)

def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a compiled torch script module."""
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

calculator = CalculatorClass(**params, device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
self.check_operation(calculator=scripted, device=device, dtype=dtype)

def test_save_load(self, CalculatorClass, params, device, dtype):
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

calculator = CalculatorClass(**params, device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
Expand All @@ -158,8 +166,7 @@ def test_save_load(self, CalculatorClass, params, device, dtype):
def test_prefactor(self, CalculatorClass, params, device, dtype):
"""Test if the prefactor is applied correctly."""
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

prefactor = 2.0
calculator1 = CalculatorClass(**params, device=device, dtype=dtype)
Expand All @@ -175,8 +182,7 @@ def test_prefactor(self, CalculatorClass, params, device, dtype):
def test_not_nan(self, CalculatorClass, params, device, dtype):
"""Make sure derivatives are not NaN."""
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

calculator = CalculatorClass(**params, device=device, dtype=dtype)
system = self.cscl_system(device=device, dtype=dtype)
Expand Down Expand Up @@ -212,9 +218,7 @@ def test_dtype_and_device_incompatability(
params = params.copy()

other_dtype = torch.float32 if dtype == torch.float64 else torch.float64

params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

match = (
rf"dtype of `potential` \({params['potential'].dtype}\) must be same as "
Expand All @@ -239,11 +243,33 @@ def test_potential_and_calculator_incompatability(
):
"""Test that the calculator raises an error if the potential and calculator are incompatible."""
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

params["potential"] = torch.jit.script(params["potential"])
with pytest.raises(
TypeError, match="Potential must be an instance of Potential, got.*"
):
CalculatorClass(**params, device=device, dtype=dtype)

def test_device_string_compatability(self, CalculatorClass, params, dtype, device):
"""Test that the calculator works with device strings."""
params = params.copy()
params["potential"] = params["potential"](dtype, "cpu")
calculator = CalculatorClass(
**params,
device=torch.device("cpu"),
dtype=dtype,
)

assert calculator.device == params["potential"].device

def test_device_index_compatability(self, CalculatorClass, params, dtype, device):
"""Test that the calculator works with no index on the device."""
if torch.cuda.is_available():
params = params.copy()
params["potential"] = params["potential"](dtype, "cuda")
calculator = CalculatorClass(
**params, device=torch.device("cuda:0"), dtype=dtype
)

assert calculator.device == params["potential"].device
6 changes: 4 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch
from vesin import NeighborList

from torchpme._utils import _get_device, _get_dtype

SQRT3 = math.sqrt(3)

DIR_PATH = Path(__file__).parent
Expand All @@ -15,8 +17,8 @@


def define_crystal(crystal_name="CsCl", dtype=None, device=None):
device = torch.get_default_device() if device is None else device
dtype = torch.get_default_dtype() if dtype is None else dtype
device = _get_device(device)
dtype = _get_dtype(dtype)

# Define all relevant parameters (atom positions, charges, cell) of the reference
# crystal structures for which the Madelung constants obtained from the Ewald sums
Expand Down
Loading

0 comments on commit ec6d55e

Please sign in to comment.