Skip to content

Commit

Permalink
Make dtype and device implicit in classes (#166)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Philip Loche <ploche@physik.fu-berlin.de>
  • Loading branch information
E-Rum and PicoCentauri authored Feb 11, 2025
1 parent ddff22b commit fb760cd
Show file tree
Hide file tree
Showing 32 changed files with 237 additions and 547 deletions.
10 changes: 4 additions & 6 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
Added
#####

* Enhanced ``device`` and ``dtype`` consistency checks throughout the library
* Better documentation for for ``cell``, ``charges`` and ``positions`` parameters
* Require consistent ``dtype`` between ``positions`` and ``neighbor_distances`` in
``Calculator`` classes and tuning functions.

Fixed
#####
Removed
#######

* Fix ``device`` and ``dtype`` not being used in the init of the ``P3MCalculator``
* Remove ``device`` and ``dtype`` from init of ``Calculator``, ``Potential`` and
``Tuning`` classes

`Version 0.2.0 <https://github.com/lab-cosmo/torch-pme/releases/tag/v0.2.0>`_ - 2025-01-23
------------------------------------------------------------------------------------------
Expand Down
9 changes: 4 additions & 5 deletions examples/01-charges-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
dtype=dtype,
)

# %%
Expand Down Expand Up @@ -103,9 +102,9 @@
# will be used to *compute* the potential energy of the system.

calculator = torchpme.PMECalculator(
torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params
torchpme.CoulombPotential(smearing=smearing), **pme_params
)

calculator.to(dtype=dtype)
# %%
#
# Single Charge Channel
Expand Down Expand Up @@ -207,9 +206,9 @@
# creating a new calculator with the metatensor interface.

calculator_metatensor = torchpme.metatensor.PMECalculator(
torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params
torchpme.CoulombPotential(smearing=smearing), **pme_params
)

calculator_metatensor.to(dtype=dtype)
# %%
#
# Computation with metatensor involves using Metatensor's :class:`System
Expand Down
4 changes: 1 addition & 3 deletions examples/02-neighbor-lists-usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
dtype=dtype,
)

# %%
Expand Down Expand Up @@ -195,8 +194,7 @@ def distances(
# compute the potential.

pme = torchpme.PMECalculator(
potential=torchpme.CoulombPotential(smearing=smearing, dtype=dtype),
dtype=dtype,
potential=torchpme.CoulombPotential(smearing=smearing),
**pme_params,
)
potential = pme(
Expand Down
2 changes: 1 addition & 1 deletion examples/07-lode-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def __init__(self, potential: Potential, n_grid: int = 3):
)

# assumes a smooth exclusion region so sets the integration cutoff to half that
nodes, weights = get_full_grid(n_grid, potential.exclusion_radius.item() / 2)
nodes, weights = get_full_grid(n_grid, potential.exclusion_radius / 2)

# these are the "stencils" used to project the potential
# on an atom-centered basis. NB: weights might also be incorporated
Expand Down
14 changes: 8 additions & 6 deletions examples/08-combined-potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@
# evaluation, and so one has to set it also for the combined potential, even if it is
# not used explicitly in the evaluation of the combination.

pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype)
pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing, dtype=dtype)

potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing, dtype=dtype)
pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing)
pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing)
pot_1 = pot_1.to(dtype=dtype)
pot_2 = pot_2.to(dtype=dtype)
potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing)
potential = potential.to(dtype=dtype)

# Note also that :class:`CombinedPotential` can be used with any combination of
# potentials, as long they are all either direct or range separated. For instance, one
Expand Down Expand Up @@ -156,9 +158,9 @@
# much bigger system.

calculator = EwaldCalculator(
potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A, dtype=dtype
potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A
)

calculator.to(dtype=dtype)

# %%
#
Expand Down
21 changes: 4 additions & 17 deletions examples/10-tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,10 @@
pme_params = {"mesh_spacing": 1.0, "interpolation_nodes": 4}

pme = torchpme.PMECalculator(
potential=torchpme.CoulombPotential(smearing=smearing, device=device, dtype=dtype),
device=device,
dtype=dtype,
potential=torchpme.CoulombPotential(smearing=smearing),
**pme_params, # type: ignore[arg-type]
)

pme.to(device=device, dtype=dtype)
# %%
# Run the calculator
# ~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -170,8 +168,6 @@
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
run_backward=True,
device=device,
dtype=dtype,
)
estimated_timing = timings(pme)

Expand Down Expand Up @@ -220,14 +216,11 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device,
)

pme = torchpme.PMECalculator(
potential=torchpme.CoulombPotential(
smearing=smearing, device=device, dtype=dtype
),
potential=torchpme.CoulombPotential(smearing=smearing),
mesh_spacing=mesh_spacing,
interpolation_nodes=interpolation_nodes,
device=device,
dtype=dtype,
)
pme.to(device=device, dtype=dtype)
potential = pme(
charges=charges,
cell=cell,
Expand All @@ -247,8 +240,6 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device,
run_backward=True,
n_warmup=1,
n_repeat=4,
device=device,
dtype=dtype,
)
estimated_timing = timings(pme)
return madelung, estimated_timing
Expand Down Expand Up @@ -457,8 +448,6 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device,
cutoff=5.0,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
device=device,
dtype=dtype,
)

print(
Expand Down Expand Up @@ -492,8 +481,6 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device,
cutoff=cutoff,
neighbor_indices=filter_indices,
neighbor_distances=filter_distances,
device=device,
dtype=dtype,
)
timings_grid.append(timing)

Expand Down
9 changes: 4 additions & 5 deletions examples/basic-usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@
# contains all the necessary functions (such as those defining the short-range and
# long-range splits) for this potential and makes them useable in the rest of the code.

potential = CoulombPotential(smearing=smearing, device=device, dtype=dtype)
potential = CoulombPotential(smearing=smearing)
potential.to(device=device, dtype=dtype)

# %%
#
Expand Down Expand Up @@ -193,10 +194,8 @@
# Since our structure is relatively small, we use the :class:`EwaldCalculator`.
# We start by the initialization of the class.

calculator = EwaldCalculator(
potential=potential, lr_wavelength=lr_wavelength, device=device, dtype=dtype
)

calculator = EwaldCalculator(potential=potential, lr_wavelength=lr_wavelength)
calculator.to(device=device, dtype=dtype)
# %%
#
# Compute Energy
Expand Down
52 changes: 13 additions & 39 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,18 @@
from typing import Optional, Union
from typing import Union

import torch


def _get_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
return torch.get_default_dtype() if dtype is None else dtype


def _get_device(device: Union[None, str, torch.device]) -> torch.device:
new_device = torch.get_default_device() if device is None else torch.device(device)

# Add default index of 0 to a cuda device to avoid errors when comparing with
# devices from tensors
if new_device.type == "cuda" and new_device.index is None:
new_device = torch.device("cuda:0")

return new_device


def _validate_parameters(
charges: torch.Tensor,
cell: torch.Tensor,
positions: torch.Tensor,
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
dtype: torch.dtype,
device: torch.device,
) -> None:
if positions.dtype != dtype:
raise TypeError(
f"type of `positions` ({positions.dtype}) must be same as the class "
f"type ({dtype})"
)

if positions.device != device:
raise ValueError(
f"device of `positions` ({positions.device}) must be same as the class "
f"device ({device})"
)
dtype = positions.dtype
device = positions.device

# check shape, dtype and device of positions
num_atoms = len(positions)
Expand All @@ -55,14 +29,14 @@ def _validate_parameters(
f"{list(cell.shape)}"
)

if cell.dtype != positions.dtype:
if cell.dtype != dtype:
raise TypeError(
f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})"
f"type of `cell` ({cell.dtype}) must be same as that of the `positions` class ({dtype})"
)

if cell.device != device:
raise ValueError(
f"device of `cell` ({cell.device}) must be same as the class ({device})"
f"device of `cell` ({cell.device}) must be same as that of the `positions` class ({device})"
)

if smearing is not None and torch.equal(
Expand All @@ -89,14 +63,14 @@ def _validate_parameters(
f"{len(positions)} atoms"
)

if charges.dtype != positions.dtype:
if charges.dtype != dtype:
raise TypeError(
f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})"
f"type of `charges` ({charges.dtype}) must be same as that of the `positions` class ({dtype})"
)

if charges.device != device:
raise ValueError(
f"device of `charges` ({charges.device}) must be same as the class "
f"device of `charges` ({charges.device}) must be same as that of the `positions` class "
f"({device})"
)

Expand All @@ -111,7 +85,7 @@ def _validate_parameters(
if neighbor_indices.device != device:
raise ValueError(
f"device of `neighbor_indices` ({neighbor_indices.device}) must be "
f"same as the class ({device})"
f"same as that of the `positions` class ({device})"
)

if neighbor_distances.shape != neighbor_indices[:, 0].shape:
Expand All @@ -124,11 +98,11 @@ def _validate_parameters(
if neighbor_distances.device != device:
raise ValueError(
f"device of `neighbor_distances` ({neighbor_distances.device}) must be "
f"same as the class ({device})"
f"same as that of the `positions` class ({device})"
)

if neighbor_distances.dtype != positions.dtype:
if neighbor_distances.dtype != dtype:
raise TypeError(
f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same "
f"as the class ({dtype})"
f"as that of the `positions` class ({dtype})"
)
26 changes: 1 addition & 25 deletions src/torchpme/calculators/calculator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Optional, Union

import torch
from torch import profiler

from .._utils import _get_device, _get_dtype, _validate_parameters
from .._utils import _validate_parameters
from ..potentials import Potential


Expand All @@ -27,17 +25,13 @@ class Calculator(torch.nn.Module):
will come from a full (True) or half (False, default) neighbor list.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
"""

def __init__(
self,
potential: Potential,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__()

Expand All @@ -46,24 +40,8 @@ def __init__(
f"Potential must be an instance of Potential, got {type(potential)}"
)

self.device = _get_device(device)
self.dtype = _get_dtype(dtype)

if self.dtype != potential.dtype:
raise TypeError(
f"dtype of `potential` ({potential.dtype}) must be same as of "
f"`calculator` ({self.dtype})"
)

if self.device != potential.device:
raise ValueError(
f"device of `potential` ({potential.device}) must be same as of "
f"`calculator` ({self.device})"
)

self.potential = potential
self.full_neighbor_list = full_neighbor_list

self.prefactor = prefactor

def _compute_rspace(
Expand Down Expand Up @@ -164,8 +142,6 @@ def forward(
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
smearing=self.potential.smearing,
dtype=self.dtype,
device=self.device,
)

# Compute short-range (SR) part using a real space sum
Expand Down
Loading

0 comments on commit fb760cd

Please sign in to comment.