diff --git a/docs/src/references/changelog.rst b/docs/src/references/changelog.rst index d1e86621..96407d07 100644 --- a/docs/src/references/changelog.rst +++ b/docs/src/references/changelog.rst @@ -32,6 +32,8 @@ Added Fixed ##### +* Refactor the ``InversePowerLawPotential`` class to restrict the exponent to integer + values * Ensured consistency of ``dtype`` and ``device`` in the ``Potential`` and ``Calculator`` classses * Fixed consistency of ``dtype`` and ``device`` in the ``SplinePotential`` class diff --git a/examples/8-combined-potential.py b/examples/8-combined-potential.py index 5c196f5c..301da8f7 100644 --- a/examples/8-combined-potential.py +++ b/examples/8-combined-potential.py @@ -65,8 +65,8 @@ # evaluation, and so one has to set it also for the combined potential, even if it is # not used explicitly in the evaluation of the combination. -pot_1 = InversePowerLawPotential(exponent=1.0, smearing=smearing) -pot_2 = InversePowerLawPotential(exponent=2.0, smearing=smearing) +pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing) +pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing) potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing) diff --git a/src/torchpme/calculators/ewald.py b/src/torchpme/calculators/ewald.py index d009c4d1..95898e87 100644 --- a/src/torchpme/calculators/ewald.py +++ b/src/torchpme/calculators/ewald.py @@ -138,6 +138,5 @@ def _compute_kspace( charge_tot = torch.sum(charges, dim=0) prefac = self.potential.background_correction() energy -= 2 * prefac * charge_tot * ivolume - # Compensate for double counting of pairs (i,j) and (j,i) return energy / 2 diff --git a/src/torchpme/lib/__init__.py b/src/torchpme/lib/__init__.py index aa4b8bf8..28cea33b 100644 --- a/src/torchpme/lib/__init__.py +++ b/src/torchpme/lib/__init__.py @@ -4,6 +4,7 @@ generate_kvectors_for_mesh, get_ns_mesh, ) +from .math import CustomExp1, gamma, gammaincc_over_powerlaw, torch_exp1 from .mesh_interpolator import MeshInterpolator __all__ = [ @@ -16,4 +17,8 @@ "generate_kvectors_for_ewald", "generate_kvectors_for_mesh", "get_ns_mesh", + "gamma", + "CustomExp1", + "gammaincc_over_powerlaw", + "torch_exp1", ] diff --git a/src/torchpme/lib/math.py b/src/torchpme/lib/math.py new file mode 100644 index 00000000..871abbc2 --- /dev/null +++ b/src/torchpme/lib/math.py @@ -0,0 +1,56 @@ +import torch +from scipy.special import exp1 +from torch.special import gammaln + + +def gamma(x: torch.Tensor) -> torch.Tensor: + """ + (Complete) Gamma function. + + pytorch has not implemented the commonly used (complete) Gamma function. We define + it in a custom way to make autograd work as in + https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122 + """ + return torch.exp(gammaln(x)) + + +class CustomExp1(torch.autograd.Function): + """Custom exponential integral function Exp1(x) to have an autograd-compatible version.""" + + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + input_numpy = input.cpu().numpy() if not input.is_cpu else input.numpy() + return torch.tensor(exp1(input_numpy), device=input.device, dtype=input.dtype) + + @staticmethod + def backward(ctx, grad_output): + (input,) = ctx.saved_tensors + return -grad_output * torch.exp(-input) / input + + +def torch_exp1(input): + """Wrapper for the custom exponential integral function.""" + return CustomExp1.apply(input) + + +def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + """Function to compute the regularized incomplete gamma function complement for integer exponents.""" + if exponent == 1: + return torch.exp(-z) / z + if exponent == 2: + return torch.sqrt(torch.pi / z) * torch.erfc(torch.sqrt(z)) + if exponent == 3: + return torch_exp1(z) + if exponent == 4: + return 2 * ( + torch.exp(-z) - torch.sqrt(torch.pi * z) * torch.erfc(torch.sqrt(z)) + ) + if exponent == 5: + return torch.exp(-z) - z * torch_exp1(z) + if exponent == 6: + return ( + (2 - 4 * z) * torch.exp(-z) + + 4 * torch.sqrt(torch.pi * z**3) * torch.erfc(torch.sqrt(z)) + ) / 3 + raise ValueError(f"Unsupported exponent: {exponent}") diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py index 8b2449ca..9abeb823 100644 --- a/src/torchpme/potentials/inversepowerlaw.py +++ b/src/torchpme/potentials/inversepowerlaw.py @@ -1,20 +1,11 @@ from typing import Optional import torch -from torch.special import gammainc, gammaincc, gammaln +from torch.special import gammainc -from .potential import Potential - - -def gamma(x: torch.Tensor) -> torch.Tensor: - """ - (Complete) Gamma function. +from torchpme.lib import gamma, gammaincc_over_powerlaw - pytorch has not implemented the commonly used (complete) Gamma function. We define - it in a custom way to make autograd work as in - https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122 - """ - return torch.exp(gammaln(x)) +from .potential import Potential class InversePowerLawPotential(Potential): @@ -46,7 +37,7 @@ class InversePowerLawPotential(Potential): def __init__( self, - exponent: float, + exponent: int, smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, dtype: Optional[torch.dtype] = None, @@ -54,8 +45,8 @@ def __init__( ): super().__init__(smearing, exclusion_radius, dtype, device) - if exponent <= 0 or exponent > 3: - raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p <= 3") + # function call to check the validity of the exponent + gammaincc_over_powerlaw(exponent, torch.tensor(1.0, dtype=dtype, device=device)) self.register_buffer( "exponent", torch.tensor(exponent, dtype=self.dtype, device=self.device) ) @@ -130,9 +121,7 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: # for consistency reasons. masked = torch.where(x == 0, 1.0, x) # avoid NaNs in backwards, see Coulomb return torch.where( - k_sq == 0, - 0.0, - prefac * gammaincc(peff, masked) / masked**peff * gamma(peff), + k_sq == 0, 0.0, prefac * gammaincc_over_powerlaw(exponent, masked) ) def self_contribution(self) -> torch.Tensor: @@ -145,7 +134,11 @@ def self_contribution(self) -> torch.Tensor: return 1 / gamma(phalf + 1) / (2 * self.smearing**2) ** phalf def background_correction(self) -> torch.Tensor: - # "charge neutrality" correction for 1/r^p potential + # "charge neutrality" correction for 1/r^p potential diverges for exponent p = 3 + # and is not needed for p > 3 , so we set it to zero (see in + # https://doi.org/10.48550/arXiv.2412.03281 SI section) + if self.exponent >= 3: + return torch.tensor(0.0, dtype=self.dtype, device=self.device) if self.smearing is None: raise ValueError( "Cannot compute background correction without specifying `smearing`." diff --git a/tests/calculators/test_values_ewald.py b/tests/calculators/test_values_ewald.py index 208d937d..6b405011 100644 --- a/tests/calculators/test_values_ewald.py +++ b/tests/calculators/test_values_ewald.py @@ -100,7 +100,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name): lr_wavelength = 0.5 * smearing calc = EwaldCalculator( InversePowerLawPotential( - exponent=1.0, + exponent=1, smearing=smearing, ), lr_wavelength=lr_wavelength, @@ -111,7 +111,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name): smearing = sr_cutoff / 5.0 calc = PMECalculator( InversePowerLawPotential( - exponent=1.0, + exponent=1, smearing=smearing, ), mesh_spacing=smearing / 8, @@ -198,7 +198,7 @@ def test_wigner(crystal_name, scaling_factor): # Compute potential and compare against reference calc = EwaldCalculator( InversePowerLawPotential( - exponent=1.0, + exponent=1, smearing=smeareff, ), lr_wavelength=smeareff / 2, diff --git a/tests/lib/test_math.py b/tests/lib/test_math.py new file mode 100644 index 00000000..4ec7037c --- /dev/null +++ b/tests/lib/test_math.py @@ -0,0 +1,25 @@ +import numpy as np +import torch +from scipy.special import exp1 + +from torchpme.lib import torch_exp1 + + +def finite_difference_derivative(func, x, h=1e-5): + return (func(x + h) - func(x - h)) / (2 * h) + + +def test_torch_exp1_consistency_with_scipy(): + x = torch.rand(1000, dtype=torch.float64) + torch_result = torch_exp1(x) + scipy_result = exp1(x.numpy()) + assert np.allclose(torch_result.numpy(), scipy_result, atol=1e-6) + + +def test_torch_exp1_derivative(): + x = torch.rand(1, dtype=torch.float64, requires_grad=True) + torch_result = torch_exp1(x) + torch_result.backward() + torch_exp1_prime = x.grad + finite_diff_result = finite_difference_derivative(exp1, x.detach().numpy()) + assert np.allclose(torch_exp1_prime.numpy(), finite_diff_result, atol=1e-6) diff --git a/tests/test_potentials.py b/tests/test_potentials.py index 87670b3b..08793749 100644 --- a/tests/test_potentials.py +++ b/tests/test_potentials.py @@ -30,7 +30,7 @@ def gamma(x): # Electron mass m_e = 9.1094 * 1e-31 kg # TODO: for the moment, InversePowerLawPotential only works for exponent 0<p<3 # ps = [1.0, 2.0, 3.0, 6.0] + [0.12345, 0.54321, 2.581304, 4.835909, 6.674311, 9.109431] -ps = [1.0, 2.0, 3.0] + [0.12345, 0.54321, 2.581304] +ps = [1, 2, 3] # Define range of smearing parameters covering relevant values smearinges = [0.1, 0.5, 1.0, 1.56] @@ -83,7 +83,7 @@ def test_sr_lr_split(exponent, smearing): assert_close(potential_from_dist, potential_from_sum, rtol=rtol, atol=atol) -@pytest.mark.parametrize("exponent", [1.0, 2.0, 3.0]) +@pytest.mark.parametrize("exponent", [1, 2, 3]) @pytest.mark.parametrize("smearing", smearinges) def test_exact_sr(exponent, smearing): """ @@ -103,11 +103,11 @@ def test_exact_sr(exponent, smearing): # Compute exact analytical expression obtained for relevant exponents potential_1 = erfc(dists / SQRT2 / smearing) / dists potential_2 = torch.exp(-0.5 * dists_sq / smearing**2) / dists_sq - if exponent == 1.0: + if exponent == 1: potential_exact = potential_1 - elif exponent == 2.0: + elif exponent == 2: potential_exact = potential_2 - elif exponent == 3.0: + elif exponent == 3: prefac = SQRT2 / torch.sqrt(PI) / smearing potential_exact = potential_1 / dists_sq + prefac * potential_2 @@ -117,7 +117,7 @@ def test_exact_sr(exponent, smearing): assert_close(potential_sr_from_dist, potential_exact, rtol=rtol, atol=atol) -@pytest.mark.parametrize("exponent", [1.0, 2.0, 3.0]) +@pytest.mark.parametrize("exponent", [1, 2, 3]) @pytest.mark.parametrize("smearing", smearinges) def test_exact_lr(exponent, smearing): """ @@ -137,11 +137,11 @@ def test_exact_lr(exponent, smearing): # Compute exact analytical expression obtained for relevant exponents potential_1 = erf(dists / SQRT2 / smearing) / dists potential_2 = torch.exp(-0.5 * dists_sq / smearing**2) / dists_sq - if exponent == 1.0: + if exponent == 1: potential_exact = potential_1 - elif exponent == 2.0: + elif exponent == 2: potential_exact = 1 / dists_sq - potential_2 - elif exponent == 3.0: + elif exponent == 3: prefac = SQRT2 / torch.sqrt(PI) / smearing potential_exact = potential_1 / dists_sq - prefac * potential_2 @@ -151,7 +151,7 @@ def test_exact_lr(exponent, smearing): assert_close(potential_lr_from_dist, potential_exact, rtol=rtol, atol=atol) -@pytest.mark.parametrize("exponent", [1.0, 2.0]) +@pytest.mark.parametrize("exponent", [1, 2, 3]) @pytest.mark.parametrize("smearing", smearinges) def test_exact_fourier(exponent, smearing): """ @@ -169,12 +169,12 @@ def test_exact_fourier(exponent, smearing): fourier_from_class = ipl.lr_from_k_sq(ks_sq) # Compute exact analytical expression obtained for relevant exponents - if exponent == 1.0: + if exponent == 1: fourier_exact = 4 * PI / ks_sq * torch.exp(-0.5 * smearing**2 * ks_sq) - elif exponent == 2.0: + elif exponent == 2: fourier_exact = 2 * PI**2 / ks * erfc(smearing * ks / SQRT2) - elif exponent == 3.0: - fourier_exact = -2 * PI * expi(-0.5 * smearing**2 * ks_sq) + elif exponent == 3: + fourier_exact = -2 * PI * torch.tensor(expi(-0.5 * smearing**2 * ks_sq.numpy())) # Compare results. Large tolerance due to singular division rtol = 1e-14 @@ -183,7 +183,7 @@ def test_exact_fourier(exponent, smearing): @pytest.mark.parametrize("smearing", smearinges) -@pytest.mark.parametrize("exponent", ps[:-1]) # for p=9.11, the results are unstable +@pytest.mark.parametrize("exponent", ps[:-1]) def test_lr_value_at_zero(exponent, smearing): """ The LR part of the potential should no longer have a singularity as r-->0. Instead, @@ -213,18 +213,18 @@ def test_lr_value_at_zero(exponent, smearing): def test_exponent_out_of_range(): - match = r"`exponent` p=.* has to satisfy 0 < p <= 3" + match = r"Unsupported exponent: .*" with pytest.raises(ValueError, match=match): InversePowerLawPotential(exponent=-1.0, smearing=0.0) with pytest.raises(ValueError, match=match): - InversePowerLawPotential(exponent=4, smearing=0.0) + InversePowerLawPotential(exponent=7, smearing=0.0) @pytest.mark.parametrize("potential", [CoulombPotential, InversePowerLawPotential]) def test_range_none(potential): if potential is InversePowerLawPotential: - pot = potential(exponent=2.0) + pot = potential(exponent=2) else: pot = potential() @@ -250,16 +250,16 @@ class NoImplPotential(Potential): with pytest.raises( NotImplementedError, match="from_dist is not implemented for NoImplPotential" ): - mypot.from_dist(torch.tensor([1, 2.0, 3.0])) + mypot.from_dist(torch.tensor([1, 2, 3])) with pytest.raises( NotImplementedError, match="lr_from_dist is not implemented for NoImplPotential" ): - mypot.lr_from_dist(torch.tensor([1, 2.0, 3.0])) + mypot.lr_from_dist(torch.tensor([1, 2, 3])) with pytest.raises( NotImplementedError, match="lr_from_k_sq is not implemented for NoImplPotential", ): - mypot.lr_from_k_sq(torch.tensor([1, 2.0, 3.0])) + mypot.lr_from_k_sq(torch.tensor([1, 2, 3])) with pytest.raises( NotImplementedError, match="self_contribution is not implemented for NoImplPotential", @@ -274,7 +274,7 @@ class NoImplPotential(Potential): ValueError, match="Cannot compute cutoff function when `exclusion_radius` is not set", ): - mypot.f_cutoff(torch.tensor([1, 2.0, 3.0])) + mypot.f_cutoff(torch.tensor([1, 2, 3])) @pytest.mark.parametrize("exclusion_radius", [0.5, 1.0, 2.0]) @@ -294,7 +294,7 @@ def test_inverserp_coulomb(smearing): """ # Compute LR part of Coulomb potential using the potentials class working for any # exponent - ipl = InversePowerLawPotential(exponent=1.0, smearing=smearing, dtype=dtype) + ipl = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype) coul = CoulombPotential(smearing=smearing, dtype=dtype) ipl_from_dist = ipl.from_dist(dists) @@ -439,8 +439,8 @@ def forward(self, x: torch.Tensor): @pytest.mark.parametrize("smearing", smearinges) def test_combined_potential(smearing): - ipl_1 = InversePowerLawPotential(exponent=1.0, smearing=smearing, dtype=dtype) - ipl_2 = InversePowerLawPotential(exponent=2.0, smearing=smearing, dtype=dtype) + ipl_1 = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype) + ipl_2 = InversePowerLawPotential(exponent=2, smearing=smearing, dtype=dtype) ipl_1_from_dist = ipl_1.from_dist(dists) ipl_1_sr_from_dist = ipl_1.sr_from_dist(dists) @@ -584,7 +584,7 @@ def test_potential_device_dtype(potential_class, device, dtype): pytest.skip("CUDA is not available") smearing = 1.0 - exponent = 1.0 + exponent = 2 if potential_class is InversePowerLawPotential: potential = potential_class( @@ -604,3 +604,34 @@ def test_potential_device_dtype(potential_class, device, dtype): assert potential_lr.device.type == device assert potential_lr.dtype == dtype + + +@pytest.mark.parametrize("exponent", [4, 5, 6]) +@pytest.mark.parametrize("smearing", smearinges) +def test_inverserp_vs_spline(exponent, smearing): + """ + Compare values from InversePowerLawPotential and InversePowerLawPotentialSpline + with exponents 4, 5, 6. + """ + ks_sq_grad1 = ks_sq.clone().requires_grad_(True) + ks_sq_grad2 = ks_sq.clone().requires_grad_(True) + # Create InversePowerLawPotential + ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype) + ipl_fourier = ipl.lr_from_k_sq(ks_sq_grad1) + + # Create PotentialSpline + r_grid = torch.logspace(-5, 2, 1000, dtype=dtype) + y_grid = ipl.lr_from_dist(r_grid) + spline = SplinePotential(r_grid=r_grid, y_grid=y_grid, dtype=dtype) + spline_fourier = spline.lr_from_k_sq(ks_sq_grad2) + + # Test agreement between InversePowerLawPotential and SplinePotential + atol = 3e-5 + rtol = 2 * machine_epsilon + + assert_close(ipl_fourier, spline_fourier, rtol=rtol, atol=atol) + # Check that gradients are the same + atol = 1e-2 + ipl_fourier.sum().backward() + spline_fourier.sum().backward() + assert_close(ks_sq_grad1.grad, ks_sq_grad2.grad, rtol=rtol, atol=atol)