From df9a8a6282037a12043505aca3944f7c78d975fd Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Tue, 4 Feb 2025 14:56:50 +0100 Subject: [PATCH] Add a built-in filter for NL calculation results --- src/torchpme/tuning/tuner.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py index 546b3995..69595c61 100644 --- a/src/torchpme/tuning/tuner.py +++ b/src/torchpme/tuning/tuner.py @@ -144,6 +144,25 @@ def estimate_smearing( return float(smearing) + @staticmethod + def filter_neighbors( + cutoff: float, neighbor_indices: torch.Tensor, neighbor_distances: torch.Tensor + ): + """ + Filter neighbor indices and distances based on a user given cutoff. This allows + users pre-computing the neighbor list with a larger cutoff and then filtering + the neighbors based on a smaller cutoff, leading to a faster tuning on the + cutoff. + + :param cutoff: real space cutoff + :param neighbor_indices: torch.tensor with the ``i,j`` indices of neighbors for + which the potential should be computed in real space. + :param neighbor_distances: torch.tensor with the pair distances of the neighbors + for which the potential should be computed in real space.""" + + filter_idx = torch.where(neighbor_distances <= cutoff) + return neighbor_indices[filter_idx], neighbor_distances[filter_idx] + class GridSearchTuner(TunerBase): """ @@ -204,6 +223,9 @@ def __init__( ) self.error_bounds = error_bounds self.params = params + neighbor_indices, neighbor_distances = self.filter_neighbors( + cutoff, neighbor_indices, neighbor_distances + ) self.time_func = TuningTimings( charges, cell,