Skip to content

Commit

Permalink
Add a built-in filter for NL calculation results
Browse files Browse the repository at this point in the history
  • Loading branch information
GardevoirX committed Feb 4, 2025
1 parent ddff22b commit df9a8a6
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/torchpme/tuning/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,25 @@ def estimate_smearing(

return float(smearing)

@staticmethod
def filter_neighbors(
cutoff: float, neighbor_indices: torch.Tensor, neighbor_distances: torch.Tensor
):
"""
Filter neighbor indices and distances based on a user given cutoff. This allows
users pre-computing the neighbor list with a larger cutoff and then filtering
the neighbors based on a smaller cutoff, leading to a faster tuning on the
cutoff.
:param cutoff: real space cutoff
:param neighbor_indices: torch.tensor with the ``i,j`` indices of neighbors for
which the potential should be computed in real space.
:param neighbor_distances: torch.tensor with the pair distances of the neighbors
for which the potential should be computed in real space."""

filter_idx = torch.where(neighbor_distances <= cutoff)
return neighbor_indices[filter_idx], neighbor_distances[filter_idx]


class GridSearchTuner(TunerBase):
"""
Expand Down Expand Up @@ -204,6 +223,9 @@ def __init__(
)
self.error_bounds = error_bounds
self.params = params
neighbor_indices, neighbor_distances = self.filter_neighbors(
cutoff, neighbor_indices, neighbor_distances
)
self.time_func = TuningTimings(
charges,
cell,
Expand Down

0 comments on commit df9a8a6

Please sign in to comment.