From ba2d2c8e692f7db379bef14853f44a63dcab4306 Mon Sep 17 00:00:00 2001 From: E-Rum Date: Sat, 18 Jan 2025 15:49:59 +0000 Subject: [PATCH] Add PyTorch implementation of the exponential integral function and update references --- docs/src/references/changelog.rst | 1 + src/torchpme/lib/__init__.py | 4 +- src/torchpme/lib/math.py | 51 +++++++++++++++++++--- src/torchpme/potentials/inversepowerlaw.py | 4 +- tests/lib/test_math.py | 11 ++--- 5 files changed, 55 insertions(+), 16 deletions(-) diff --git a/docs/src/references/changelog.rst b/docs/src/references/changelog.rst index 96407d07..12736a08 100644 --- a/docs/src/references/changelog.rst +++ b/docs/src/references/changelog.rst @@ -27,6 +27,7 @@ changelog `_ format. This project follows Added ##### +* Added a PyTorch implementation of the exponential integral function * Added ``dtype`` and ``device`` for ``Calculator`` classses Fixed diff --git a/src/torchpme/lib/__init__.py b/src/torchpme/lib/__init__.py index 28cea33b..b719cb1c 100644 --- a/src/torchpme/lib/__init__.py +++ b/src/torchpme/lib/__init__.py @@ -4,7 +4,7 @@ generate_kvectors_for_mesh, get_ns_mesh, ) -from .math import CustomExp1, gamma, gammaincc_over_powerlaw, torch_exp1 +from .math import CustomExp1, exp1, gamma, gammaincc_over_powerlaw from .mesh_interpolator import MeshInterpolator __all__ = [ @@ -20,5 +20,5 @@ "gamma", "CustomExp1", "gammaincc_over_powerlaw", - "torch_exp1", + "exp1", ] diff --git a/src/torchpme/lib/math.py b/src/torchpme/lib/math.py index 871abbc2..886f84d6 100644 --- a/src/torchpme/lib/math.py +++ b/src/torchpme/lib/math.py @@ -1,5 +1,4 @@ import torch -from scipy.special import exp1 from torch.special import gammaln @@ -15,13 +14,51 @@ def gamma(x: torch.Tensor) -> torch.Tensor: class CustomExp1(torch.autograd.Function): - """Custom exponential integral function Exp1(x) to have an autograd-compatible version.""" + """ + Compute the exponential integral E1(x) for x > 0. + :param input: Input tensor (x > 0) + :return: Exponential integral E1(x) + """ @staticmethod def forward(ctx, input): ctx.save_for_backward(input) - input_numpy = input.cpu().numpy() if not input.is_cpu else input.numpy() - return torch.tensor(exp1(input_numpy), device=input.device, dtype=input.dtype) + + # Constants + SCIPY_EULER = ( + 0.577215664901532860606512090082402431 # Euler-Mascheroni constant + ) + inf = torch.inf + + # Handle case when x == 0 + result = torch.full_like(input, inf) + mask = input > 0 + + # Compute for x <= 1 + x_small = input[mask & (input <= 1)] + if x_small.numel() > 0: + e1 = torch.ones_like(x_small) + r = torch.ones_like(x_small) + for k in range(1, 26): + r = -r * k * x_small / (k + 1.0) ** 2 + e1 += r + if torch.all(torch.abs(r) <= torch.abs(e1) * 1e-15): + break + result[mask & (input <= 1)] = ( + -SCIPY_EULER - torch.log(x_small) + x_small * e1 + ) + + # Compute for x > 1 + x_large = input[mask & (input > 1)] + if x_large.numel() > 0: + m = 20 + (80.0 / x_large).to(torch.int32) + t0 = torch.zeros_like(x_large) + for k in range(m.max(), 0, -1): + t0 = k / (1.0 + k / (x_large + t0)) + t = 1.0 / (x_large + t0) + result[mask & (input > 1)] = torch.exp(-x_large) * t + + return result @staticmethod def backward(ctx, grad_output): @@ -29,7 +66,7 @@ def backward(ctx, grad_output): return -grad_output * torch.exp(-input) / input -def torch_exp1(input): +def exp1(input): """Wrapper for the custom exponential integral function.""" return CustomExp1.apply(input) @@ -41,13 +78,13 @@ def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Te if exponent == 2: return torch.sqrt(torch.pi / z) * torch.erfc(torch.sqrt(z)) if exponent == 3: - return torch_exp1(z) + return exp1(z) if exponent == 4: return 2 * ( torch.exp(-z) - torch.sqrt(torch.pi * z) * torch.erfc(torch.sqrt(z)) ) if exponent == 5: - return torch.exp(-z) - z * torch_exp1(z) + return torch.exp(-z) - z * exp1(z) if exponent == 6: return ( (2 - 4 * z) * torch.exp(-z) diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py index 9abeb823..35ff7ac7 100644 --- a/src/torchpme/potentials/inversepowerlaw.py +++ b/src/torchpme/potentials/inversepowerlaw.py @@ -137,12 +137,12 @@ def background_correction(self) -> torch.Tensor: # "charge neutrality" correction for 1/r^p potential diverges for exponent p = 3 # and is not needed for p > 3 , so we set it to zero (see in # https://doi.org/10.48550/arXiv.2412.03281 SI section) - if self.exponent >= 3: - return torch.tensor(0.0, dtype=self.dtype, device=self.device) if self.smearing is None: raise ValueError( "Cannot compute background correction without specifying `smearing`." ) + if self.exponent >= 3: + return self.smearing * 0.0 prefac = torch.pi**1.5 * (2 * self.smearing**2) ** ((3 - self.exponent) / 2) prefac /= (3 - self.exponent) * gamma(self.exponent / 2) return prefac diff --git a/tests/lib/test_math.py b/tests/lib/test_math.py index 4ec7037c..b2c8cfb1 100644 --- a/tests/lib/test_math.py +++ b/tests/lib/test_math.py @@ -2,7 +2,7 @@ import torch from scipy.special import exp1 -from torchpme.lib import torch_exp1 +from torchpme.lib import exp1 as torch_exp1 def finite_difference_derivative(func, x, h=1e-5): @@ -10,10 +10,11 @@ def finite_difference_derivative(func, x, h=1e-5): def test_torch_exp1_consistency_with_scipy(): - x = torch.rand(1000, dtype=torch.float64) - torch_result = torch_exp1(x) - scipy_result = exp1(x.numpy()) - assert np.allclose(torch_result.numpy(), scipy_result, atol=1e-6) + random_tensor = torch.FloatTensor(100000).uniform_(0, 1000) + random_array = random_tensor.numpy() + scipy_result = exp1(random_array) + torch_result = torch_exp1(random_tensor) + assert np.allclose(scipy_result, torch_result.numpy(), atol=1e-15) def test_torch_exp1_derivative():