Skip to content

Commit

Permalink
Make timing test less strict
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jan 24, 2025
1 parent 6c982c8 commit 14f6771
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 45 deletions.
38 changes: 16 additions & 22 deletions tests/tuning/test_timer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
from pathlib import Path

import ase
import torch

from torchpme import (
Expand All @@ -10,7 +11,7 @@
from torchpme.tuning.tuner import TuningTimings

sys.path.append(str(Path(__file__).parents[1]))
from helpers import compute_distances, define_crystal, neighbor_list
from helpers import define_crystal, neighbor_list

DTYPE = torch.float32
DEFAULT_CUTOFF = 4.4
Expand All @@ -19,30 +20,23 @@
CELL_1 = torch.eye(3, dtype=DTYPE)


def _nl_calculation(pos, cell):
neighbor_indices, neighbor_shifts = neighbor_list(
positions=pos,
periodic=True,
box=cell,
cutoff=DEFAULT_CUTOFF,
neighbor_shifts=True,
)

neighbor_distances = compute_distances(
positions=pos,
neighbor_indices=neighbor_indices,
cell=cell,
neighbor_shifts=neighbor_shifts,
)
def test_timer():
n_repeat_1 = 10
n_repeat_2 = 100
pos, charges, cell, _, _ = define_crystal()

return neighbor_indices, neighbor_distances
# use ase to make system bigger
atoms = ase.Atoms("H" * len(pos), positions=pos.numpy(), cell=cell.numpy())
atoms.set_initial_charges(charges.numpy().flatten())
atoms.repeat((4, 4, 4))

pos = torch.tensor(atoms.positions, dtype=DTYPE)
charges = torch.tensor(atoms.get_initial_charges(), dtype=DTYPE).reshape(-1, 1)
cell = torch.tensor(atoms.cell.array, dtype=DTYPE)

def test_timer():
n_repeat_1 = 8
n_repeat_2 = 16
pos, charges, cell, madelung_ref, num_units = define_crystal()
neighbor_indices, neighbor_distances = _nl_calculation(pos, cell)
neighbor_indices, neighbor_distances = neighbor_list(
positions=pos, box=cell, cutoff=DEFAULT_CUTOFF
)

calculator = EwaldCalculator(
potential=CoulombPotential(smearing=1.0),
Expand Down
33 changes: 10 additions & 23 deletions tests/tuning/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchpme.tuning.tuner import TunerBase

sys.path.append(str(Path(__file__).parents[1]))
from helpers import compute_distances, define_crystal, neighbor_list
from helpers import define_crystal, neighbor_list

DTYPE = torch.float32
DEVICE = "cpu"
Expand All @@ -24,25 +24,6 @@
CELL_1 = torch.eye(3, dtype=DTYPE, device=DEVICE)


def _nl_calculation(pos, cell):
neighbor_indices, neighbor_shifts = neighbor_list(
positions=pos,
periodic=True,
box=cell,
cutoff=DEFAULT_CUTOFF,
neighbor_shifts=True,
)

neighbor_distances = compute_distances(
positions=pos,
neighbor_indices=neighbor_indices,
cell=cell,
neighbor_shifts=neighbor_shifts,
)

return neighbor_indices, neighbor_distances


def test_TunerBase_double():
"""
Check that `TunerBase` initilizes with double precisions tensors.
Expand Down Expand Up @@ -78,7 +59,9 @@ def test_parameter_choose(calculator, tune, param_length, accuracy):
pos, charges, cell, madelung_ref, num_units = define_crystal()

# Compute neighbor list
neighbor_indices, neighbor_distances = _nl_calculation(pos, cell)
neighbor_indices, neighbor_distances = neighbor_list(
positions=pos, box=cell, cutoff=DEFAULT_CUTOFF
)

smearing, params, _ = tune(
charges,
Expand Down Expand Up @@ -115,7 +98,9 @@ def test_accuracy_error(tune):
pos, charges, cell, _, _ = define_crystal()

match = "'foo' is not a float."
neighbor_indices, neighbor_distances = _nl_calculation(pos, cell)
neighbor_indices, neighbor_distances = neighbor_list(
positions=pos, box=cell, cutoff=DEFAULT_CUTOFF
)
with pytest.raises(ValueError, match=match):
tune(
charges,
Expand All @@ -133,7 +118,9 @@ def test_exponent_not_1_error(tune):
pos, charges, cell, _, _ = define_crystal()

match = "Only exponent = 1 is supported but got 2."
neighbor_indices, neighbor_distances = _nl_calculation(pos, cell)
neighbor_indices, neighbor_distances = neighbor_list(
positions=pos, box=cell, cutoff=DEFAULT_CUTOFF
)
with pytest.raises(NotImplementedError, match=match):
tune(
charges,
Expand Down

0 comments on commit 14f6771

Please sign in to comment.