Skip to content

Commit

Permalink
Fix device and dtype not being specified in the __init__ of `P3…
Browse files Browse the repository at this point in the history
…MCalculator` (#159)
  • Loading branch information
GardevoirX authored Jan 29, 2025
1 parent b22c7a2 commit bf52afa
Show file tree
Hide file tree
Showing 27 changed files with 416 additions and 266 deletions.
6 changes: 6 additions & 0 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
Added
#####

* Enhanced ``device`` and ``dtype`` consistency checks throughout the library
* Better documentation for for ``cell``, ``charges`` and ``positions`` parameters
* Require consistent ``dtype`` between ``positions`` and ``neighbor_distances`` in
``Calculator`` classes and tuning functions.

Fixed
#####

* Fix ``device`` and ``dtype`` not being used in the init of the ``P3MCalculator``

`Version 0.2.0 <https://github.com/lab-cosmo/torch-pme/releases/tag/v0.2.0>`_ - 2025-01-23
------------------------------------------------------------------------------------------

Expand Down
18 changes: 10 additions & 8 deletions examples/01-charges-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@
import torchpme
from torchpme.tuning import tune_pme

dtype = torch.float64

# %%
#
# Create the properties CsCl unit cell

symbols = ("Cs", "Cl")
types = torch.tensor([55, 17])
charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64)
positions = torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=torch.float64)
cell = torch.eye(3, dtype=torch.float64)
charges = torch.tensor([[1.0], [-1.0]], dtype=dtype)
positions = torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=dtype)
cell = torch.eye(3, dtype=dtype)
pbc = torch.tensor([True, True, True])


Expand All @@ -72,6 +73,7 @@
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
dtype=dtype,
)

# %%
Expand Down Expand Up @@ -101,7 +103,7 @@
# will be used to *compute* the potential energy of the system.

calculator = torchpme.PMECalculator(
torchpme.CoulombPotential(smearing=smearing), **pme_params
torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params
)

# %%
Expand All @@ -112,7 +114,7 @@
# As a first application of multiple charge channels, we start simply by using the
# classic definition of one charge channel per atom.

charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64)
charges = torch.tensor([[1.0], [-1.0]], dtype=dtype)

# %%
#
Expand Down Expand Up @@ -160,7 +162,7 @@
# species-specific potentials and facilitating the learning process for machine learning
# algorithms.

charges_one_hot = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float64)
charges_one_hot = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=dtype)

# %%
#
Expand Down Expand Up @@ -205,7 +207,7 @@
# creating a new calculator with the metatensor interface.

calculator_metatensor = torchpme.metatensor.PMECalculator(
torchpme.CoulombPotential(smearing=smearing), **pme_params
torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params
)

# %%
Expand Down
8 changes: 6 additions & 2 deletions examples/02-neighbor-lists-usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#
# As a test system, we use a 2x2x2 supercell of an CsCl crystal in a cubic cell.

dtype = torch.float64
atoms_unitcell = ase.Atoms(
symbols=["Cs", "Cl"],
positions=np.array([(0, 0, 0), (0.5, 0.5, 0.5)]),
Expand Down Expand Up @@ -97,7 +98,7 @@
nl = vesin.torch.NeighborList(cutoff=cutoff, full_list=False)
neighbor_indices, neighbor_distances = nl.compute(
points=positions.to(dtype=torch.float64, device="cpu"),
box=cell.to(dtype=torch.float64, device="cpu"),
box=cell.to(dtype=dtype, device="cpu"),
periodic=True,
quantities="Pd",
)
Expand All @@ -109,6 +110,7 @@
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
dtype=dtype,
)

# %%
Expand Down Expand Up @@ -193,7 +195,9 @@ def distances(
# compute the potential.

pme = torchpme.PMECalculator(
potential=torchpme.CoulombPotential(smearing=smearing), **pme_params
potential=torchpme.CoulombPotential(smearing=smearing, dtype=dtype),
dtype=dtype,
**pme_params,
)
potential = pme(
charges=charges,
Expand Down
15 changes: 8 additions & 7 deletions examples/08-combined-potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from torchpme import CombinedPotential, EwaldCalculator, InversePowerLawPotential
from torchpme.prefactors import eV_A

dtype = torch.float64

# %%
# Combined potentials
# -------------------
Expand Down Expand Up @@ -65,10 +67,10 @@
# evaluation, and so one has to set it also for the combined potential, even if it is
# not used explicitly in the evaluation of the combination.

pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing)
pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing)
pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype)
pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing, dtype=dtype)

potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing)
potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing, dtype=dtype)

# Note also that :class:`CombinedPotential` can be used with any combination of
# potentials, as long they are all either direct or range separated. For instance, one
Expand All @@ -80,7 +82,7 @@
# We now plot of the individual and combined ``potential`` functions together with an
# explicit sum of the two potentials.

dist = torch.logspace(-3, 2, 1000)
dist = torch.logspace(-3, 2, 1000, dtype=dtype)

fig, ax = plt.subplots()

Expand Down Expand Up @@ -115,7 +117,7 @@
# combines all terms in a range-separated potential, including the k-space
# kernel.

k = torch.logspace(-2, 2, 1000)
k = torch.logspace(-2, 2, 1000, dtype=dtype)

fig, ax = plt.subplots()

Expand Down Expand Up @@ -154,9 +156,8 @@
# much bigger system.

calculator = EwaldCalculator(
potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A
potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A, dtype=dtype
)
calculator.to(dtype=torch.float64)


# %%
Expand Down
32 changes: 25 additions & 7 deletions examples/10-tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@
pme_params = {"mesh_spacing": 1.0, "interpolation_nodes": 4}

pme = torchpme.PMECalculator(
potential=torchpme.CoulombPotential(smearing=smearing),
potential=torchpme.CoulombPotential(smearing=smearing, device=device, dtype=dtype),
device=device,
dtype=dtype,
**pme_params, # type: ignore[arg-type]
)

Expand Down Expand Up @@ -168,6 +170,8 @@
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
run_backward=True,
device=device,
dtype=dtype,
)
estimated_timing = timings(pme)

Expand Down Expand Up @@ -210,15 +214,19 @@ def filter_neighbors(cutoff, neighbor_indices, neighbor_distances):
return neighbor_indices[filter_idx], neighbor_distances[filter_idx]


def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device, dtype):
filter_indices, filter_distances = filter_neighbors(
cutoff, neighbor_indices, neighbor_distances
)

pme = torchpme.PMECalculator(
potential=torchpme.CoulombPotential(smearing=smearing),
potential=torchpme.CoulombPotential(
smearing=smearing, device=device, dtype=dtype
),
mesh_spacing=mesh_spacing,
interpolation_nodes=interpolation_nodes,
device=device,
dtype=dtype,
)
potential = pme(
charges=charges,
Expand All @@ -239,6 +247,8 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
run_backward=True,
n_warmup=1,
n_repeat=4,
device=device,
dtype=dtype,
)
estimated_timing = timings(pme)
return madelung, estimated_timing
Expand All @@ -251,7 +261,9 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
bounds = np.zeros((len(smearing_grid), len(spacing_grid)))
for ism, smearing in enumerate(smearing_grid):
for isp, spacing in enumerate(spacing_grid):
results[ism, isp], timings[ism, isp] = timed_madelung(8.0, smearing, spacing, 4)
results[ism, isp], timings[ism, isp] = timed_madelung(
8.0, smearing, spacing, 4, device, dtype
)
bounds[ism, isp] = error_bounds(8.0, smearing, spacing, 4)

# %%
Expand Down Expand Up @@ -374,7 +386,7 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
for inint, nint in enumerate(nint_grid):
for isp, spacing in enumerate(spacing_grid):
results[inint, isp], timings[inint, isp] = timed_madelung(
5.0, 1.0, spacing, nint
5.0, 1.0, spacing, nint, device=device, dtype=dtype
)
bounds[inint, isp] = error_bounds(5.0, 1.0, spacing, nint)

Expand Down Expand Up @@ -445,15 +457,19 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
cutoff=5.0,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
device=device,
dtype=dtype,
)

print(f"""
print(
f"""
Estimated PME parameters (cutoff={5.0} Å):
Smearing: {smearing} Å
Mesh spacing: {parameters["mesh_spacing"]} Å
Interpolation order: {parameters["interpolation_nodes"]}
Estimated time per step: {timing} s
""")
"""
)

# %%
# What is the best cutoff?
Expand All @@ -476,6 +492,8 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
cutoff=cutoff,
neighbor_indices=filter_indices,
neighbor_distances=filter_distances,
device=device,
dtype=dtype,
)
timings_grid.append(timing)

Expand Down
6 changes: 4 additions & 2 deletions examples/basic-usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
# contains all the necessary functions (such as those defining the short-range and
# long-range splits) for this potential and makes them useable in the rest of the code.

potential = CoulombPotential(smearing=smearing)
potential = CoulombPotential(smearing=smearing, device=device, dtype=dtype)

# %%
#
Expand Down Expand Up @@ -193,7 +193,9 @@
# Since our structure is relatively small, we use the :class:`EwaldCalculator`.
# We start by the initialization of the class.

calculator = EwaldCalculator(potential=potential, lr_wavelength=lr_wavelength)
calculator = EwaldCalculator(
potential=potential, lr_wavelength=lr_wavelength, device=device, dtype=dtype
)

# %%
#
Expand Down
55 changes: 36 additions & 19 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,26 @@ def _validate_parameters(
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
dtype: torch.dtype,
device: Union[str, torch.device],
) -> None:
device = positions.device
dtype = positions.dtype
if positions.dtype != dtype:
raise TypeError(
f"type of `positions` ({positions.dtype}) must be same as the class "
f"type ({dtype})"
)

if isinstance(device, torch.device):
device = device.type

if positions.device.type != device:
raise ValueError(
f"device of `positions` ({positions.device}) must be same as the class "
f"device ({device})"
)

# We use `positions.device` because it includes the device type AND index, which the
# `device` parameter may lack

# check shape, dtype and device of positions
num_atoms = len(positions)
Expand All @@ -29,14 +46,14 @@ def _validate_parameters(
f"{list(cell.shape)}"
)

if cell.dtype != dtype:
raise ValueError(
f"type of `cell` ({cell.dtype}) must be same as `positions` ({dtype})"
if cell.dtype != positions.dtype:
raise TypeError(
f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})"
)

if cell.device != device:
if cell.device != positions.device:
raise ValueError(
f"device of `cell` ({cell.device}) must be same as `positions` ({device})"
f"device of `cell` ({cell.device}) must be same as the class ({device})"
)

if smearing is not None and torch.equal(
Expand All @@ -63,14 +80,14 @@ def _validate_parameters(
f"{len(positions)} atoms"
)

if charges.dtype != dtype:
raise ValueError(
f"type of `charges` ({charges.dtype}) must be same as `positions` ({dtype})"
if charges.dtype != positions.dtype:
raise TypeError(
f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})"
)

if charges.device != device:
if charges.device != positions.device:
raise ValueError(
f"device of `charges` ({charges.device}) must be same as `positions` "
f"device of `charges` ({charges.device}) must be same as the class "
f"({device})"
)

Expand All @@ -82,10 +99,10 @@ def _validate_parameters(
"structure"
)

if neighbor_indices.device != device:
if neighbor_indices.device != positions.device:
raise ValueError(
f"device of `neighbor_indices` ({neighbor_indices.device}) must be "
f"same as `positions` ({device})"
f"same as the class ({device})"
)

if neighbor_distances.shape != neighbor_indices[:, 0].shape:
Expand All @@ -95,14 +112,14 @@ def _validate_parameters(
f"{list(neighbor_indices.shape)} and {list(neighbor_distances.shape)}"
)

if neighbor_distances.device != device:
if neighbor_distances.device != positions.device:
raise ValueError(
f"device of `neighbor_distances` ({neighbor_distances.device}) must be "
f"same as `positions` ({device})"
f"same as the class ({device})"
)

if neighbor_distances.dtype != dtype:
raise ValueError(
if neighbor_distances.dtype != positions.dtype:
raise TypeError(
f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same "
f"as `positions` ({dtype})"
f"as the class ({dtype})"
)
Loading

0 comments on commit bf52afa

Please sign in to comment.