Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix device and dtype not being specified in the __init__ of P3MCalculator #159

Merged
merged 5 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
Added
#####

* Enhanced ``device`` and ``dtype`` consistency checks throughout the library
* Better documentation for for ``cell``, ``charges`` and ``positions`` parameters
* Require consistent ``dtype`` between ``positions`` and ``neighbor_distances`` in
``Calculator`` classes and tuning functions.

Fixed
#####

* Fix ``device`` and ``dtype`` not being used in the init of the ``P3MCalculator``

`Version 0.2.0 <https://github.com/lab-cosmo/torch-pme/releases/tag/v0.2.0>`_ - 2025-01-23
------------------------------------------------------------------------------------------

Expand Down
18 changes: 10 additions & 8 deletions examples/01-charges-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@
import torchpme
from torchpme.tuning import tune_pme

dtype = torch.float64

# %%
#
# Create the properties CsCl unit cell

symbols = ("Cs", "Cl")
types = torch.tensor([55, 17])
charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64)
positions = torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=torch.float64)
cell = torch.eye(3, dtype=torch.float64)
charges = torch.tensor([[1.0], [-1.0]], dtype=dtype)
positions = torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=dtype)
cell = torch.eye(3, dtype=dtype)
pbc = torch.tensor([True, True, True])


Expand All @@ -72,6 +73,7 @@
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
dtype=dtype,
)

# %%
Expand Down Expand Up @@ -101,7 +103,7 @@
# will be used to *compute* the potential energy of the system.

calculator = torchpme.PMECalculator(
torchpme.CoulombPotential(smearing=smearing), **pme_params
torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params
)

# %%
Expand All @@ -112,7 +114,7 @@
# As a first application of multiple charge channels, we start simply by using the
# classic definition of one charge channel per atom.

charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64)
charges = torch.tensor([[1.0], [-1.0]], dtype=dtype)

# %%
#
Expand Down Expand Up @@ -160,7 +162,7 @@
# species-specific potentials and facilitating the learning process for machine learning
# algorithms.

charges_one_hot = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float64)
charges_one_hot = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=dtype)

# %%
#
Expand Down Expand Up @@ -205,7 +207,7 @@
# creating a new calculator with the metatensor interface.

calculator_metatensor = torchpme.metatensor.PMECalculator(
torchpme.CoulombPotential(smearing=smearing), **pme_params
torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params
)

# %%
Expand Down
8 changes: 6 additions & 2 deletions examples/02-neighbor-lists-usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#
# As a test system, we use a 2x2x2 supercell of an CsCl crystal in a cubic cell.

dtype = torch.float64
atoms_unitcell = ase.Atoms(
symbols=["Cs", "Cl"],
positions=np.array([(0, 0, 0), (0.5, 0.5, 0.5)]),
Expand Down Expand Up @@ -97,7 +98,7 @@
nl = vesin.torch.NeighborList(cutoff=cutoff, full_list=False)
neighbor_indices, neighbor_distances = nl.compute(
points=positions.to(dtype=torch.float64, device="cpu"),
box=cell.to(dtype=torch.float64, device="cpu"),
box=cell.to(dtype=dtype, device="cpu"),
periodic=True,
quantities="Pd",
)
Expand All @@ -109,6 +110,7 @@
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
dtype=dtype,
)

# %%
Expand Down Expand Up @@ -193,7 +195,9 @@ def distances(
# compute the potential.

pme = torchpme.PMECalculator(
potential=torchpme.CoulombPotential(smearing=smearing), **pme_params
potential=torchpme.CoulombPotential(smearing=smearing, dtype=dtype),
dtype=dtype,
**pme_params,
)
potential = pme(
charges=charges,
Expand Down
15 changes: 8 additions & 7 deletions examples/08-combined-potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from torchpme import CombinedPotential, EwaldCalculator, InversePowerLawPotential
from torchpme.prefactors import eV_A

dtype = torch.float64

# %%
# Combined potentials
# -------------------
Expand Down Expand Up @@ -65,10 +67,10 @@
# evaluation, and so one has to set it also for the combined potential, even if it is
# not used explicitly in the evaluation of the combination.

pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing)
pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing)
pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype)
pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing, dtype=dtype)

potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing)
potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing, dtype=dtype)

# Note also that :class:`CombinedPotential` can be used with any combination of
# potentials, as long they are all either direct or range separated. For instance, one
Expand All @@ -80,7 +82,7 @@
# We now plot of the individual and combined ``potential`` functions together with an
# explicit sum of the two potentials.

dist = torch.logspace(-3, 2, 1000)
dist = torch.logspace(-3, 2, 1000, dtype=dtype)

fig, ax = plt.subplots()

Expand Down Expand Up @@ -115,7 +117,7 @@
# combines all terms in a range-separated potential, including the k-space
# kernel.

k = torch.logspace(-2, 2, 1000)
k = torch.logspace(-2, 2, 1000, dtype=dtype)

fig, ax = plt.subplots()

Expand Down Expand Up @@ -154,9 +156,8 @@
# much bigger system.

calculator = EwaldCalculator(
potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A
potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A, dtype=dtype
)
calculator.to(dtype=torch.float64)


# %%
Expand Down
32 changes: 25 additions & 7 deletions examples/10-tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@
pme_params = {"mesh_spacing": 1.0, "interpolation_nodes": 4}

pme = torchpme.PMECalculator(
potential=torchpme.CoulombPotential(smearing=smearing),
potential=torchpme.CoulombPotential(smearing=smearing, device=device, dtype=dtype),
device=device,
dtype=dtype,
**pme_params, # type: ignore[arg-type]
)

Expand Down Expand Up @@ -168,6 +170,8 @@
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
run_backward=True,
device=device,
dtype=dtype,
)
estimated_timing = timings(pme)

Expand Down Expand Up @@ -210,15 +214,19 @@ def filter_neighbors(cutoff, neighbor_indices, neighbor_distances):
return neighbor_indices[filter_idx], neighbor_distances[filter_idx]


def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device, dtype):
filter_indices, filter_distances = filter_neighbors(
cutoff, neighbor_indices, neighbor_distances
)

pme = torchpme.PMECalculator(
potential=torchpme.CoulombPotential(smearing=smearing),
potential=torchpme.CoulombPotential(
smearing=smearing, device=device, dtype=dtype
),
mesh_spacing=mesh_spacing,
interpolation_nodes=interpolation_nodes,
device=device,
dtype=dtype,
)
potential = pme(
charges=charges,
Expand All @@ -239,6 +247,8 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
run_backward=True,
n_warmup=1,
n_repeat=4,
device=device,
dtype=dtype,
)
estimated_timing = timings(pme)
return madelung, estimated_timing
Expand All @@ -251,7 +261,9 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
bounds = np.zeros((len(smearing_grid), len(spacing_grid)))
for ism, smearing in enumerate(smearing_grid):
for isp, spacing in enumerate(spacing_grid):
results[ism, isp], timings[ism, isp] = timed_madelung(8.0, smearing, spacing, 4)
results[ism, isp], timings[ism, isp] = timed_madelung(
8.0, smearing, spacing, 4, device, dtype
)
bounds[ism, isp] = error_bounds(8.0, smearing, spacing, 4)

# %%
Expand Down Expand Up @@ -374,7 +386,7 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
for inint, nint in enumerate(nint_grid):
for isp, spacing in enumerate(spacing_grid):
results[inint, isp], timings[inint, isp] = timed_madelung(
5.0, 1.0, spacing, nint
5.0, 1.0, spacing, nint, device=device, dtype=dtype
)
bounds[inint, isp] = error_bounds(5.0, 1.0, spacing, nint)

Expand Down Expand Up @@ -445,15 +457,19 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
cutoff=5.0,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
device=device,
dtype=dtype,
)

print(f"""
print(
f"""
Estimated PME parameters (cutoff={5.0} Å):
Smearing: {smearing} Å
Mesh spacing: {parameters["mesh_spacing"]} Å
Interpolation order: {parameters["interpolation_nodes"]}
Estimated time per step: {timing} s
""")
"""
)

# %%
# What is the best cutoff?
Expand All @@ -476,6 +492,8 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
cutoff=cutoff,
neighbor_indices=filter_indices,
neighbor_distances=filter_distances,
device=device,
dtype=dtype,
)
timings_grid.append(timing)

Expand Down
6 changes: 4 additions & 2 deletions examples/basic-usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
# contains all the necessary functions (such as those defining the short-range and
# long-range splits) for this potential and makes them useable in the rest of the code.

potential = CoulombPotential(smearing=smearing)
potential = CoulombPotential(smearing=smearing, device=device, dtype=dtype)

# %%
#
Expand Down Expand Up @@ -193,7 +193,9 @@
# Since our structure is relatively small, we use the :class:`EwaldCalculator`.
# We start by the initialization of the class.

calculator = EwaldCalculator(potential=potential, lr_wavelength=lr_wavelength)
calculator = EwaldCalculator(
potential=potential, lr_wavelength=lr_wavelength, device=device, dtype=dtype
)

# %%
#
Expand Down
55 changes: 36 additions & 19 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,26 @@ def _validate_parameters(
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
dtype: torch.dtype,
device: Union[str, torch.device],
) -> None:
device = positions.device
dtype = positions.dtype
if positions.dtype != dtype:
raise TypeError(
f"type of `positions` ({positions.dtype}) must be same as the class "
f"type ({dtype})"
)

if isinstance(device, torch.device):
device = device.type

if positions.device.type != device:
raise ValueError(
f"device of `positions` ({positions.device}) must be same as the class "
f"device ({device})"
)

# We use `positions.device` because it includes the device type AND index, which the
# `device` parameter may lack

# check shape, dtype and device of positions
num_atoms = len(positions)
Expand All @@ -29,14 +46,14 @@ def _validate_parameters(
f"{list(cell.shape)}"
)

if cell.dtype != dtype:
raise ValueError(
f"type of `cell` ({cell.dtype}) must be same as `positions` ({dtype})"
if cell.dtype != positions.dtype:
raise TypeError(
f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})"
)

if cell.device != device:
if cell.device != positions.device:
raise ValueError(
f"device of `cell` ({cell.device}) must be same as `positions` ({device})"
f"device of `cell` ({cell.device}) must be same as the class ({device})"
)

if smearing is not None and torch.equal(
Expand All @@ -63,14 +80,14 @@ def _validate_parameters(
f"{len(positions)} atoms"
)

if charges.dtype != dtype:
raise ValueError(
f"type of `charges` ({charges.dtype}) must be same as `positions` ({dtype})"
if charges.dtype != positions.dtype:
raise TypeError(
f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})"
)

if charges.device != device:
if charges.device != positions.device:
raise ValueError(
f"device of `charges` ({charges.device}) must be same as `positions` "
f"device of `charges` ({charges.device}) must be same as the class "
f"({device})"
)

Expand All @@ -82,10 +99,10 @@ def _validate_parameters(
"structure"
)

if neighbor_indices.device != device:
if neighbor_indices.device != positions.device:
raise ValueError(
f"device of `neighbor_indices` ({neighbor_indices.device}) must be "
f"same as `positions` ({device})"
f"same as the class ({device})"
)

if neighbor_distances.shape != neighbor_indices[:, 0].shape:
Expand All @@ -95,14 +112,14 @@ def _validate_parameters(
f"{list(neighbor_indices.shape)} and {list(neighbor_distances.shape)}"
)

if neighbor_distances.device != device:
if neighbor_distances.device != positions.device:
raise ValueError(
f"device of `neighbor_distances` ({neighbor_distances.device}) must be "
f"same as `positions` ({device})"
f"same as the class ({device})"
)

if neighbor_distances.dtype != dtype:
raise ValueError(
if neighbor_distances.dtype != positions.dtype:
raise TypeError(
f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same "
f"as `positions` ({dtype})"
f"as the class ({dtype})"
)
Loading