Skip to content

Commit

Permalink
Fix pytests and make linter happy
Browse files Browse the repository at this point in the history
  • Loading branch information
GardevoirX committed Dec 28, 2024
1 parent 5ed22a5 commit 5e2029d
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 124 deletions.
6 changes: 3 additions & 3 deletions examples/01-charges-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@
# The ``sum_squared_charges`` is equal to ``2.0`` becaue each atom either has a charge
# of 1 or -1 in units of elementary charges.

smearing, pme_params, cutoff = torchpme.utils.tune_pme(
charges=charges, cell=cell, positions=positions
)
smearing, pme_params, cutoff = torchpme.utils.tuning.pme.PMETuner(
charges=charges, cell=cell, positions=positions, cutoff=4.4
).tune()

# %%
#
Expand Down
6 changes: 3 additions & 3 deletions examples/02-neighbor-lists-usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@

sum_squared_charges = float(torch.sum(charges**2))

smearing, pme_params, cutoff = torchpme.utils.tune_pme(
sum_squared_charges=sum_squared_charges, cell=cell, positions=positions
)
smearing, pme_params, cutoff = torchpme.utils.tuning.pme.PMETuner(
charges=charges, cell=cell, positions=positions, cutoff=4.4
).tune()

# %%
#
Expand Down
9 changes: 0 additions & 9 deletions src/torchpme/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
from . import prefactors, tuning, splines # noqa
from .splines import CubicSpline, CubicSplineReciprocal
from .tuning.ewald import EwaldTuner, EwaldErrorBounds
from .tuning.pme import PMETuner, PMEErrorBounds
from .tuning.p3m import P3MTuner, P3MErrorBounds

__all__ = [
"EwaldTuner",
"EwaldErrorBounds",
"P3MTuner",
"P3MErrorBounds",
"PMETuner",
"PMEErrorBounds",
"CubicSpline",
"CubicSplineReciprocal",
]
2 changes: 1 addition & 1 deletion src/torchpme/utils/tuning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def _optimize_parameters(
accuracy: float,
learning_rate: float,
) -> None:

print("optimize ", params)
optimizer = torch.optim.Adam(params, lr=learning_rate)

Expand Down Expand Up @@ -143,6 +142,7 @@ def _validate_parameters(


class TuningErrorBounds(torch.nn.Module):
"""Base class for error bounds."""

def __init__(
self,
Expand Down
10 changes: 9 additions & 1 deletion src/torchpme/utils/tuning/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import numpy as np
import torch

from ...calculators import EwaldCalculator
from . import (
TuningErrorBounds,
)
from .grid_search import GridSearchBase
from ...calculators import EwaldCalculator

TWO_PI = 2 * math.pi

Expand Down Expand Up @@ -82,6 +82,12 @@ def forward(self, smearing, lr_wavelength, cutoff):


class EwaldTuner(GridSearchBase):
"""
Class for finding the optimal parameters for EwaldCalculator using a grid search.
For details of the parameters see :class:`torchpme.utils.tuning.GridSearchBase`.
"""

ErrorBounds = EwaldErrorBounds
CalculatorClass = EwaldCalculator
GridSearchParams = {"lr_wavelength": 1 / np.arange(1, 15)}
Expand All @@ -108,3 +114,5 @@ def __init__(
self.GridSearchParams["lr_wavelength"] *= float(
torch.min(self._cell_dimensions)
)

__doc__ = GridSearchBase.__doc__
96 changes: 43 additions & 53 deletions src/torchpme/utils/tuning/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,30 @@
from ...potentials import CoulombPotential
from . import (
TuningErrorBounds,
_estimate_smearing_cutoff,
_validate_parameters,
)

from . import _estimate_smearing_cutoff, _validate_parameters


class GridSearchBase:
ErrorBounds: TuningErrorBounds
CalculatorClass: Calculator
GridSearchParams: dict[str, torch.Tensor] # {"all_interpolation_nodes": ..., ...}
r"""
Base class for finding the optimal parameters for calculators using a grid search.
:param charges: torch.Tensor, atomic (pseudo-)charges
:param cell: torch.Tensor, periodic supercell for the system
:param positions: torch.Tensor, Cartesian coordinates of the particles within
the supercell.
:param cutoff: float, cutoff distance for the neighborlist
:param exponent :math:`p` in :math:`1/r^p` potentials
:param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for
which the potential should be computed in real space.
:param neighbor_distances: torch.Tensor with the pair distances of the neighbors
for which the potential should be computed in real space.
"""

ErrorBounds: type[TuningErrorBounds]
CalculatorClass: type[Calculator]
GridSearchParams: dict[str, torch.Tensor] # {"interpolation_nodes": ..., ...}

def __init__(
self,
Expand Down Expand Up @@ -70,6 +85,28 @@ def tune(
self,
accuracy: float = 1e-3,
):
r"""
The steps are: 1. Find the ``smearing`` parameter for the
:py:class:`CoulombPotential` that leads to a real space error of half the
desired accuracy. 2. Grid search for the kspace parameters, i.e. the
``lr_wavelength`` for Ewald and the ``mesh_spacing`` and ``interpolation_nodes``
for PME and P3M. For each combination of parameters, calculate the error. If the
error is smaller than the desired accuracy, use this combination for test runs
to get the calculation time needed. Return the combination that leads to the
shortest calculation time. If the desired accuracy is never reached, return the
combination that leads to the smallest error and throw a warning.
:param accuracy: Recomended values for a balance between the accuracy and speed
is :math:`10^{-3}`. For more accurate results, use :math:`10^{-6}`.
:return: Tuple containing a float of the optimal smearing for the :py:class:
`CoulombPotential`, a dictionary with the parameters for the calculator of the
chosen method and a float of the optimal cutoff value for the neighborlist
computation.
"""
if not isinstance(accuracy, float):
raise ValueError(f"'{accuracy}' is not a float.")

smearing_opt = None
params_opt = None
cutoff_opt = None
Expand Down Expand Up @@ -151,52 +188,5 @@ def _timing(self, smearing: float, params: dict):
)
if self.device is torch.device("cuda"):
torch.cuda.synchronize()
execution_time = time.time() - t0

return execution_time


def grid_search(
method: str,
charges: torch.Tensor,
cell: torch.Tensor,
positions: torch.Tensor,
cutoff: float,
exponent: int = 1,
accuracy: float = 1e-3,
neighbor_indices: Optional[torch.Tensor] = None,
neighbor_distances: Optional[torch.Tensor] = None,
):
r"""
Find the optimal parameters for calculators.
The steps are:
1. Find the ``smearing`` parameter for the :py:class:`CoulombPotential` that leads
to a real space error of half the desired accuracy.
2. Grid search for the kspace parameters, i.e. the ``lr_wavelength`` for Ewald and
the ``mesh_spacing`` and ``interpolation_nodes`` for PME and P3M.
For each combination of parameters, calculate the error. If the error is smaller
than the desired accuracy, use this combination for test runs to get the calculation
time needed. Return the combination that leads to the shortest calculation time. If
the desired accuracy is never reached, return the combination that leads to the
smallest error and throw a warning.
:param charges: torch.Tensor, atomic (pseudo-)charges
:param cell: torch.Tensor, periodic supercell for the system
:param positions: torch.Tensor, Cartesian coordinates of the particles within
the supercell.
:param cutoff: float, cutoff distance for the neighborlist
:param exponent :math:`p` in :math:`1/r^p` potentials
:param accuracy: Recomended values for a balance between the accuracy and speed is
:math:`10^{-3}`. For more accurate results, use :math:`10^{-6}`.
:param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for
which the potential should be computed in real space.
:param neighbor_distances: torch.Tensor with the pair distances of the neighbors
for which the potential should be computed in real space.

:return: Tuple containing a float of the optimal smearing for the :py:class:
`CoulombPotential`, a dictionary with the parameters for the calculator of the
chosen method and a float of the optimal cutoff value for the neighborlist
computation.
"""
pass
return time.time() - t0
9 changes: 9 additions & 0 deletions src/torchpme/utils/tuning/p3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from torchpme import P3MCalculator

from . import (
TuningErrorBounds,
)
Expand Down Expand Up @@ -151,6 +152,12 @@ def forward(self, smearing, mesh_spacing, cutoff, interpolation_nodes):


class P3MTuner(GridSearchBase):
"""
Class for finding the optimal parameters for P3MCalculator using a grid search.
For details of the parameters see :class:`torchpme.utils.tuning.GridSearchBase`.
"""

ErrorBounds = P3MErrorBounds
CalculatorClass = P3MCalculator
GridSearchParams = {
Expand Down Expand Up @@ -178,3 +185,5 @@ def __init__(
neighbor_distances,
)
self.GridSearchParams["mesh_spacing"] *= float(torch.min(self._cell_dimensions))

__doc__ = GridSearchBase.__doc__
14 changes: 11 additions & 3 deletions src/torchpme/utils/tuning/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from torchpme import PMECalculator

from . import (
TuningErrorBounds,
)
Expand Down Expand Up @@ -84,7 +85,6 @@ def error(self, cutoff, smearing, mesh_spacing, interpolation_nodes):
piecewise polynomials of degree ``n - 1`` (e.g. ``n = 4`` for cubic
interpolation). Only the values ``3, 4, 5, 6, 7`` are supported.
"""

smearing = torch.as_tensor(smearing)
mesh_spacing = torch.as_tensor(mesh_spacing)
cutoff = torch.as_tensor(cutoff)
Expand All @@ -96,11 +96,17 @@ def error(self, cutoff, smearing, mesh_spacing, interpolation_nodes):


class PMETuner(GridSearchBase):
"""
Class for finding the optimal parameters for PMECalculator using a grid search.
For details of the parameters see :class:`torchpme.utils.tuning.GridSearchBase`.
"""

ErrorBounds = PMEErrorBounds
CalculatorClass = PMECalculator
GridSearchParams = {
"interpolation_nodes": [3, 4, 5, 6, 7],
"ns_mesh": 1 / ((np.exp2(np.arange(2, 8)) - 1) / 2),
"mesh_spacing": 1 / ((np.exp2(np.arange(2, 8)) - 1) / 2),
}

def __init__(
Expand All @@ -122,4 +128,6 @@ def __init__(
neighbor_indices,
neighbor_distances,
)
self.GridSearchParams["ns_mesh"] *= float(torch.min(self._cell_dimensions))
self.GridSearchParams["mesh_spacing"] *= float(torch.min(self._cell_dimensions))

__doc__ = GridSearchBase.__doc__
Loading

0 comments on commit 5e2029d

Please sign in to comment.