Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SplinePotential device compatibility #138

Merged
merged 7 commits into from
Jan 8, 2025
Merged

SplinePotential device compatibility #138

merged 7 commits into from
Jan 8, 2025

Conversation

E-Rum
Copy link
Contributor

@E-Rum E-Rum commented Jan 7, 2025

This PR addresses issue #137. It fixes the compatibility of the SplinePotential class with GPU. Additionally, it adds a new test to verify the compatibility between the Potential class and the output device.


📚 Documentation preview 📚: https://torch-pme--138.org.readthedocs.build/en/138/

Copy link
Contributor

@PicoCentauri PicoCentauri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @E-Rum. I left some cosmetic comments and a question.

tests/test_potentials.py Outdated Show resolved Hide resolved

smearing = 1.0
exponent = 1.0
dtype = torch.float64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also parametrize over the dtype?

@@ -108,12 +115,16 @@ def __init__(
self._krn_spline = CubicSpline(k_grid**2, yhat_grid)

if y_at_zero is None:
self._y_at_zero = self._spline(torch.tensor([0.0]))
self._y_at_zero = self._spline(
torch.tensor([0.0], dtype=dtype, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not really matter, but looks cleaner to me. Feel free to ignore.

Suggested change
torch.tensor([0.0], dtype=dtype, device=device)
torch.zeros(1), dtype=dtype, device=device)

else:
self._y_at_zero = y_at_zero

if yhat_at_zero is None:
self._yhat_at_zero = self._krn_spline(torch.tensor([0.0]))
self._yhat_at_zero = self._krn_spline(
torch.tensor([0.0], dtype=dtype, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.tensor([0.0], dtype=dtype, device=device)
torch.zeros(1, dtype=dtype, device=device)

src/torchpme/potentials/spline.py Show resolved Hide resolved
@PicoCentauri
Copy link
Contributor

And could you update Unreleased section in the changelog (docs/src/references/changelog.rst) with a description of the fix? THANKS!

Co-authored-by: Philip Loche <philip.loche@posteo.de>
@E-Rum
Copy link
Contributor Author

E-Rum commented Jan 8, 2025

Okay, I resolved the comments left by @PicoCentauri. If everything is alright and we decide to address the remaining issues in the next PR, we can proceed with the merge.

Copy link
Contributor

@PicoCentauri PicoCentauri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @E-Rum and @ceriottm.

Merge when you want and CI is happy.

@E-Rum E-Rum merged commit 04edb22 into main Jan 8, 2025
13 checks passed
@E-Rum E-Rum deleted the splinegpu branch January 8, 2025 14:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

The current version does not correctly handle the SplinePotential class when using GPU.
3 participants