diff --git a/src/torchpme/_utils.py b/src/torchpme/_utils.py index 55c38e713..91568ecf9 100644 --- a/src/torchpme/_utils.py +++ b/src/torchpme/_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Union import torch @@ -11,7 +11,6 @@ def _validate_parameters( neighbor_distances: torch.Tensor, smearing: Union[float, None], ) -> None: - dtype = positions.dtype device = positions.device diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py index 2af449f6d..0bdfc4211 100644 --- a/src/torchpme/calculators/calculator.py +++ b/src/torchpme/calculators/calculator.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - import torch from torch import profiler diff --git a/src/torchpme/calculators/ewald.py b/src/torchpme/calculators/ewald.py index 0f6f2a340..4c5d19068 100644 --- a/src/torchpme/calculators/ewald.py +++ b/src/torchpme/calculators/ewald.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - import torch from ..lib import generate_kvectors_for_ewald diff --git a/src/torchpme/calculators/p3m.py b/src/torchpme/calculators/p3m.py index 24d38e252..a2f9015c2 100644 --- a/src/torchpme/calculators/p3m.py +++ b/src/torchpme/calculators/p3m.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - import torch from ..lib.kspace_filter import P3MKSpaceFilter @@ -66,7 +64,11 @@ def __init__( prefactor=prefactor, ) - cell = torch.eye(3, device=self.potential.smearing.device, dtype=self.potential.smearing.dtype) + cell = torch.eye( + 3, + device=self.potential.smearing.device, + dtype=self.potential.smearing.dtype, + ) ns_mesh = torch.ones(3, dtype=int, device=cell.device) self.kspace_filter: P3MKSpaceFilter = P3MKSpaceFilter( diff --git a/src/torchpme/calculators/pme.py b/src/torchpme/calculators/pme.py index c952b7162..77bfcd8a4 100644 --- a/src/torchpme/calculators/pme.py +++ b/src/torchpme/calculators/pme.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - import torch from torch import profiler @@ -70,7 +68,11 @@ def __init__( self.mesh_spacing: float = mesh_spacing - cell = torch.eye(3, device=self.potential.smearing.device, dtype=self.potential.smearing.dtype) + cell = torch.eye( + 3, + device=self.potential.smearing.device, + dtype=self.potential.smearing.dtype, + ) ns_mesh = torch.ones(3, dtype=int, device=cell.device) self.kspace_filter: KSpaceFilter = KSpaceFilter( diff --git a/src/torchpme/lib/splines.py b/src/torchpme/lib/splines.py index d99bdf1e5..5cee00583 100644 --- a/src/torchpme/lib/splines.py +++ b/src/torchpme/lib/splines.py @@ -151,30 +151,25 @@ def _solve_tridiagonal(a, b, c, d): def compute_second_derivatives( x_points: torch.Tensor, y_points: torch.Tensor, - high_precision: Optional[bool] = True, ): """ Computes second derivatives given the grid points of a cubic spline. :param x_points: Abscissas of the splining points for the real-space function :param y_points: Ordinates of the splining points for the real-space function - :param high_accuracy: bool, perform calculation in double precision :return: The second derivatives for the spline points """ # Do the calculation in float64 if required x = x_points y = y_points - if high_precision: - x = x.to(dtype=torch.float64) - y = y.to(dtype=torch.float64) # Calculate intervals intervals = x[1:] - x[:-1] dy = (y[1:] - y[:-1]) / intervals # Create zero boundary conditions (natural spline) - d2y = torch.zeros_like(x) + torch.zeros_like(x) n = len(x) a = torch.zeros_like(x) # Sub-diagonal (a[1..n-1]) @@ -195,10 +190,9 @@ def compute_second_derivatives( c[i] = intervals[i] / 6 d[i] = dy[i] - dy[i - 1] - d2y = _solve_tridiagonal(a, b, c, d) + return _solve_tridiagonal(a, b, c, d) # Converts back to the original dtype - return d2y def compute_spline_ft( @@ -206,7 +200,6 @@ def compute_spline_ft( x_points: torch.Tensor, y_points: torch.Tensor, d2y_points: torch.Tensor, - high_precision: Optional[bool] = True, ): r""" Computes the Fourier transform of a splined radial function. @@ -228,7 +221,6 @@ def compute_spline_ft( :param x_points: Abscissas of the splining points for the real-space function :param y_points: Ordinates of the splining points for the real-space function :param d2y_points: Second derivatives for the spline points - :param high_accuracy: bool, perform calculation in double precision :return: The radial Fourier transform :math:`\hat{f}(k)` computed at the ``k_points`` provided. @@ -244,8 +236,6 @@ def compute_spline_ft( # chooses precision for the FT evaluation dtype = x_points.dtype - if high_precision: - dtype = torch.float64 # broadcast to compute at once on all k values. # all these are terms that enter the analytical integral. diff --git a/src/torchpme/potentials/combined.py b/src/torchpme/potentials/combined.py index 010314657..dc67cdc76 100644 --- a/src/torchpme/potentials/combined.py +++ b/src/torchpme/potentials/combined.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional import torch diff --git a/src/torchpme/potentials/coulomb.py b/src/torchpme/potentials/coulomb.py index 0f33bb9e4..862dd0e5f 100644 --- a/src/torchpme/potentials/coulomb.py +++ b/src/torchpme/potentials/coulomb.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional import torch @@ -59,7 +59,7 @@ def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor: "Cannot compute long-range contribution without specifying `smearing`." ) - return torch.erf(dist / self.smearing / 2.0 ** 0.5) / dist + return torch.erf(dist / self.smearing / 2.0**0.5) / dist def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: r""" diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py index 3ac0318f5..1e7f945f0 100644 --- a/src/torchpme/potentials/inversepowerlaw.py +++ b/src/torchpme/potentials/inversepowerlaw.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional import torch from torch.special import gammainc @@ -43,9 +43,7 @@ def __init__( # function call to check the validity of the exponent gammaincc_over_powerlaw(exponent, torch.tensor(1.0)) - self.register_buffer( - "exponent", torch.tensor(float(exponent)) - ) + self.register_buffer("exponent", torch.tensor(exponent, dtype=torch.float64)) @torch.jit.export def from_dist(self, dist: torch.Tensor) -> torch.Tensor: @@ -99,7 +97,9 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: ) peff = (3 - self.exponent) / 2 - prefac = torch.pi**1.5 / gamma(self.exponent / 2) * (2 * self.smearing**2) ** peff + prefac = ( + torch.pi**1.5 / gamma(self.exponent / 2) * (2 * self.smearing**2) ** peff + ) x = 0.5 * self.smearing**2 * k_sq # The k=0 term often needs to be set separately since for exponents p<=3 diff --git a/src/torchpme/potentials/potential.py b/src/torchpme/potentials/potential.py index 79e4e8fc3..538acf04b 100644 --- a/src/torchpme/potentials/potential.py +++ b/src/torchpme/potentials/potential.py @@ -1,7 +1,8 @@ -from typing import Optional, Union +from typing import Optional import torch + class Potential(torch.nn.Module): r""" Base class defining the interface for a pair potential energy function @@ -39,7 +40,9 @@ def __init__( super().__init__() if smearing is not None: - self.register_buffer("smearing", torch.tensor(smearing)) + self.register_buffer( + "smearing", torch.tensor(smearing, dtype=torch.float64) + ) else: self.smearing = None diff --git a/src/torchpme/potentials/spline.py b/src/torchpme/potentials/spline.py index 09fca0594..2419f641a 100644 --- a/src/torchpme/potentials/spline.py +++ b/src/torchpme/potentials/spline.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional import torch @@ -63,7 +63,7 @@ def __init__( if len(y_grid) != len(r_grid): raise ValueError("Length of radial grid and value array mismatch.") - + self.register_buffer("r_grid", r_grid) self.register_buffer("y_grid", y_grid) @@ -106,14 +106,14 @@ def __init__( if y_at_zero is None: self._y_at_zero = self._spline( - torch.zeros(1, dtype=self.dtype, device=self.device) + torch.zeros(1, dtype=self.r_grid.dtype, device=self.r_grid.device) ) else: self._y_at_zero = y_at_zero if yhat_at_zero is None: self._yhat_at_zero = self._krn_spline( - torch.zeros(1, dtype=self.dtype, device=self.device) + torch.zeros(1, dtype=self.k_grid.dtype, device=self.k_grid.device) ) else: self._yhat_at_zero = yhat_at_zero diff --git a/src/torchpme/tuning/ewald.py b/src/torchpme/tuning/ewald.py index d2776ec0d..b1116245d 100644 --- a/src/torchpme/tuning/ewald.py +++ b/src/torchpme/tuning/ewald.py @@ -1,5 +1,5 @@ import math -from typing import Any, Optional, Union +from typing import Any from warnings import warn import torch diff --git a/src/torchpme/tuning/p3m.py b/src/torchpme/tuning/p3m.py index 20c403d9a..7ad6e4db1 100644 --- a/src/torchpme/tuning/p3m.py +++ b/src/torchpme/tuning/p3m.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import Any, Optional, Union +from typing import Any from warnings import warn import torch diff --git a/src/torchpme/tuning/pme.py b/src/torchpme/tuning/pme.py index f654fbd7e..68d80f0ba 100644 --- a/src/torchpme/tuning/pme.py +++ b/src/torchpme/tuning/pme.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import Any, Optional, Union +from typing import Any from warnings import warn import torch diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py index 192058600..3887c1882 100644 --- a/src/torchpme/tuning/tuner.py +++ b/src/torchpme/tuning/tuner.py @@ -1,6 +1,6 @@ import math import time -from typing import Optional, Union +from typing import Optional import torch @@ -234,7 +234,7 @@ def _timing(self, smearing: float, k_space_params: dict): ), **k_space_params, ) - + calculator.to(device=self.positions.device, dtype=self.positions.dtype) return self.time_func(calculator) diff --git a/tests/calculators/test_calculator.py b/tests/calculators/test_calculator.py index 8c953e429..b5e8bfe85 100644 --- a/tests/calculators/test_calculator.py +++ b/tests/calculators/test_calculator.py @@ -78,9 +78,7 @@ def test_invalid_shape_cell(): def test_invalid_dtype_cell(): calculator = CalculatorTest() - match = ( - r"type of `cell` \(torch.float64\) must be same as that of the `positions` class \(torch.float32\)" - ) + match = r"type of `cell` \(torch.float64\) must be same as that of the `positions` class \(torch.float32\)" with pytest.raises(TypeError, match=match): calculator.forward( positions=POSITIONS_1, diff --git a/tests/calculators/test_values_direct.py b/tests/calculators/test_values_direct.py index f2925f2ab..867137c00 100644 --- a/tests/calculators/test_values_direct.py +++ b/tests/calculators/test_values_direct.py @@ -18,7 +18,8 @@ class CalculatorTest(Calculator): def __init__(self, **kwargs): super().__init__( potential=CoulombPotential( - smearing=None, exclusion_radius=None, + smearing=None, + exclusion_radius=None, ), **kwargs, ) diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py index c8ae2d507..a4fc4174e 100644 --- a/tests/calculators/test_workflow.py +++ b/tests/calculators/test_workflow.py @@ -16,7 +16,6 @@ PMECalculator, ) - DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"] DTYPES = [torch.float32, torch.float64] SMEARING = 0.1 @@ -32,35 +31,27 @@ ( Calculator, { - "potential": CoulombPotential( - smearing=None - ), + "potential": CoulombPotential(smearing=None), }, ), ( EwaldCalculator, { - "potential": CoulombPotential( - smearing=SMEARING - ), + "potential": CoulombPotential(smearing=SMEARING), "lr_wavelength": LR_WAVELENGTH, }, ), ( PMECalculator, { - "potential": CoulombPotential( - smearing=SMEARING - ), + "potential": CoulombPotential(smearing=SMEARING), "mesh_spacing": MESH_SPACING, }, ), ( P3MCalculator, { - "potential": CoulombPotential( - smearing=SMEARING - ), + "potential": CoulombPotential(smearing=SMEARING), "mesh_spacing": MESH_SPACING, }, ), @@ -69,7 +60,6 @@ class TestWorkflow: def cscl_system(self, device=None, dtype=None): """CsCl crystal. Same as in the madelung test""" - positions = torch.tensor( [[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype, device=device ) @@ -109,7 +99,6 @@ def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype def test_dtype_device(self, CalculatorClass, params, device, dtype): """Test that the output dtype and device are the same as the input.""" - print(params["potential"].smearing) calculator = CalculatorClass(**params) calculator.to(device=device, dtype=dtype) potential = calculator(*self.cscl_system(device=device, dtype=dtype)) @@ -163,10 +152,8 @@ def test_prefactor(self, CalculatorClass, params, device, dtype): def test_not_nan(self, CalculatorClass, params, device, dtype): """Make sure derivatives are not NaN.""" - params = params.copy() - params["potential"] = params["potential"](dtype, device) - - calculator = CalculatorClass(**params, device=device, dtype=dtype) + calculator = CalculatorClass(**params) + calculator.to(device=device, dtype=dtype) system = self.cscl_system(device=device, dtype=dtype) system[0].requires_grad = True system[1].requires_grad = True @@ -193,29 +180,6 @@ def test_not_nan(self, CalculatorClass, params, device, dtype): torch.autograd.grad(energy, system[2], retain_graph=True)[0] ).any() - def test_dtype_and_device_incompatability( - self, CalculatorClass, params, device, dtype - ): - """Test that the calculator raises an error if the dtype and device are incompatible with potential.""" - params = params.copy() - - other_dtype = torch.float32 if dtype == torch.float64 else torch.float64 - params["potential"] = params["potential"](dtype, device) - - match = ( - rf"dtype of `potential` \({params['potential'].dtype}\) must be same as " - rf"of `calculator` \({other_dtype}\)" - ) - with pytest.raises(TypeError, match=match): - CalculatorClass(**params, dtype=other_dtype, device=device) - - match = ( - rf"device of `potential` \({params['potential'].device}\) must be same as " - rf"of `calculator` \(meta\)" - ) - with pytest.raises(ValueError, match=match): - CalculatorClass(**params, dtype=dtype, device=torch.device("meta")) - def test_potential_and_calculator_incompatability( self, CalculatorClass, @@ -224,34 +188,8 @@ def test_potential_and_calculator_incompatability( dtype, ): """Test that the calculator raises an error if the potential and calculator are incompatible.""" - params = params.copy() - params["potential"] = params["potential"](dtype, device) - params["potential"] = torch.jit.script(params["potential"]) with pytest.raises( TypeError, match="Potential must be an instance of Potential, got.*" ): - CalculatorClass(**params, device=device, dtype=dtype) - - def test_device_string_compatability(self, CalculatorClass, params, dtype, device): - """Test that the calculator works with device strings.""" - params = params.copy() - params["potential"] = params["potential"](dtype, "cpu") - calculator = CalculatorClass( - **params, - device=torch.device("cpu"), - dtype=dtype, - ) - - assert calculator.device == params["potential"].device - - def test_device_index_compatability(self, CalculatorClass, params, dtype, device): - """Test that the calculator works with no index on the device.""" - if torch.cuda.is_available(): - params = params.copy() - params["potential"] = params["potential"](dtype, "cuda") - calculator = CalculatorClass( - **params, device=torch.device("cuda:0"), dtype=dtype - ) - - assert calculator.device == params["potential"].device + CalculatorClass(**params) diff --git a/tests/helpers.py b/tests/helpers.py index 9374e242a..f553bff64 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -15,7 +15,6 @@ def define_crystal(crystal_name="CsCl", dtype=None, device=None): - # Define all relevant parameters (atom positions, charges, cell) of the reference # crystal structures for which the Madelung constants obtained from the Ewald sums # are compared with reference values. diff --git a/tests/lib/test_splines.py b/tests/lib/test_splines.py index e85761887..f5595e3ed 100644 --- a/tests/lib/test_splines.py +++ b/tests/lib/test_splines.py @@ -59,8 +59,12 @@ def test_inverse_spline(function): @pytest.mark.parametrize("high_accuracy", [True, False]) def test_ft_accuracy(high_accuracy): - x_grid = torch.linspace(0, 20, 2000, dtype=torch.float32) - y_grid = torch.exp(-(x_grid**2) * 0.5) + if high_accuracy: + x_grid = torch.linspace(0, 20, 2000, dtype=torch.float64) + y_grid = torch.exp(-(x_grid**2) * 0.5) + else: + x_grid = torch.linspace(0, 20, 2000, dtype=torch.float32) + y_grid = torch.exp(-(x_grid**2) * 0.5) k_grid = torch.linspace(0, 20, 20, dtype=torch.float32) krn = compute_spline_ft( @@ -68,9 +72,9 @@ def test_ft_accuracy(high_accuracy): x_points=x_grid, y_points=y_grid, d2y_points=compute_second_derivatives( - x_points=x_grid, y_points=y_grid, high_precision=high_accuracy + x_points=x_grid, + y_points=y_grid, ), - high_precision=high_accuracy, ) krn_ref = torch.exp(-(k_grid**2) * 0.5) * (2 * torch.pi) ** (3 / 2) diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py index 141977d2a..c8f9681f3 100644 --- a/tests/metatensor/test_workflow_metatensor.py +++ b/tests/metatensor/test_workflow_metatensor.py @@ -7,7 +7,6 @@ from packaging import version import torchpme -from torchpme._utils import _get_device, _get_dtype mts_torch = pytest.importorskip("metatensor.torch") mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") @@ -27,35 +26,27 @@ ( torchpme.metatensor.Calculator, { - "potential": lambda dtype, device: torchpme.CoulombPotential( - smearing=None, dtype=dtype, device=device - ), + "potential": torchpme.CoulombPotential(smearing=None), }, ), ( torchpme.metatensor.EwaldCalculator, { - "potential": lambda dtype, device: torchpme.CoulombPotential( - smearing=SMEARING, dtype=dtype, device=device - ), + "potential": torchpme.CoulombPotential(smearing=SMEARING), "lr_wavelength": LR_WAVELENGTH, }, ), ( torchpme.metatensor.PMECalculator, { - "potential": lambda dtype, device: torchpme.CoulombPotential( - smearing=SMEARING, dtype=dtype, device=device - ), + "potential": torchpme.CoulombPotential(smearing=SMEARING), "mesh_spacing": MESH_SPACING, }, ), ( torchpme.metatensor.P3MCalculator, { - "potential": lambda dtype, device: torchpme.CoulombPotential( - smearing=SMEARING, dtype=dtype, device=device - ), + "potential": torchpme.CoulombPotential(smearing=SMEARING), "mesh_spacing": MESH_SPACING, }, ), @@ -63,9 +54,6 @@ ) class TestWorkflow: def system(self, device=None, dtype=None): - device = _get_device(device) - dtype = _get_dtype(dtype) - system = mts_atomistic.System( types=torch.tensor([1, 2, 2]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.2], [0.0, 0.0, 0.5]]), @@ -118,24 +106,21 @@ def check_operation(self, calculator, device, dtype): def test_operation_as_python(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a normal python script""" - params = params.copy() - params["potential"] = params["potential"](dtype, device) - calculator = CalculatorClass(**params, device=device, dtype=dtype) + calculator = CalculatorClass(**params) + calculator.to(device=device, dtype=dtype) self.check_operation(calculator=calculator, device=device, dtype=dtype) def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a compiled torch script module.""" - params = params.copy() - params["potential"] = params["potential"](dtype, device) - calculator = CalculatorClass(**params, device=device, dtype=dtype) + calculator = CalculatorClass(**params) + calculator.to(device=device, dtype=dtype) scripted = torch.jit.script(calculator) self.check_operation(calculator=scripted, device=device, dtype=dtype) def test_save_load(self, CalculatorClass, params, device, dtype): - params = params.copy() - params["potential"] = params["potential"](dtype, device) - - calculator = CalculatorClass(**params, device=device, dtype=dtype) + """Save and load a compiled torch script module.""" + calculator = CalculatorClass(**params) + calculator.to(device=device, dtype=dtype) scripted = torch.jit.script(calculator) with io.BytesIO() as buffer: torch.jit.save(scripted, buffer) diff --git a/tests/test_potentials.py b/tests/test_potentials.py index 087937490..54398aa50 100644 --- a/tests/test_potentials.py +++ b/tests/test_potentials.py @@ -67,8 +67,8 @@ def test_sr_lr_split(exponent, smearing): potential. """ # Compute diverse potentials for this inverse power law - ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype) - + ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing) + ipl.to(dtype=dtype) potential_from_dist = ipl.from_dist(dists) potential_sr_from_dist = ipl.sr_from_dist(dists) potential_lr_from_dist = ipl.lr_from_dist(dists) @@ -96,10 +96,9 @@ def test_exact_sr(exponent, smearing): """ # Compute SR part of Coulomb potential using the potentials class working for any # exponent - ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype) - + ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing) + ipl.to(dtype=dtype) potential_sr_from_dist = ipl.sr_from_dist(dists) - # Compute exact analytical expression obtained for relevant exponents potential_1 = erfc(dists / SQRT2 / smearing) / dists potential_2 = torch.exp(-0.5 * dists_sq / smearing**2) / dists_sq @@ -110,7 +109,6 @@ def test_exact_sr(exponent, smearing): elif exponent == 3: prefac = SQRT2 / torch.sqrt(PI) / smearing potential_exact = potential_1 / dists_sq + prefac * potential_2 - # Compare results. Large tolerance due to singular division rtol = 1e2 * machine_epsilon atol = 4e-15 @@ -130,7 +128,8 @@ def test_exact_lr(exponent, smearing): """ # Compute LR part of Coulomb potential using the potentials class working for any # exponent - ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype) + ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing) + ipl.to(dtype=dtype) potential_lr_from_dist = ipl.lr_from_dist(dists) @@ -164,7 +163,8 @@ def test_exact_fourier(exponent, smearing): """ # Compute LR part of Coulomb potential using the potentials class working for any # exponent - ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype) + ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing) + ipl.to(dtype=dtype) fourier_from_class = ipl.lr_from_k_sq(ks_sq) @@ -202,7 +202,8 @@ def test_lr_value_at_zero(exponent, smearing): """ # Get atomic density at tiny distance dist_small = torch.tensor(1e-8, dtype=dtype) - ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype) + ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing) + ipl.to(dtype=dtype) potential_close_to_zero = ipl.lr_from_dist(dist_small) @@ -279,7 +280,8 @@ class NoImplPotential(Potential): @pytest.mark.parametrize("exclusion_radius", [0.5, 1.0, 2.0]) def test_f_cutoff(exclusion_radius): - coul = CoulombPotential(exclusion_radius=exclusion_radius, dtype=dtype) + coul = CoulombPotential(exclusion_radius=exclusion_radius) + coul.to(dtype=dtype) dist = torch.tensor([0.3]) fcut = coul.f_cutoff(dist) @@ -294,8 +296,10 @@ def test_inverserp_coulomb(smearing): """ # Compute LR part of Coulomb potential using the potentials class working for any # exponent - ipl = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype) - coul = CoulombPotential(smearing=smearing, dtype=dtype) + ipl = InversePowerLawPotential(exponent=1, smearing=smearing) + ipl.to(dtype=dtype) + coul = CoulombPotential(smearing=smearing) + coul.to(dtype=dtype) ipl_from_dist = ipl.from_dist(dists) ipl_sr_from_dist = ipl.sr_from_dist(dists) @@ -371,16 +375,17 @@ def test_spline_potential_vs_coulomb(): # the approximation is not super-accurate coulomb = CoulombPotential(smearing=1.0) - x_grid = torch.logspace(-3.0, 3.0, 1000) + coulomb.to(dtype=dtype) + x_grid = torch.logspace(-3.0, 3.0, 1000, dtype=dtype) y_grid = coulomb.lr_from_dist(x_grid) spline = SplinePotential(r_grid=x_grid, y_grid=y_grid, reciprocal=True) - t_grid = torch.logspace(-torch.pi / 2, torch.pi / 2, 100) + t_grid = torch.logspace(-torch.pi / 2, torch.pi / 2, 100, dtype=dtype) z_coul = coulomb.lr_from_dist(t_grid) z_spline = spline.lr_from_dist(t_grid) assert_close(z_coul, z_spline, atol=5e-5, rtol=0) - k_grid2 = torch.logspace(-2, 1, 40) + k_grid2 = torch.logspace(-2, 1, 40, dtype=dtype) krn_coul = coulomb.kernel_from_k_sq(k_grid2) krn_spline = spline.kernel_from_k_sq(k_grid2) @@ -439,8 +444,8 @@ def forward(self, x: torch.Tensor): @pytest.mark.parametrize("smearing", smearinges) def test_combined_potential(smearing): - ipl_1 = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype) - ipl_2 = InversePowerLawPotential(exponent=2, smearing=smearing, dtype=dtype) + ipl_1 = InversePowerLawPotential(exponent=1, smearing=smearing) + ipl_2 = InversePowerLawPotential(exponent=2, smearing=smearing) ipl_1_from_dist = ipl_1.from_dist(dists) ipl_1_sr_from_dist = ipl_1.sr_from_dist(dists) @@ -461,7 +466,6 @@ def test_combined_potential(smearing): potentials=[ipl_1, ipl_2], initial_weights=weights, learnable_weights=False, - dtype=dtype, smearing=1.0, ) combined_from_dist = combined.from_dist(dists) @@ -516,53 +520,53 @@ def test_combined_potential(smearing): def test_combined_potentials_jit(smearing): # make a separate test as pytest.mark.parametrize does not work with # torch.jit.script for combined potentials - coulomb = CoulombPotential(smearing=smearing, dtype=dtype) + coulomb = CoulombPotential(smearing=smearing) + coulomb.to(dtype=dtype) x_grid = torch.logspace(-2, 2, 100, dtype=dtype) y_grid = coulomb.lr_from_dist(x_grid) # create a spline potential spline = SplinePotential( - r_grid=x_grid, y_grid=y_grid, reciprocal=True, dtype=dtype, smearing=1.0 + r_grid=x_grid, y_grid=y_grid, reciprocal=True, smearing=1.0 ) - - combo = CombinedPotential(potentials=[spline, coulomb], dtype=dtype, smearing=1.0) - mypme = PMECalculator(combo, mesh_spacing=1.0, dtype=dtype) + spline.to(dtype=dtype) + combo = CombinedPotential(potentials=[spline, coulomb], smearing=1.0) + combo.to(dtype=dtype) + mypme = PMECalculator(combo, mesh_spacing=1.0) _ = torch.jit.script(mypme) def test_combined_potential_incompatability(): - coulomb1 = CoulombPotential(smearing=1.0, dtype=dtype) - coulomb2 = CoulombPotential(dtype=dtype) + coulomb1 = CoulombPotential(smearing=1.0) + coulomb2 = CoulombPotential() with pytest.raises( ValueError, match="Cannot combine direct \\(`smearing=None`\\) and range-separated \\(`smearing=float`\\) potentials.", ): - _ = CombinedPotential(potentials=[coulomb1, coulomb2], dtype=dtype) + _ = CombinedPotential(potentials=[coulomb1, coulomb2]) with pytest.raises( ValueError, match="You should specify a `smearing` when combining range-separated \\(`smearing=float`\\) potentials.", ): - _ = CombinedPotential(potentials=[coulomb1, coulomb1], dtype=dtype) + _ = CombinedPotential(potentials=[coulomb1, coulomb1]) with pytest.raises( ValueError, match="Cannot specify `smearing` when combining direct \\(`smearing=None`\\) potentials.", ): - _ = CombinedPotential( - potentials=[coulomb2, coulomb2], smearing=1.0, dtype=dtype - ) + _ = CombinedPotential(potentials=[coulomb2, coulomb2], smearing=1.0) def test_combined_potential_learnable_weights(): weights = torch.randn(2, dtype=dtype) - coulomb1 = CoulombPotential(smearing=2.0, dtype=dtype) - coulomb2 = CoulombPotential(smearing=1.0, dtype=dtype) + coulomb1 = CoulombPotential(smearing=2.0) + coulomb2 = CoulombPotential(smearing=1.0) combined = CombinedPotential( potentials=[coulomb1, coulomb2], smearing=1.0, - dtype=dtype, initial_weights=weights.clone(), learnable_weights=True, ) + combined.to(dtype=dtype) assert combined.weights.requires_grad # make a small optimization step @@ -587,17 +591,16 @@ def test_potential_device_dtype(potential_class, device, dtype): exponent = 2 if potential_class is InversePowerLawPotential: - potential = potential_class( - exponent=exponent, smearing=smearing, dtype=dtype, device=device - ) + potential = potential_class(exponent=exponent, smearing=smearing) + potential.to(device=device, dtype=dtype) elif potential_class is SplinePotential: x_grid = torch.linspace(0, 20, 100, device=device, dtype=dtype) y_grid = torch.exp(-(x_grid**2) * 0.5) - potential = potential_class( - r_grid=x_grid, y_grid=y_grid, reciprocal=False, dtype=dtype, device=device - ) + potential = potential_class(r_grid=x_grid, y_grid=y_grid, reciprocal=False) + potential.to(device=device, dtype=dtype) else: - potential = potential_class(smearing=smearing, dtype=dtype, device=device) + potential = potential_class(smearing=smearing) + potential.to(device=device, dtype=dtype) dists = torch.linspace(0.1, 10.0, 100, device=device, dtype=dtype) potential_lr = potential.lr_from_dist(dists) @@ -616,13 +619,15 @@ def test_inverserp_vs_spline(exponent, smearing): ks_sq_grad1 = ks_sq.clone().requires_grad_(True) ks_sq_grad2 = ks_sq.clone().requires_grad_(True) # Create InversePowerLawPotential - ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype) + ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing) + ipl.to(dtype=dtype) ipl_fourier = ipl.lr_from_k_sq(ks_sq_grad1) # Create PotentialSpline r_grid = torch.logspace(-5, 2, 1000, dtype=dtype) y_grid = ipl.lr_from_dist(r_grid) - spline = SplinePotential(r_grid=r_grid, y_grid=y_grid, dtype=dtype) + spline = SplinePotential(r_grid=r_grid, y_grid=y_grid) + spline.to(dtype=dtype) spline_fourier = spline.lr_from_k_sq(ks_sq_grad2) # Test agreement between InversePowerLawPotential and SplinePotential diff --git a/tests/tuning/test_timer.py b/tests/tuning/test_timer.py index 4481d2fa2..2100e877f 100644 --- a/tests/tuning/test_timer.py +++ b/tests/tuning/test_timer.py @@ -41,7 +41,6 @@ def test_timer(): calculator = EwaldCalculator( potential=CoulombPotential(smearing=1.0), lr_wavelength=0.25, - dtype=DTYPE, ) timing_1 = TuningTimings( @@ -50,7 +49,6 @@ def test_timer(): positions=pos, neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, - dtype=DTYPE, n_repeat=n_repeat_1, ) @@ -60,7 +58,6 @@ def test_timer(): positions=pos, neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, - dtype=DTYPE, n_repeat=n_repeat_2, ) diff --git a/tests/tuning/test_tuning.py b/tests/tuning/test_tuning.py index cabea6d98..07cce237c 100644 --- a/tests/tuning/test_tuning.py +++ b/tests/tuning/test_tuning.py @@ -10,7 +10,6 @@ P3MCalculator, PMECalculator, ) -from torchpme._utils import _get_device, _get_dtype from torchpme.tuning import tune_ewald, tune_p3m, tune_pme from torchpme.tuning.tuner import TunerBase @@ -23,9 +22,6 @@ def system(device=None, dtype=None): - device = _get_device(device) - dtype = _get_dtype(dtype) - charges = torch.ones((4, 1), dtype=dtype, device=device) cell = torch.eye(3, dtype=dtype, device=device) positions = 0.3 * torch.arange(12, dtype=dtype, device=device).reshape((4, 3)) @@ -50,8 +46,6 @@ def test_TunerBase_init(device, dtype): cutoff=DEFAULT_CUTOFF, calculator=1.0, exponent=1, - dtype=dtype, - device=device, ) @@ -89,19 +83,16 @@ def test_parameter_choose(device, dtype, calculator, tune, param_length, accurac neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, accuracy=accuracy, - dtype=dtype, - device=device, ) assert len(params) == param_length # Compute potential and compare against target value using default hypers calc = calculator( - potential=(CoulombPotential(smearing=smearing, dtype=dtype, device=device)), - dtype=dtype, - device=device, + potential=(CoulombPotential(smearing=smearing)), **params, ) + calc.to(device=device, dtype=dtype) potentials = calc.forward( positions=pos, charges=charges, @@ -212,9 +203,7 @@ def test_invalid_cell(tune): @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_invalid_dtype_cell(tune): charges, _, positions = system() - match = ( - r"type of `cell` \(torch.float64\) must be same as the class \(torch.float32\)" - ) + match = r"type of `cell` \(torch.float64\) must be same as that of the `positions` class \(torch.float32\)" with pytest.raises(TypeError, match=match): tune( charges=charges,