Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
E-Rum committed Feb 9, 2025
1 parent 973c9e8 commit c763a09
Show file tree
Hide file tree
Showing 24 changed files with 114 additions and 206 deletions.
3 changes: 1 addition & 2 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Union

import torch

Expand All @@ -11,7 +11,6 @@ def _validate_parameters(
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
) -> None:

dtype = positions.dtype
device = positions.device

Expand Down
2 changes: 0 additions & 2 deletions src/torchpme/calculators/calculator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional, Union

import torch
from torch import profiler

Expand Down
2 changes: 0 additions & 2 deletions src/torchpme/calculators/ewald.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional, Union

import torch

from ..lib import generate_kvectors_for_ewald
Expand Down
8 changes: 5 additions & 3 deletions src/torchpme/calculators/p3m.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional, Union

import torch

from ..lib.kspace_filter import P3MKSpaceFilter
Expand Down Expand Up @@ -66,7 +64,11 @@ def __init__(
prefactor=prefactor,
)

cell = torch.eye(3, device=self.potential.smearing.device, dtype=self.potential.smearing.dtype)
cell = torch.eye(
3,
device=self.potential.smearing.device,
dtype=self.potential.smearing.dtype,
)
ns_mesh = torch.ones(3, dtype=int, device=cell.device)

self.kspace_filter: P3MKSpaceFilter = P3MKSpaceFilter(
Expand Down
8 changes: 5 additions & 3 deletions src/torchpme/calculators/pme.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional, Union

import torch
from torch import profiler

Expand Down Expand Up @@ -70,7 +68,11 @@ def __init__(

self.mesh_spacing: float = mesh_spacing

cell = torch.eye(3, device=self.potential.smearing.device, dtype=self.potential.smearing.dtype)
cell = torch.eye(
3,
device=self.potential.smearing.device,
dtype=self.potential.smearing.dtype,
)
ns_mesh = torch.ones(3, dtype=int, device=cell.device)

self.kspace_filter: KSpaceFilter = KSpaceFilter(
Expand Down
14 changes: 2 additions & 12 deletions src/torchpme/lib/splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,30 +151,25 @@ def _solve_tridiagonal(a, b, c, d):
def compute_second_derivatives(
x_points: torch.Tensor,
y_points: torch.Tensor,
high_precision: Optional[bool] = True,
):
"""
Computes second derivatives given the grid points of a cubic spline.
:param x_points: Abscissas of the splining points for the real-space function
:param y_points: Ordinates of the splining points for the real-space function
:param high_accuracy: bool, perform calculation in double precision
:return: The second derivatives for the spline points
"""
# Do the calculation in float64 if required
x = x_points
y = y_points
if high_precision:
x = x.to(dtype=torch.float64)
y = y.to(dtype=torch.float64)

# Calculate intervals
intervals = x[1:] - x[:-1]
dy = (y[1:] - y[:-1]) / intervals

# Create zero boundary conditions (natural spline)
d2y = torch.zeros_like(x)
torch.zeros_like(x)

n = len(x)
a = torch.zeros_like(x) # Sub-diagonal (a[1..n-1])
Expand All @@ -195,18 +190,16 @@ def compute_second_derivatives(
c[i] = intervals[i] / 6
d[i] = dy[i] - dy[i - 1]

d2y = _solve_tridiagonal(a, b, c, d)
return _solve_tridiagonal(a, b, c, d)

# Converts back to the original dtype
return d2y


def compute_spline_ft(
k_points: torch.Tensor,
x_points: torch.Tensor,
y_points: torch.Tensor,
d2y_points: torch.Tensor,
high_precision: Optional[bool] = True,
):
r"""
Computes the Fourier transform of a splined radial function.
Expand All @@ -228,7 +221,6 @@ def compute_spline_ft(
:param x_points: Abscissas of the splining points for the real-space function
:param y_points: Ordinates of the splining points for the real-space function
:param d2y_points: Second derivatives for the spline points
:param high_accuracy: bool, perform calculation in double precision
:return: The radial Fourier transform :math:`\hat{f}(k)` computed
at the ``k_points`` provided.
Expand All @@ -244,8 +236,6 @@ def compute_spline_ft(

# chooses precision for the FT evaluation
dtype = x_points.dtype
if high_precision:
dtype = torch.float64

# broadcast to compute at once on all k values.
# all these are terms that enter the analytical integral.
Expand Down
2 changes: 1 addition & 1 deletion src/torchpme/potentials/combined.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional

import torch

Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/potentials/coulomb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional

import torch

Expand Down Expand Up @@ -59,7 +59,7 @@ def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
"Cannot compute long-range contribution without specifying `smearing`."
)

return torch.erf(dist / self.smearing / 2.0 ** 0.5) / dist
return torch.erf(dist / self.smearing / 2.0**0.5) / dist

def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
r"""
Expand Down
10 changes: 5 additions & 5 deletions src/torchpme/potentials/inversepowerlaw.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional

import torch
from torch.special import gammainc
Expand Down Expand Up @@ -43,9 +43,7 @@ def __init__(

# function call to check the validity of the exponent
gammaincc_over_powerlaw(exponent, torch.tensor(1.0))
self.register_buffer(
"exponent", torch.tensor(float(exponent))
)
self.register_buffer("exponent", torch.tensor(exponent, dtype=torch.float64))

@torch.jit.export
def from_dist(self, dist: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -99,7 +97,9 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
)

peff = (3 - self.exponent) / 2
prefac = torch.pi**1.5 / gamma(self.exponent / 2) * (2 * self.smearing**2) ** peff
prefac = (
torch.pi**1.5 / gamma(self.exponent / 2) * (2 * self.smearing**2) ** peff
)
x = 0.5 * self.smearing**2 * k_sq

# The k=0 term often needs to be set separately since for exponents p<=3
Expand Down
7 changes: 5 additions & 2 deletions src/torchpme/potentials/potential.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional, Union
from typing import Optional

import torch


class Potential(torch.nn.Module):
r"""
Base class defining the interface for a pair potential energy function
Expand Down Expand Up @@ -39,7 +40,9 @@ def __init__(
super().__init__()

if smearing is not None:
self.register_buffer("smearing", torch.tensor(smearing))
self.register_buffer(
"smearing", torch.tensor(smearing, dtype=torch.float64)
)
else:
self.smearing = None

Expand Down
8 changes: 4 additions & 4 deletions src/torchpme/potentials/spline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional

import torch

Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(

if len(y_grid) != len(r_grid):
raise ValueError("Length of radial grid and value array mismatch.")

self.register_buffer("r_grid", r_grid)
self.register_buffer("y_grid", y_grid)

Expand Down Expand Up @@ -106,14 +106,14 @@ def __init__(

if y_at_zero is None:
self._y_at_zero = self._spline(
torch.zeros(1, dtype=self.dtype, device=self.device)
torch.zeros(1, dtype=self.r_grid.dtype, device=self.r_grid.device)
)
else:
self._y_at_zero = y_at_zero

if yhat_at_zero is None:
self._yhat_at_zero = self._krn_spline(
torch.zeros(1, dtype=self.dtype, device=self.device)
torch.zeros(1, dtype=self.k_grid.dtype, device=self.k_grid.device)
)
else:
self._yhat_at_zero = yhat_at_zero
Expand Down
2 changes: 1 addition & 1 deletion src/torchpme/tuning/ewald.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Any, Optional, Union
from typing import Any
from warnings import warn

import torch
Expand Down
2 changes: 1 addition & 1 deletion src/torchpme/tuning/p3m.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from itertools import product
from typing import Any, Optional, Union
from typing import Any
from warnings import warn

import torch
Expand Down
2 changes: 1 addition & 1 deletion src/torchpme/tuning/pme.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from itertools import product
from typing import Any, Optional, Union
from typing import Any
from warnings import warn

import torch
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/tuning/tuner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import time
from typing import Optional, Union
from typing import Optional

import torch

Expand Down Expand Up @@ -234,7 +234,7 @@ def _timing(self, smearing: float, k_space_params: dict):
),
**k_space_params,
)

calculator.to(device=self.positions.device, dtype=self.positions.dtype)
return self.time_func(calculator)


Expand Down
4 changes: 1 addition & 3 deletions tests/calculators/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def test_invalid_shape_cell():

def test_invalid_dtype_cell():
calculator = CalculatorTest()
match = (
r"type of `cell` \(torch.float64\) must be same as that of the `positions` class \(torch.float32\)"
)
match = r"type of `cell` \(torch.float64\) must be same as that of the `positions` class \(torch.float32\)"
with pytest.raises(TypeError, match=match):
calculator.forward(
positions=POSITIONS_1,
Expand Down
3 changes: 2 additions & 1 deletion tests/calculators/test_values_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class CalculatorTest(Calculator):
def __init__(self, **kwargs):
super().__init__(
potential=CoulombPotential(
smearing=None, exclusion_radius=None,
smearing=None,
exclusion_radius=None,
),
**kwargs,
)
Expand Down
Loading

0 comments on commit c763a09

Please sign in to comment.