-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix device
and dtype
not being specified in the __init__
of P3MCalculator
#159
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Does this break anything if you don't pass it?
Maybe we should add a test to prevent this...
Yes I was doing some plots yesterday (on cuda), and it broke... I guess this is mainly because we only do our tests on cpu? which is also the default device of the device-not-specified tensors, so current tests cannot catch these potential issues |
If you run the test suite on a machine with the CUDA device they also don't fail? I though we added a parametrize over CUDA if its available. |
Yes true, I tried to run the tunning tests on cuda and it worked well. You can try to reproduce the error with the following code. It fails in the main branch but works in this branch import sys
sys.path.append("/home/qxu/repos/torch-pme/tests")
from helpers import define_crystal
import ase
import torch
import vesin.torch
from torchpme.tuning import tune_p3m
device = "cuda"
dtype = torch.float64
CUTOFF = 4.4
positions, charges, cell, madelung_ref, num_formula_units = define_crystal()
atoms_unitcell = ase.Atoms(
symbols=len(positions) * ["H"],
positions=positions,
charges=charges.flatten(),
pbc=True,
cell=cell,
)
rep = [8, 8, 8]
atoms = atoms_unitcell.repeat(rep)
positions = torch.tensor(atoms.positions, dtype=dtype, device=device)
charges = torch.tensor(
atoms.get_initial_charges(), dtype=dtype, device=device
).unsqueeze(1)
cell = torch.tensor(atoms.cell.array, dtype=dtype, device=device)
nl = vesin.torch.NeighborList(cutoff=CUTOFF, full_list=False)
i, j, neighbor_distances = nl.compute(
points=positions.to(dtype=torch.float64, device="cpu"),
box=cell.to(dtype=torch.float64, device="cpu"),
periodic=True,
quantities="ijd",
)
neighbor_indices = torch.stack([i, j], dim=1).to(device=device)
neighbor_distances = neighbor_distances.to(dtype=dtype, device=device)
smearing, params, _ = tune_p3m(
charges=charges,
cell=cell,
positions=positions,
cutoff=CUTOFF,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
dtype=dtype,
device=device,
accuracy=1e-7,
) |
Okay thanks. I will just add a loop over devices in the |
Okay turns out we didnt't really test consistent I added those tests. @GardevoirX can you maybe check if the changes you made are covered by the tests. Basically commenting your changes out on a CUDA machine and run the updated testsuite from this PR. |
Fix again
Cool, these tests can catch the error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very good! Thank you very much for the enormous work on cleaning up the front! Could you also confirm if everything works correctly when we simultaneously pass device as a string
for the calculator and as a torch.device
for the potential, and vice versa?
src/torchpme/_utils.py
Outdated
f"type of `cell` ({cell.dtype}) must be same as `positions` ({dtype})" | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also change the ValueError
/TypeError
message for cell
, charges
, neighbor_indices
, and neighbor_distances
, stating that they should have the same dtype
and device
as specified, but not the same as positions
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I can change the error. Message and add a comment why we still compare against the positions.dtype
.
Thanks for your reviews. I updated the test message and moved the |
Contributor (creator of pull-request) checklist
Reviewer checklist
📚 Documentation preview 📚: https://torch-pme--159.org.readthedocs.build/en/159/