Skip to content

Commit

Permalink
Add PyTorch implementation of the exponential integral function and u…
Browse files Browse the repository at this point in the history
…pdate references
  • Loading branch information
E-Rum committed Jan 18, 2025
1 parent 3540a81 commit ba2d2c8
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
Added
#####

* Added a PyTorch implementation of the exponential integral function
* Added ``dtype`` and ``device`` for ``Calculator`` classses

Fixed
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
generate_kvectors_for_mesh,
get_ns_mesh,
)
from .math import CustomExp1, gamma, gammaincc_over_powerlaw, torch_exp1
from .math import CustomExp1, exp1, gamma, gammaincc_over_powerlaw
from .mesh_interpolator import MeshInterpolator

__all__ = [
Expand All @@ -20,5 +20,5 @@
"gamma",
"CustomExp1",
"gammaincc_over_powerlaw",
"torch_exp1",
"exp1",
]
51 changes: 44 additions & 7 deletions src/torchpme/lib/math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from scipy.special import exp1
from torch.special import gammaln


Expand All @@ -15,21 +14,59 @@ def gamma(x: torch.Tensor) -> torch.Tensor:


class CustomExp1(torch.autograd.Function):
"""Custom exponential integral function Exp1(x) to have an autograd-compatible version."""
"""
Compute the exponential integral E1(x) for x > 0.
:param input: Input tensor (x > 0)
:return: Exponential integral E1(x)
"""

@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
input_numpy = input.cpu().numpy() if not input.is_cpu else input.numpy()
return torch.tensor(exp1(input_numpy), device=input.device, dtype=input.dtype)

# Constants
SCIPY_EULER = (
0.577215664901532860606512090082402431 # Euler-Mascheroni constant
)
inf = torch.inf

# Handle case when x == 0
result = torch.full_like(input, inf)
mask = input > 0

# Compute for x <= 1
x_small = input[mask & (input <= 1)]
if x_small.numel() > 0:
e1 = torch.ones_like(x_small)
r = torch.ones_like(x_small)
for k in range(1, 26):
r = -r * k * x_small / (k + 1.0) ** 2
e1 += r
if torch.all(torch.abs(r) <= torch.abs(e1) * 1e-15):
break
result[mask & (input <= 1)] = (
-SCIPY_EULER - torch.log(x_small) + x_small * e1
)

# Compute for x > 1
x_large = input[mask & (input > 1)]
if x_large.numel() > 0:
m = 20 + (80.0 / x_large).to(torch.int32)
t0 = torch.zeros_like(x_large)
for k in range(m.max(), 0, -1):
t0 = k / (1.0 + k / (x_large + t0))
t = 1.0 / (x_large + t0)
result[mask & (input > 1)] = torch.exp(-x_large) * t

return result

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
return -grad_output * torch.exp(-input) / input


def torch_exp1(input):
def exp1(input):
"""Wrapper for the custom exponential integral function."""
return CustomExp1.apply(input)

Expand All @@ -41,13 +78,13 @@ def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Te
if exponent == 2:
return torch.sqrt(torch.pi / z) * torch.erfc(torch.sqrt(z))
if exponent == 3:
return torch_exp1(z)
return exp1(z)
if exponent == 4:
return 2 * (
torch.exp(-z) - torch.sqrt(torch.pi * z) * torch.erfc(torch.sqrt(z))
)
if exponent == 5:
return torch.exp(-z) - z * torch_exp1(z)
return torch.exp(-z) - z * exp1(z)
if exponent == 6:
return (
(2 - 4 * z) * torch.exp(-z)
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/potentials/inversepowerlaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ def background_correction(self) -> torch.Tensor:
# "charge neutrality" correction for 1/r^p potential diverges for exponent p = 3
# and is not needed for p > 3 , so we set it to zero (see in
# https://doi.org/10.48550/arXiv.2412.03281 SI section)
if self.exponent >= 3:
return torch.tensor(0.0, dtype=self.dtype, device=self.device)
if self.smearing is None:
raise ValueError(
"Cannot compute background correction without specifying `smearing`."
)
if self.exponent >= 3:
return self.smearing * 0.0
prefac = torch.pi**1.5 * (2 * self.smearing**2) ** ((3 - self.exponent) / 2)
prefac /= (3 - self.exponent) * gamma(self.exponent / 2)
return prefac
Expand Down
11 changes: 6 additions & 5 deletions tests/lib/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@
import torch
from scipy.special import exp1

from torchpme.lib import torch_exp1
from torchpme.lib import exp1 as torch_exp1


def finite_difference_derivative(func, x, h=1e-5):
return (func(x + h) - func(x - h)) / (2 * h)


def test_torch_exp1_consistency_with_scipy():
x = torch.rand(1000, dtype=torch.float64)
torch_result = torch_exp1(x)
scipy_result = exp1(x.numpy())
assert np.allclose(torch_result.numpy(), scipy_result, atol=1e-6)
random_tensor = torch.FloatTensor(100000).uniform_(0, 1000)
random_array = random_tensor.numpy()
scipy_result = exp1(random_array)
torch_result = torch_exp1(random_tensor)
assert np.allclose(scipy_result, torch_result.numpy(), atol=1e-15)


def test_torch_exp1_derivative():
Expand Down

0 comments on commit ba2d2c8

Please sign in to comment.