Skip to content

Commit

Permalink
Continue changing
Browse files Browse the repository at this point in the history
  • Loading branch information
E-Rum committed Feb 4, 2025
1 parent 0cd2e80 commit 973c9e8
Show file tree
Hide file tree
Showing 14 changed files with 66 additions and 153 deletions.
11 changes: 7 additions & 4 deletions src/torchpme/calculators/p3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ def __init__(
prefactor=prefactor,
)

cell = torch.eye(3, device=self.potential.smearing.device, dtype=self.potential.smearing.dtype)
ns_mesh = torch.ones(3, dtype=int, device=cell.device)

self.kspace_filter: P3MKSpaceFilter = P3MKSpaceFilter(
cell=torch.eye(3, dtype=self.dtype, device=self.device),
ns_mesh=torch.ones(3, dtype=int, device=self.device),
cell=cell,
ns_mesh=ns_mesh,
interpolation_nodes=self.interpolation_nodes,
kernel=self.potential,
mode=0, # Green's function for point-charge potentials
Expand All @@ -78,8 +81,8 @@ def __init__(
)

self.mesh_interpolator: MeshInterpolator = MeshInterpolator(
cell=torch.eye(3, dtype=self.dtype, device=self.device),
ns_mesh=torch.ones(3, dtype=int, device=self.device),
cell=cell,
ns_mesh=ns_mesh,
interpolation_nodes=self.interpolation_nodes,
method="P3M",
)
8 changes: 4 additions & 4 deletions src/torchpme/calculators/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def __init__(

self.mesh_spacing: float = mesh_spacing

self.register_buffer("cell", torch.eye(3))
ns_mesh = torch.ones(3, dtype=int, device=self.cell.device)
cell = torch.eye(3, device=self.potential.smearing.device, dtype=self.potential.smearing.dtype)
ns_mesh = torch.ones(3, dtype=int, device=cell.device)

self.kspace_filter: KSpaceFilter = KSpaceFilter(
cell=self.cell,
cell=cell,
ns_mesh=ns_mesh,
kernel=self.potential,
fft_norm="backward",
Expand All @@ -84,7 +84,7 @@ def __init__(
self.interpolation_nodes: int = interpolation_nodes

self.mesh_interpolator: MeshInterpolator = MeshInterpolator(
cell=self.cell,
cell=cell,
ns_mesh=ns_mesh,
interpolation_nodes=self.interpolation_nodes,
method="Lagrange", # convention for classic PME
Expand Down
2 changes: 2 additions & 0 deletions src/torchpme/potentials/coulomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
# https://github.com/jax-ml/jax/issues/1052
# https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
masked = torch.where(k_sq == 0, 1.0, k_sq)
print(self.smearing.device)
print(k_sq.device)
return torch.where(
k_sq == 0,
0.0,
Expand Down
2 changes: 1 addition & 1 deletion src/torchpme/potentials/inversepowerlaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
# function call to check the validity of the exponent
gammaincc_over_powerlaw(exponent, torch.tensor(1.0))
self.register_buffer(
"exponent", torch.tensor(exponent)
"exponent", torch.tensor(float(exponent))
)

@torch.jit.export
Expand Down
6 changes: 5 additions & 1 deletion src/torchpme/potentials/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def __init__(
):
super().__init__()

self.smearing = smearing
if smearing is not None:
self.register_buffer("smearing", torch.tensor(smearing))
else:
self.smearing = None

self.exclusion_radius = exclusion_radius

@torch.jit.export
Expand Down
4 changes: 0 additions & 4 deletions src/torchpme/tuning/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ def tune_ewald(
ns_lo: int = 1,
ns_hi: int = 14,
accuracy: float = 1e-3,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.EwaldCalculator`.
Expand Down Expand Up @@ -96,8 +94,6 @@ def tune_ewald(
calculator=EwaldCalculator,
error_bounds=EwaldErrorBounds(charges=charges, cell=cell, positions=positions),
params=params,
dtype=dtype,
device=device,
)
smearing = tuner.estimate_smearing(accuracy)
errs, timings = tuner.tune(accuracy)
Expand Down
4 changes: 0 additions & 4 deletions src/torchpme/tuning/p3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ def tune_p3m(
mesh_lo: int = 2,
mesh_hi: int = 7,
accuracy: float = 1e-3,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.calculators.pme.PMECalculator`.
Expand Down Expand Up @@ -169,8 +167,6 @@ def tune_p3m(
calculator=P3MCalculator,
error_bounds=P3MErrorBounds(charges=charges, cell=cell, positions=positions),
params=params,
dtype=dtype,
device=device,
)
smearing = tuner.estimate_smearing(accuracy)
errs, timings = tuner.tune(accuracy)
Expand Down
4 changes: 0 additions & 4 deletions src/torchpme/tuning/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ def tune_pme(
mesh_lo: int = 2,
mesh_hi: int = 7,
accuracy: float = 1e-3,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.PMECalculator`.
Expand Down Expand Up @@ -112,8 +110,6 @@ def tune_pme(
calculator=PMECalculator,
error_bounds=PMEErrorBounds(charges=charges, cell=cell, positions=positions),
params=params,
dtype=dtype,
device=device,
)
smearing = tuner.estimate_smearing(accuracy)
errs, timings = tuner.tune(accuracy)
Expand Down
28 changes: 1 addition & 27 deletions src/torchpme/tuning/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from .._utils import _get_device, _get_dtype, _validate_parameters
from .._utils import _validate_parameters
from ..calculators import Calculator
from ..potentials import InversePowerLawPotential

Expand Down Expand Up @@ -83,17 +83,12 @@ def __init__(
cutoff: float,
calculator: type[Calculator],
exponent: int = 1,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
if exponent != 1:
raise NotImplementedError(
f"Only exponent = 1 is supported but got {exponent}."
)

self.device = _get_device(device)
self.dtype = _get_dtype(dtype)

_validate_parameters(
charges=charges,
cell=cell,
Expand All @@ -103,8 +98,6 @@ def __init__(
[1.0], device=positions.device, dtype=positions.dtype
),
smearing=1.0, # dummy value because; always have range-seperated potentials
dtype=self.dtype,
device=self.device,
)
self.charges = charges
self.cell = cell
Expand Down Expand Up @@ -189,8 +182,6 @@ def __init__(
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
exponent: int = 1,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__(
charges=charges,
Expand All @@ -199,8 +190,6 @@ def __init__(
cutoff=cutoff,
calculator=calculator,
exponent=exponent,
dtype=dtype,
device=device,
)
self.error_bounds = error_bounds
self.params = params
Expand All @@ -211,8 +200,6 @@ def __init__(
neighbor_indices,
neighbor_distances,
True,
dtype=dtype,
device=device,
)

def tune(self, accuracy: float = 1e-3) -> tuple[list[float], list[float]]:
Expand Down Expand Up @@ -244,11 +231,7 @@ def _timing(self, smearing: float, k_space_params: dict):
potential=InversePowerLawPotential(
exponent=self.exponent, # but only exponent = 1 is supported
smearing=smearing,
device=self.device,
dtype=self.dtype,
),
device=self.device,
dtype=self.dtype,
**k_space_params,
)

Expand Down Expand Up @@ -289,23 +272,16 @@ def __init__(
n_repeat: int = 4,
n_warmup: int = 4,
run_backward: Optional[bool] = True,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__()

self.device = _get_device(device)
self.dtype = _get_dtype(dtype)

_validate_parameters(
charges=charges,
cell=cell,
positions=positions,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
smearing=1.0, # dummy value because; always have range-seperated potentials
device=self.device,
dtype=self.dtype,
)

self.charges = charges
Expand Down Expand Up @@ -351,8 +327,6 @@ def forward(self, calculator: torch.nn.Module):
if self.run_backward:
value.backward(retain_graph=True)

if self.device is torch.device("cuda"):
torch.cuda.synchronize()
execution_time += time.monotonic()

return execution_time / self.n_repeat
43 changes: 7 additions & 36 deletions tests/calculators/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,35 +43,6 @@ def test_compute_output_shapes():
assert result.shape == charges.shape


def test_wrong_device_positions():
calculator = CalculatorTest()
match = r"device of `positions` \(meta\) must be same as the class device \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1.to(device="meta"),
charges=CHARGES_1,
cell=CELL_1,
neighbor_indices=NEIGHBOR_INDICES,
neighbor_distances=NEIGHBOR_DISTANCES,
)


def test_wrong_dtype_positions():
calculator = CalculatorTest()
match = (
r"type of `positions` \(torch.float64\) must be same as the class type "
r"\(torch.float32\)"
)
with pytest.raises(TypeError, match=match):
calculator.forward(
positions=POSITIONS_1.to(dtype=torch.float64),
charges=CHARGES_1,
cell=CELL_1,
neighbor_indices=NEIGHBOR_INDICES,
neighbor_distances=NEIGHBOR_DISTANCES,
)


# Tests for invalid shape, dtype and device of positions
def test_invalid_shape_positions():
calculator = CalculatorTest()
Expand Down Expand Up @@ -108,7 +79,7 @@ def test_invalid_shape_cell():
def test_invalid_dtype_cell():
calculator = CalculatorTest()
match = (
r"type of `cell` \(torch.float64\) must be same as the class \(torch.float32\)"
r"type of `cell` \(torch.float64\) must be same as that of the `positions` class \(torch.float32\)"
)
with pytest.raises(TypeError, match=match):
calculator.forward(
Expand All @@ -122,7 +93,7 @@ def test_invalid_dtype_cell():

def test_invalid_device_cell():
calculator = CalculatorTest()
match = r"device of `cell` \(meta\) must be same as the class \(cpu\)"
match = r"device of `cell` \(meta\) must be same as that of the `positions` class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
Expand Down Expand Up @@ -188,7 +159,7 @@ def test_invalid_shape_charges():
def test_invalid_dtype_charges():
calculator = CalculatorTest()
match = (
r"type of `charges` \(torch.float64\) must be same as the class "
r"type of `charges` \(torch.float64\) must be same as that of the `positions` class "
r"\(torch.float32\)"
)
with pytest.raises(TypeError, match=match):
Expand All @@ -203,7 +174,7 @@ def test_invalid_dtype_charges():

def test_invalid_device_charges():
calculator = CalculatorTest()
match = r"device of `charges` \(meta\) must be same as the class \(cpu\)"
match = r"device of `charges` \(meta\) must be same as that of the `positions` class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
Expand Down Expand Up @@ -248,7 +219,7 @@ def test_invalid_shape_neighbor_indices_neighbor_distances():

def test_invalid_device_neighbor_indices():
calculator = CalculatorTest()
match = r"device of `neighbor_indices` \(meta\) must be same as the class \(cpu\)"
match = r"device of `neighbor_indices` \(meta\) must be same as that of the `positions` class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
Expand All @@ -261,7 +232,7 @@ def test_invalid_device_neighbor_indices():

def test_invalid_device_neighbor_distances():
calculator = CalculatorTest()
match = r"device of `neighbor_distances` \(meta\) must be same as the class \(cpu\)"
match = r"device of `neighbor_distances` \(meta\) must be same as that of the `positions` class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
Expand All @@ -276,7 +247,7 @@ def test_invalid_dtype_neighbor_distances():
calculator = CalculatorTest()
match = (
r"type of `neighbor_distances` \(torch.float64\) must be same "
r"as the class \(torch.float32\)"
r"as that of the `positions` class \(torch.float32\)"
)
with pytest.raises(TypeError, match=match):
calculator.forward(
Expand Down
3 changes: 1 addition & 2 deletions tests/calculators/test_values_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ class CalculatorTest(Calculator):
def __init__(self, **kwargs):
super().__init__(
potential=CoulombPotential(
smearing=None, exclusion_radius=None, dtype=DTYPE
smearing=None, exclusion_radius=None,
),
**kwargs,
dtype=DTYPE,
)


Expand Down
Loading

0 comments on commit 973c9e8

Please sign in to comment.