Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix device and dtype not being specified in the __init__ of P3MCalculator #159

Merged
merged 5 commits into from
Jan 29, 2025

Conversation

GardevoirX
Copy link
Contributor

@GardevoirX GardevoirX commented Jan 27, 2025

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?

📚 Documentation preview 📚: https://torch-pme--159.org.readthedocs.build/en/159/

Copy link
Contributor

@PicoCentauri PicoCentauri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Does this break anything if you don't pass it?

Maybe we should add a test to prevent this...

@GardevoirX
Copy link
Contributor Author

GardevoirX commented Jan 27, 2025

Thanks! Does this break anything if you don't pass it?

Maybe we should add a test to prevent this...

Yes I was doing some plots yesterday (on cuda), and it broke... I guess this is mainly because we only do our tests on cpu? which is also the default device of the device-not-specified tensors, so current tests cannot catch these potential issues

@PicoCentauri
Copy link
Contributor

If you run the test suite on a machine with the CUDA device they also don't fail? I though we added a parametrize over CUDA if its available.

@GardevoirX
Copy link
Contributor Author

Yes true, I tried to run the tunning tests on cuda and it worked well. You can try to reproduce the error with the following code. It fails in the main branch but works in this branch

import sys

sys.path.append("/home/qxu/repos/torch-pme/tests")
from helpers import define_crystal
import ase
import torch
import vesin.torch
from torchpme.tuning import tune_p3m

device = "cuda"
dtype = torch.float64
CUTOFF = 4.4
positions, charges, cell, madelung_ref, num_formula_units = define_crystal()

atoms_unitcell = ase.Atoms(
    symbols=len(positions) * ["H"],
    positions=positions,
    charges=charges.flatten(),
    pbc=True,
    cell=cell,
)
rep = [8, 8, 8]
atoms = atoms_unitcell.repeat(rep)
positions = torch.tensor(atoms.positions, dtype=dtype, device=device)
charges = torch.tensor(
    atoms.get_initial_charges(), dtype=dtype, device=device
).unsqueeze(1)
cell = torch.tensor(atoms.cell.array, dtype=dtype, device=device)

nl = vesin.torch.NeighborList(cutoff=CUTOFF, full_list=False)
i, j, neighbor_distances = nl.compute(
    points=positions.to(dtype=torch.float64, device="cpu"),
    box=cell.to(dtype=torch.float64, device="cpu"),
    periodic=True,
    quantities="ijd",
)
neighbor_indices = torch.stack([i, j], dim=1).to(device=device)
neighbor_distances = neighbor_distances.to(dtype=dtype, device=device)

smearing, params, _ = tune_p3m(
    charges=charges,
    cell=cell,
    positions=positions,
    cutoff=CUTOFF,
    neighbor_indices=neighbor_indices,
    neighbor_distances=neighbor_distances,
    dtype=dtype,
    device=device,
    accuracy=1e-7,
)

@PicoCentauri
Copy link
Contributor

Okay thanks. I will just add a loop over devices in the test_tuning.py. That should test for it.

@PicoCentauri
Copy link
Contributor

Okay turns out we didnt't really test consistent dtype and device at all.

I added those tests.

@GardevoirX can you maybe check if the changes you made are covered by the tests. Basically commenting your changes out on a CUDA machine and run the updated testsuite from this PR.

@PicoCentauri PicoCentauri requested a review from E-Rum January 28, 2025 15:20
@GardevoirX
Copy link
Contributor Author

Cool, these tests can catch the error

Copy link
Contributor

@E-Rum E-Rum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very good! Thank you very much for the enormous work on cleaning up the front! Could you also confirm if everything works correctly when we simultaneously pass device as a string for the calculator and as a torch.device for the potential, and vice versa?

Comment on lines 48 to 50
f"type of `cell` ({cell.dtype}) must be same as `positions` ({dtype})"
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also change the ValueError/TypeError message for cell, charges, neighbor_indices, and neighbor_distances, stating that they should have the same dtype and device as specified, but not the same as positions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can change the error. Message and add a comment why we still compare against the positions.dtype.

@PicoCentauri
Copy link
Contributor

Thanks for your reviews. I updated the test message and moved the dtype to the very beginning of the examples. As we do in all examples.

@E-Rum E-Rum merged commit bf52afa into main Jan 29, 2025
13 checks passed
@E-Rum E-Rum deleted the fix-p3m branch January 29, 2025 12:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants