diff --git a/src/meshlode/calculators.py b/src/meshlode/calculators.py index 4181d07a..89ec3b4f 100644 --- a/src/meshlode/calculators.py +++ b/src/meshlode/calculators.py @@ -9,37 +9,42 @@ Our calculator API follows the `rascaline `_ API and coding guidelines to promote usability and interoperability with existing workflows. """ -from typing import List, Union, Tuple, Dict +from typing import Dict, List, Union -# import numpy as np import torch -from torch import Tensor from metatensor.torch import Labels, TensorBlock, TensorMap -from meshlode.mesh_interpolator import MeshInterpolator from meshlode.fourier_convolution import FourierSpaceConvolution -from .system import System +from meshlode.mesh_interpolator import MeshInterpolator +from meshlode.system import System + def my_1d_tolist(x: torch.Tensor): + """Auxilary function to convert torch tensor to list of integers""" result: List[int] = [] for i in x: result.append(i.item()) return result + class MeshPotential(torch.nn.Module): """A species wise long range potential. :param atomic_gaussian_width: Width of the atom-centered gaussian used to create the atomic density. + :type atomic_gaussian_width: float :param mesh_spacing: Value that determines the umber of Fourier-space grid points that will be used along each axis. - :param interpolation_order: Interpolation order for mapping onto the grid. - ``4`` equals cubic interpolation. + :type mesh_spacing: float + :param interpolation_order: Interpolation order for mapping onto the grid, where an + interpolation order of p corresponds to interpolation by a polynomial of degree + p-1 (e.g. p=4 for cubic interpolation). + :type interpolation_order: int Example ------- - >>> calculator = MeshPotential(atomic_gaussian_width=1) + >>> calculator = MeshPotential(atomic_gaussian_width=1.0) """ @@ -61,37 +66,40 @@ def __init__( # infrastructure, which require a "forward" function. We name this function # "compute" instead, for compatibility with other COSMO software. def forward( - self, - systems: Union[List[System],System], - ) -> TensorMap: - """forward just calls :py:meth:`CalculatorModule.compute`""" - res = self.compute(frames=systems) - return res - # return 0. - - #@torch.jit.export + self, + systems: Union[List[System], System], + ) -> TensorMap: + """forward just calls :py:meth:`CalculatorModule.compute`""" + res = self.compute(frames=systems) + return res + # return 0. + + # @torch.jit.export def compute( self, - frames: Union[List[System],System], + frames: Union[List[System], System], subtract_self: bool = False, - ) -> TensorMap: - """Runs a calculation with this calculator on the given ``systems``. - TODO update this - :param systems: single system or list of systems on which to run the + ) -> TensorMap: + """Compute the potential at the position of each atom for all Systems provided + in "frames". + + :param frames: single System or list of Systems on which to run the calculation. If any of the systems' ``positions`` or ``cell`` has ``requires_grad`` set to :py:obj:`True`, then the corresponding gradients are computed and registered as a custom node in the computational graph, to allow backward propagation of the gradients later. - :param gradients: List of forward gradients to keep in the output. If this is - :py:obj:`None` or an empty list ``[]``, no gradients are kept in the output. - Some gradients might still be computed at runtime to allow for backward - propagation. + :param subtract_self: bool. If set to true, subtract from the features of an + atom i the contributions to the potential arising from the "center" atom itself + (but not the periodic images). + + :return: TensorMap containing the potential of all atoms. The keys of the + tensormap are "species_center" and "species_neighbor". """ # Make sure that the compute function also works if only a single frame is # provided as input (for convenience of users testing out the code) if not isinstance(frames, list): frames = [frames] - + # Generate a dictionary to map atomic species to array indices # In general, the species are sorted according to atomic number # and assigned the array indices 0, 1, 2,... @@ -107,7 +115,7 @@ def compute( # Initialize dictionary for sparse storage of the features n_species_sq = n_species * n_species - feat_dic: Dict[int, List[torch.Tensor]] = {a:[] for a in range(n_species_sq)} + feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_species_sq)} for frame in frames: # One-hot encoding of charge information @@ -115,17 +123,20 @@ def compute( species = frame.species charges = torch.zeros((n_atoms, n_species), dtype=torch.float) for i_specie, atomic_number in enumerate(atomic_numbers): - charges[species == atomic_number, i_specie] = 1. + charges[species == atomic_number, i_specie] = 1.0 # Compute the potentials - potential = self._compute_single_frame(frame.cell, frame.positions, - charges, subtract_self) - + potential = self._compute_single_frame( + frame.cell, frame.positions, charges, subtract_self + ) + # Reorder data into Metatensor format for spec_center, at_num_center in enumerate(atomic_numbers): for spec_neighbor in range(len(atomic_numbers)): a_pair = spec_center * n_species + spec_neighbor - feat_dic[a_pair] += [potential[species==at_num_center,spec_neighbor]] + feat_dic[a_pair] += [ + potential[species == at_num_center, spec_neighbor] + ] # Assemble all computed potential values into TensorBlocks for each combination # of species_center and species_neighbor @@ -142,18 +153,18 @@ def compute( samples_vals.append([i_frame, i_atom]) samples_vals_tensor = torch.tensor((samples_vals), dtype=torch.int32) labels_samples = Labels(["structure", "center"], samples_vals_tensor) - + labels_properties = Labels(["potential"], torch.tensor([[0]])) - + block = TensorBlock( - samples=labels_samples, - components=[], - properties=labels_properties, - values=torch.hstack(values).reshape((-1,1)), + samples=labels_samples, + components=[], + properties=labels_properties, + values=torch.hstack(values).reshape((-1, 1)), ) blocks.append(block) - + # Generate TensorMap from TensorBlocks by defining suitable keys key_values: List[torch.Tensor] = [] for spec_center in atomic_numbers: @@ -162,12 +173,15 @@ def compute( key_values = torch.vstack(key_values) labels_keys = Labels(["species_center", "species_neighbor"], key_values) - return TensorMap(keys=labels_keys, blocks=blocks) + return TensorMap(keys=labels_keys, blocks=blocks) - - def _compute_single_frame(self, cell: torch.Tensor, - positions: torch.Tensor, charges: torch.Tensor, - subtract_self: bool = False ) -> torch.Tensor: + def _compute_single_frame( + self, + cell: torch.Tensor, + positions: torch.Tensor, + charges: torch.Tensor, + subtract_self: bool = False, + ) -> torch.Tensor: """ Compute the "electrostatic" potential at the position of all atoms in a structure. @@ -177,6 +191,7 @@ def _compute_single_frame(self, cell: torch.Tensor, :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian coordinates of the atoms. The implementation also works if the positions are not contained within the unit cell. + :param charges: torch.tensor of shape (n_atoms, n_channels). In the simplest case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the charge of atom i. More generally, the potential for the same atom positions @@ -190,6 +205,7 @@ def _compute_single_frame(self, cell: torch.Tensor, "Cl" potential. Subtracting these from each other, one could recover the more standard electrostatic potential in which Na and Cl have charges of +1 and -1, respectively. + :param subtract_self: bool. If set to true, the contribution to the potential of the center atom itself is subtracted, meaning that only the potential generated by the remaining atoms + periodic images of the center atom is taken into @@ -205,7 +221,7 @@ def _compute_single_frame(self, cell: torch.Tensor, # Initializations n_atoms = len(positions) n_channels = charges.shape[1] - assert positions.shape == (n_atoms,3) + assert positions.shape == (n_atoms, 3) assert charges.shape[0] == n_atoms # Define k-vectors @@ -219,17 +235,17 @@ def _compute_single_frame(self, cell: torch.Tensor, # is reached basis_norms = torch.linalg.norm(cell, dim=1) ns_approx = k_cutoff * basis_norms / 2 / torch.pi - ns_actual_approx = 2 * ns_approx + 1 # actual number of mesh points - ns = 2**torch.ceil(torch.log2(ns_actual_approx)).long() # [nx, ny, nz] + ns_actual_approx = 2 * ns_approx + 1 # actual number of mesh points + ns = 2 ** torch.ceil(torch.log2(ns_actual_approx)).long() # [nx, ny, nz] # Step 1: Smear particles onto mesh MI = MeshInterpolator(cell, ns, interpolation_order=interpolation_order) MI.compute_interpolation_weights(positions) rho_mesh = MI.points_to_mesh(particle_weights=charges) - + # Step 2: Perform Fourier space convolution (FSC) FSC = FourierSpaceConvolution(cell) - value_at_origin = 0. # charge neutrality + value_at_origin = 0.0 # charge neutrality potential_mesh = FSC.compute(rho_mesh, potential_exponent=1, smearing=smearing) # Step 3: Back interpolation @@ -237,7 +253,7 @@ def _compute_single_frame(self, cell: torch.Tensor, # Remove self contribution if subtract_self: - self_contrib = torch.sqrt(torch.tensor(2./torch.pi)) / smearing + self_contrib = torch.sqrt(torch.tensor(2.0 / torch.pi)) / smearing interpolated_potential -= charges * self_contrib - return interpolated_potential \ No newline at end of file + return interpolated_potential diff --git a/src/meshlode/fourier_convolution.py b/src/meshlode/fourier_convolution.py index feb83a78..14fde2a0 100644 --- a/src/meshlode/fourier_convolution.py +++ b/src/meshlode/fourier_convolution.py @@ -2,9 +2,9 @@ Fourier Convolution =================== """ -from typing import Callable import torch + class FourierSpaceConvolution: """ Class for handling all the steps necessary to compute the convolution f*G between @@ -13,6 +13,7 @@ class FourierSpaceConvolution: :param cell: torch.tensor of shape (3,3) Tensor specifying the real space unit cell of a structure, where cell[i] is the i-th basis vector """ + def __init__(self, cell: torch.Tensor): self.cell: torch.Tensor = cell @@ -20,7 +21,7 @@ def generate_kvectors(self, ns: torch.Tensor) -> torch.Tensor: """ For a given unit cell, compute all reciprocal space vectors that are used to perform sums in the Fourier transformed space. - + :param cell: torch.tensor of shape (3,3) Tensor specifying the real space unit cell of a structure, where cell[i] is the i-th basis vector @@ -29,71 +30,76 @@ def generate_kvectors(self, ns: torch.Tensor) -> torch.Tensor: z-direction, respectively. For faster performance during the Fast Fourier Transform (FFT) it is recommended to use values of nx, ny and nz that are powers of 2. - - :returns: torch.tensor of shape [N_k,3] Contains all reciprocal space vectors + + :return: torch.tensor of shape [N_k,3] Contains all reciprocal space vectors that will be used during Ewald summation (or related approaches). The number N_k of such vectors is given by N_k = nx * ny * nz. k_vectors[i] contains the i-th vector, where the order has no special significance. """ # Define basis vectors of the reciprocal cell - reciprocal_cell = 2*torch.pi*self.cell.inverse().T + reciprocal_cell = 2 * torch.pi * self.cell.inverse().T bx = reciprocal_cell[0] by = reciprocal_cell[1] bz = reciprocal_cell[2] - + # Generate all reciprocal space vectors nxs_1d = torch.fft.fftfreq(ns[0]) * ns[0] nys_1d = torch.fft.fftfreq(ns[1]) * ns[1] - nzs_1d = torch.fft.rfftfreq(ns[2]) * ns[2] # real FFT - nxs, nys, nzs = torch.meshgrid(nxs_1d, nys_1d, nzs_1d, indexing='ij') - nxs = nxs.reshape((int(ns[0]),int(ns[1]),len(nzs_1d),1)) - nys = nys.reshape((int(ns[0]),int(ns[1]),len(nzs_1d),1)) - nzs = nzs.reshape((int(ns[0]),int(ns[1]),len(nzs_1d),1)) + nzs_1d = torch.fft.rfftfreq(ns[2]) * ns[2] # real FFT + nxs, nys, nzs = torch.meshgrid(nxs_1d, nys_1d, nzs_1d, indexing="ij") + nxs = nxs.reshape((int(ns[0]), int(ns[1]), len(nzs_1d), 1)) + nys = nys.reshape((int(ns[0]), int(ns[1]), len(nzs_1d), 1)) + nzs = nzs.reshape((int(ns[0]), int(ns[1]), len(nzs_1d), 1)) k_vectors = nxs * bx + nys * by + nzs * bz - + return k_vectors - def kernel_func(self, ksq : torch.Tensor, - potential_exponent: int = 1, - smearing: float = 0.2) -> torch.Tensor: + def kernel_func( + self, ksq: torch.Tensor, potential_exponent: int = 1, smearing: float = 0.2 + ) -> torch.Tensor: """ Fourier transform of the Coulomb potential or more general effective 1/r**p potentials with additional smearing to remove the singularity at the origin. - + :param ksq: torch.tensor of shape (N_k,) Squared norm of the k-vectors :param potential_exponent: Exponent of the effective 1/r**p decay :param smearing: Broadening of the 1/r**p decay close to the origin - :returns: torch.tensor of shape (N_k,) with the values of the kernel function - G(k) evaluated at the provided (squared norms of the) k-vectors + :return: torch.tensor of shape (N_k,) with the values of the kernel function + G(k) evaluated at the provided (squared norms of the) k-vectors """ if potential_exponent == 1: - return 4*torch.pi / ksq * torch.exp(-0.5*smearing**2*ksq) + return 4 * torch.pi / ksq * torch.exp(-0.5 * smearing**2 * ksq) elif potential_exponent == 0: return torch.ones_like(ksq) else: - raise ValueError('Only potential exponents 0 and 1 are supported') + raise ValueError("Only potential exponents 0 and 1 are supported") - def value_at_origin(self, potential_exponent: int = 1, smearing: float = 0.2) -> float: + def value_at_origin( + self, potential_exponent: int = 1, smearing: float = 0.2 + ) -> float: """ Since the kernel function in reciprocal space typically has a (removable) singularity at k=0, the value at that point needs to be specified explicitly. - + :param potential_exponent: Exponent of the effective 1/r**p decay :param smearing: Broadening of the 1/r**p decay close to the origin - :returns: float of G(k=0), the value of the kernel function at the origin. + :return: float of G(k=0), the value of the kernel function at the origin. """ - if potential_exponent in [1,2,3]: - return 0. + if potential_exponent in [1, 2, 3]: + return 0.0 elif potential_exponent == 0: - return 1. + return 1.0 else: - raise ValueError('Only potential exponents 0 and 1 are supported') + raise ValueError("Only potential exponents 0 and 1 are supported") - def compute(self, mesh_values: torch.Tensor, - potential_exponent: int = 1, - smearing: float = 0.2) -> torch.Tensor: + def compute( + self, + mesh_values: torch.Tensor, + potential_exponent: int = 1, + smearing: float = 0.2, + ) -> torch.Tensor: """ Compute the "electrostatic potential" from the density defined on a discrete mesh. @@ -113,21 +119,25 @@ def compute(self, mesh_values: torch.Tensor, # Get shape information from mesh n_channels, nx, ny, nz = mesh_values.shape ns = torch.tensor([nx, ny, nz]) - + # Get the relevant reciprocal space vectors (k-vectors) # and compute their norm. kvectors = self.generate_kvectors(ns) knorm_sq = torch.sum(kvectors**2, dim=3) - + # G(k) is the Fourier transform of the Coulomb potential # generated by a Gaussian charge density # We remove the singularity at k=0 by explicitly setting its # value to be equal to zero. This mathematically corresponds # to the requirement that the net charge of the cell is zero. - #G = kernel_func(knorm_sq) - G = self.kernel_func(knorm_sq, potential_exponent=potential_exponent, smearing=smearing) - G[0,0,0] = self.value_at_origin(potential_exponent=potential_exponent, smearing=smearing) - + # G = kernel_func(knorm_sq) + G = self.kernel_func( + knorm_sq, potential_exponent=potential_exponent, smearing=smearing + ) + G[0, 0, 0] = self.value_at_origin( + potential_exponent=potential_exponent, smearing=smearing + ) + # Fourier transforms consisting of the following substeps: # 1. Fourier transform the density # 2. multiply by kernel in k-space @@ -136,12 +146,14 @@ def compute(self, mesh_values: torch.Tensor, # that do not introduce any extra factors of 1/n_mesh. # This is why the forward transform (fft) is called with the # normalization option 'backward' (the convention in which 1/n_mesh - # is in the backward transformation) and vice versa for the + # is in the backward transformation) and vice versa for the # inverse transform (irfft). volume = self.cell.det() - dims = (1,2,3) # dimensions along which to Fourier transform - mesh_hat = torch.fft.rfftn(mesh_values, norm='backward', dim=dims) + dims = (1, 2, 3) # dimensions along which to Fourier transform + mesh_hat = torch.fft.rfftn(mesh_values, norm="backward", dim=dims) potential_hat = mesh_hat * G - potential_mesh = torch.fft.irfftn(potential_hat, norm='forward', dim=dims) / volume + potential_mesh = ( + torch.fft.irfftn(potential_hat, norm="forward", dim=dims) / volume + ) - return potential_mesh \ No newline at end of file + return potential_mesh diff --git a/src/meshlode/mesh_interpolator.py b/src/meshlode/mesh_interpolator.py index 2d866c98..e9daeed7 100644 --- a/src/meshlode/mesh_interpolator.py +++ b/src/meshlode/mesh_interpolator.py @@ -2,7 +2,6 @@ Mesh Interpolator ================= """ - import torch @@ -31,27 +30,24 @@ class MeshInterpolator: to smoother interpolation, at a computational cost that grows cubically with the interpolation order (once one moves to the 3D case). """ + def __init__( self, cell: torch.Tensor, ns_mesh: torch.Tensor, interpolation_order: int ): - self.cell = cell self.ns_mesh = ns_mesh self.interpolation_order = interpolation_order # Initialize the variables in which to store the intermediate # interpolation nodes and weights - self.interpolation_weights: torch.Tensor = torch.tensor(0.) + self.interpolation_weights: torch.Tensor = torch.tensor(0.0) self.x_shifts: torch.Tensor = torch.tensor(0) self.y_shifts: torch.Tensor = torch.tensor(0) self.z_shifts: torch.Tensor = torch.tensor(0) self.x_indices: torch.Tensor = torch.tensor(0) self.y_indices: torch.Tensor = torch.tensor(0) self.z_indices: torch.Tensor = torch.tensor(0) - - - def compute_1d_weights(self, x: torch.Tensor) -> torch.Tensor: """ Generate the smooth interpolation weights used to smear the particles onto a @@ -64,7 +60,7 @@ def compute_1d_weights(self, x: torch.Tensor) -> torch.Tensor: :param x: torch.tensor of shape (n,) Set of relative positions in the interval [-1/2, 1/2]. - :returns: torch.tensor of shape (interpolation_order, n) + :return: torch.tensor of shape (interpolation_order, n) Interpolation weights """ # Compute weights based on the given order @@ -179,7 +175,7 @@ def points_to_mesh(self, particle_weights: torch.Tensor) -> torch.Tensor: the Na and Cl contributions to the potential separately by using a one-hot encoding of the species. - :returns: torch.tensor of shape (n_channels, n_mesh, n_mesh, n_mesh) + :return: torch.tensor of shape (n_channels, n_mesh, n_mesh, n_mesh) Discrete density """ # Update mesh values by combining particle weights and interpolation weights @@ -187,7 +183,7 @@ def points_to_mesh(self, particle_weights: torch.Tensor) -> torch.Tensor: nx = int(self.ns_mesh[0]) ny = int(self.ns_mesh[1]) nz = int(self.ns_mesh[2]) - rho_mesh = torch.zeros((n_channels,nx,ny,nz)) + rho_mesh = torch.zeros((n_channels, nx, ny, nz)) for a in range(n_channels): rho_mesh[a].index_put_( (self.x_indices, self.y_indices, self.z_indices), @@ -216,7 +212,7 @@ def mesh_to_points(self, mesh_vals: torch.Tensor) -> torch.Tensor: Absolute positions of particles in Cartesian coordinates, onto whose locations we wish to interpolate the mesh values. - :returns: interpolated_values: torch.tensor of shape (n_points, n_channels) + :return: interpolated_values: torch.tensor of shape (n_points, n_channels) Values of the interpolated function. """ interpolated_values = ( diff --git a/src/meshlode/system.py b/src/meshlode/system.py index 201c96f1..78677318 100644 --- a/src/meshlode/system.py +++ b/src/meshlode/system.py @@ -46,7 +46,7 @@ def cell(self) -> torch.Tensor: """ return self._cell - + def __len__(self) -> int: """ Return the number of atoms diff --git a/tests/test_calculators.py b/tests/test_calculators.py index 8dc46a21..83929f0e 100644 --- a/tests/test_calculators.py +++ b/tests/test_calculators.py @@ -1,27 +1,30 @@ +from typing import List + +import pytest import torch +from metatensor.torch import Labels, TensorBlock, TensorMap from packaging import version -from typing import List -from meshlode import calculators -from meshlode import MeshPotential + +from meshlode import MeshPotential, calculators from meshlode.system import System -import pytest -from metatensor.torch import TensorMap, TensorBlock, Labels # Define toy system consisting of a single structure for testing def toy_system_single_frame() -> System: return System( species=torch.tensor([1, 1, 8, 8]), positions=torch.tensor([[0.0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]]), - cell=torch.tensor([[10., 0, 0], [0, 10, 0], [0, 0, 10]]), + cell=torch.tensor([[10.0, 0, 0], [0, 10, 0], [0, 0, 10]]), ) + # Initialize the calculators. For now, only the MeshPotential is implemented. def descriptor() -> MeshPotential: return MeshPotential( - atomic_gaussian_width=1., + atomic_gaussian_width=1.0, ) + # Make sure that the calculators are computing the features without raising errors, # and returns the correct output format (TensorMap) def check_operation(calculator): @@ -30,10 +33,12 @@ def check_operation(calculator): if version.parse(torch.__version__) >= version.parse("2.1"): assert descriptor._type().name() == "TensorMap" + # Run the above test as a normal python script def test_operation_as_python(): check_operation(descriptor()) + # Similar to the above, but also testing that the code can be compiled as a torch script def test_operation_as_torch_script(): scripted = torch.jit.script(descriptor()) @@ -43,19 +48,40 @@ def test_operation_as_torch_script(): # Define a more complex toy system consisting of multiple frames, mixing three species. def toy_system_2() -> List[System]: # First few frames containing Nitrogen - L = 2. + L = 2.0 frames = [] - frames.append(System(species=torch.tensor([7]), positions=torch.zeros((1,3)), cell=L*2*torch.eye(3))) - frames.append(System(species=torch.tensor([7,7]), positions=torch.zeros((2,3)), cell=L*2*torch.eye(3))) - frames.append(System(species=torch.tensor([7,7,7]), positions=torch.zeros((3,3)), cell=L*2*torch.eye(3))) + frames.append( + System( + species=torch.tensor([7]), + positions=torch.zeros((1, 3)), + cell=L * 2 * torch.eye(3), + ) + ) + frames.append( + System( + species=torch.tensor([7, 7]), + positions=torch.zeros((2, 3)), + cell=L * 2 * torch.eye(3), + ) + ) + frames.append( + System( + species=torch.tensor([7, 7, 7]), + positions=torch.zeros((3, 3)), + cell=L * 2 * torch.eye(3), + ) + ) # One more frame containing Na and Cl - positions = torch.tensor([[0, 0, 0], [1., 0, 0]]) - cell = torch.tensor([[0, 1., 1], [1, 0, 1], [1, 1, 0]]) - frames.append(System(species=torch.tensor([11,17]), positions=positions, cell=cell)) + positions = torch.tensor([[0, 0, 0], [1.0, 0, 0]]) + cell = torch.tensor([[0, 1.0, 1], [1, 0, 1], [1, 1, 0]]) + frames.append( + System(species=torch.tensor([11, 17]), positions=positions, cell=cell) + ) return frames + class TestMultiFrameToySystem: # Compute TensorMap containing features for various hyperparameters, including more # extreme values. @@ -63,22 +89,35 @@ class TestMultiFrameToySystem: frames = toy_system_2() for atomic_gaussian_width in [0.01, 0.3, 3.7]: for mesh_spacing in [15.3, 0.19]: - for interpolation_order in [1,2,3,4,5]: - MP = MeshPotential(atomic_gaussian_width=atomic_gaussian_width, - mesh_spacing=mesh_spacing, - interpolation_order = interpolation_order) + for interpolation_order in [1, 2, 3, 4, 5]: + MP = MeshPotential( + atomic_gaussian_width=atomic_gaussian_width, + mesh_spacing=mesh_spacing, + interpolation_order=interpolation_order, + ) tensormaps_list.append(MP.compute(frames, subtract_self=False)) @pytest.mark.parametrize("features", tensormaps_list) def test_tensormap_labels(self, features): # Test that the keys of the TensorMap for the toy system are correct - label_values = torch.tensor([[7,7],[7,11],[7,17],[11,7],[11,11],[11,17], - [17,7],[17,11],[17,17]]) + label_values = torch.tensor( + [ + [7, 7], + [7, 11], + [7, 17], + [11, 7], + [11, 11], + [11, 17], + [17, 7], + [17, 11], + [17, 17], + ] + ) label_names = ["species_center", "species_neighbor"] - labels_ref = Labels(names = label_names, values = label_values) + labels_ref = Labels(names=label_names, values=label_values) assert labels_ref == features.keys - + @pytest.mark.parametrize("features", tensormaps_list) def test_zero_blocks(self, features): # Since the first 3 frames contain Nitrogen only, while the last frame @@ -86,13 +125,13 @@ def test_zero_blocks(self, features): for i in [11, 17]: # For structures in which Nitrogen is present, there will be no Na or Cl # neighbors. There are six such center atoms in total. - block = features.block({"species_center":7, "species_neighbor":i}) - assert torch.equal(block.values, torch.zeros((6,1))) + block = features.block({"species_center": 7, "species_neighbor": i}) + assert torch.equal(block.values, torch.zeros((6, 1))) # For structures in which Na or Cl are present, there will be no Nitrogen # neighbors. - block = features.block({"species_center":i, "species_neighbor":7}) - assert torch.equal(block.values, torch.zeros((1,1))) + block = features.block({"species_center": i, "species_neighbor": 7}) + assert torch.equal(block.values, torch.zeros((1, 1))) @pytest.mark.parametrize("features", tensormaps_list) def test_nitrogen_blocks(self, features): @@ -102,22 +141,26 @@ def test_nitrogen_blocks(self, features): # - the third frame contains three atoms at the origin # Thus, the features should almost be identical, up to a global factor # that is the number of atoms (that are exactly on the same position). - block = features.block({"species_center":7, "species_neighbor":7}) - values = block.values[:,0] # flatten to 1d - values_ref = torch.tensor([1.,2,2,3,3,3]) + block = features.block({"species_center": 7, "species_neighbor": 7}) + values = block.values[:, 0] # flatten to 1d + values_ref = torch.tensor([1.0, 2, 2, 3, 3, 3]) # We use a slightly higher relative tolerance due to numerical errors - torch.testing.assert_close(values / values[0], values_ref, rtol=1e-6, atol=0.) - + torch.testing.assert_close(values / values[0], values_ref, rtol=1e-6, atol=0.0) + @pytest.mark.parametrize("features", tensormaps_list) def test_nacl_blocks(self, features): # In the NaCl structure, swapping the positions of all Na and Cl atoms leads to # an equivalent structure (up to global translation). This leads to symmetry # in the features: the Na-density around Cl is the same as the Cl-density around # Na and so on. - block_nana = features.block({"species_center":11, "species_neighbor":11}) - block_nacl = features.block({"species_center":11, "species_neighbor":17}) - block_clna = features.block({"species_center":17, "species_neighbor":11}) - block_clcl = features.block({"species_center":17, "species_neighbor":17}) - torch.testing.assert_close(block_nacl.values, block_clna.values, rtol=1e-15, atol=0.) - torch.testing.assert_close(block_nana.values, block_clcl.values, rtol=1e-15, atol=0.) \ No newline at end of file + block_nana = features.block({"species_center": 11, "species_neighbor": 11}) + block_nacl = features.block({"species_center": 11, "species_neighbor": 17}) + block_clna = features.block({"species_center": 17, "species_neighbor": 11}) + block_clcl = features.block({"species_center": 17, "species_neighbor": 17}) + torch.testing.assert_close( + block_nacl.values, block_clna.values, rtol=1e-15, atol=0.0 + ) + torch.testing.assert_close( + block_nana.values, block_clcl.values, rtol=1e-15, atol=0.0 + ) diff --git a/tests/test_fourier_convolution.py b/tests/test_fourier_convolution.py index 1d1fb440..3c174500 100644 --- a/tests/test_fourier_convolution.py +++ b/tests/test_fourier_convolution.py @@ -1,26 +1,29 @@ """ Tests for Fourier space convolution class """ +import pytest import torch from torch.testing import assert_close -import pytest + from meshlode.fourier_convolution import FourierSpaceConvolution + class TestKvectorGeneration: """ Tests for the subroutine that generates all reciprocal space vectors. """ + cells = [] for i in range(6): - L = torch.rand((1,)) * 20 + 1. - cells.append(torch.randn((3,3)) * L) + L = torch.rand((1,)) * 20 + 1.0 + cells.append(torch.randn((3, 3)) * L) ns_list = [] for i in range(6): - ns_list.append(torch.randint(1,20,size=(3,))) + ns_list.append(torch.randint(1, 20, size=(3,))) - @pytest.mark.parametrize('ns', ns_list) - @pytest.mark.parametrize('cell', cells) + @pytest.mark.parametrize("ns", ns_list) + @pytest.mark.parametrize("cell", cells) def test_duality_of_kvectors(self, cell, ns): """ If a_j for j=1,2,3 are the three basis vectors of a unit cell and @@ -35,20 +38,22 @@ def test_duality_of_kvectors(self, cell, ns): # Define frequencies with the same convention as in FFT # This is essentially a manual implementation of torch.fft.fftfreq ix_refs = torch.arange(nx) - ix_refs[ix_refs>=(nx+1)//2] -= nx + ix_refs[ix_refs >= (nx + 1) // 2] -= nx iy_refs = torch.arange(ny) - iy_refs[iy_refs>=(ny+1)//2] -= ny + iy_refs[iy_refs >= (ny + 1) // 2] -= ny - for ix in range(nx): + for ix in range(nx): for iy in range(ny): - for iz in range((nz+1) // 2): - inner_prods = torch.matmul(cell, kvectors[ix,iy,iz]) / 2 / torch.pi + for iz in range((nz + 1) // 2): + inner_prods = ( + torch.matmul(cell, kvectors[ix, iy, iz]) / 2 / torch.pi + ) inner_prods = torch.round(inner_prods) - inner_prods_ref = torch.tensor([ix_refs[ix],iy_refs[iy],iz]) * 1. - assert_close(inner_prods, inner_prods_ref, atol=1e-15, rtol=0.) - - @pytest.mark.parametrize('ns', ns_list) - @pytest.mark.parametrize('cell', cells) + inner_prods_ref = torch.tensor([ix_refs[ix], iy_refs[iy], iz]) * 1.0 + assert_close(inner_prods, inner_prods_ref, atol=1e-15, rtol=0.0) + + @pytest.mark.parametrize("ns", ns_list) + @pytest.mark.parametrize("cell", cells) def test_lenghts_of_kvectors(self, cell, ns): """ Check that the lengths of the obtained kvectors satisfy the triangle @@ -56,7 +61,7 @@ def test_lenghts_of_kvectors(self, cell, ns): """ # Compute an upper bound for the norms of the kvectors # that should be obtained - reciprocal_cell = 2*torch.pi*cell.inverse().T + reciprocal_cell = 2 * torch.pi * cell.inverse().T norms_basisvecs = torch.linalg.norm(reciprocal_cell, dim=1) norm_bound = torch.sum(norms_basisvecs * ns) @@ -66,29 +71,31 @@ def test_lenghts_of_kvectors(self, cell, ns): norms_all = torch.linalg.norm(kvectors, dim=3).flatten() assert torch.all(norms_all < norm_bound) + class TestConvolution: """ Test the subroutine that performs the actual convolution in reciprocal space """ + cells = [] for i in range(6): - L = torch.rand((1,)) * 20 + 1. - cells.append(torch.randn((3,3)) * L) + L = torch.rand((1,)) * 20 + 1.0 + cells.append(torch.randn((3, 3)) * L) mesh_vals_list = [] for i in range(6): - ns = torch.randint(1,20,size=(4,)) + ns = torch.randint(1, 20, size=(4,)) n_channels, nx, ny, nz = ns - nz *= 2 # for now, last dimension needs to be even - mesh_vals_list.append(torch.randn(size=(n_channels, nx,ny,nz))) + nz *= 2 # for now, last dimension needs to be even + mesh_vals_list.append(torch.randn(size=(n_channels, nx, ny, nz))) - @pytest.mark.parametrize('mesh_vals', mesh_vals_list) - @pytest.mark.parametrize('cell', cells) + @pytest.mark.parametrize("mesh_vals", mesh_vals_list) + @pytest.mark.parametrize("cell", cells) def test_convolution_for_delta(self, cell, mesh_vals): volume = cell.det() n_channels, nx, ny, nz = mesh_vals.shape - n_fft = nx*ny*nz + n_fft = nx * ny * nz FSC = FourierSpaceConvolution(cell) mesh_vals_new = FSC.compute(mesh_vals, potential_exponent=0) * volume / n_fft - - assert_close(mesh_vals, mesh_vals_new, rtol=1e-4, atol=1e-6) \ No newline at end of file + + assert_close(mesh_vals, mesh_vals_new, rtol=1e-4, atol=1e-6) diff --git a/tests/test_madelung.py b/tests/test_madelung.py index e247741a..1377cb8d 100644 --- a/tests/test_madelung.py +++ b/tests/test_madelung.py @@ -1,12 +1,13 @@ """ Madelung tests """ -import torch import pytest +import torch from torch.testing import assert_close -from meshlode.system import System from meshlode import MeshPotential +from meshlode.system import System + class TestMadelung: """ @@ -14,9 +15,11 @@ class TestMadelung: of the structures. We thus compare the computed potential against the known exact values for some simple crystal structures. """ + scaling_factors = torch.tensor([0.5, 1.2, 3.3]) crystal_list = ["NaCl", "CsCl", "ZnS", "ZnSO4"] crystal_list_powers_of_2 = ["NaCl", "CsCl", "ZnS"] + @pytest.fixture def crystal_dictionary(self): """ @@ -26,22 +29,22 @@ def crystal_dictionary(self): by Ashcroft and Mermin. Note: Symbols and charges keys have to be sorted according to their - atomic number in ascending alternating order! For an example see + atomic number in ascending alternating order! For an example see ZnS04 in the wurtzite structure. """ # Initialize dictionary for crystal paramaters d = {k: {} for k in self.crystal_list} SQRT3 = torch.sqrt(torch.tensor(3)) - + # NaCl structure # Using a primitive unit cell, the distance between the # closest Na-Cl pair is exactly 1. The cubic unit cell # in these units would have a length of 2. - d["NaCl"]["symbols"] = ['Na', 'Cl'] + d["NaCl"]["symbols"] = ["Na", "Cl"] d["NaCl"]["atomic_numbers"] = torch.tensor([11, 17]) - d["NaCl"]["charges"] = torch.tensor([[1., -1]]).T - d["NaCl"]["positions"] = torch.tensor([[0, 0, 0], [1., 0, 0]]) - d["NaCl"]["cell"] = torch.tensor([[0, 1., 1], [1, 0, 1], [1, 1, 0]]) + d["NaCl"]["charges"] = torch.tensor([[1.0, -1]]).T + d["NaCl"]["positions"] = torch.tensor([[0, 0, 0], [1.0, 0, 0]]) + d["NaCl"]["cell"] = torch.tensor([[0, 1.0, 1], [1, 0, 1], [1, 1, 0]]) d["NaCl"]["madelung"] = 1.7476 # CsCl structure @@ -51,12 +54,11 @@ def crystal_dictionary(self): # the Madelung constant by this value to match the reference. d["CsCl"]["symbols"] = ["Cs", "Cl"] d["CsCl"]["atomic_numbers"] = torch.tensor([55, 17]) - d["CsCl"]["charges"] = torch.tensor([[1., -1]]).T - d["CsCl"]["positions"] = torch.tensor([[0, 0, 0], [.5, .5, .5]]) + d["CsCl"]["charges"] = torch.tensor([[1.0, -1]]).T + d["CsCl"]["positions"] = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) d["CsCl"]["cell"] = torch.eye(3) d["CsCl"]["madelung"] = 2 * 1.7626 / SQRT3 - # ZnS (zincblende) structure # As for NaCl, a primitive unit cell is used which makes # the lattice parameter of the cubic cell equal to 2. @@ -66,9 +68,9 @@ def crystal_dictionary(self): # the cubic cell equal to 1, the Zn-S distance is sqrt(3)/4. d["ZnS"]["symbols"] = ["S", "Zn"] d["ZnS"]["atomic_numbers"] = torch.tensor([16, 30]) - d["ZnS"]["charges"] = torch.tensor([[1., -1]]).T - d["ZnS"]["positions"] = torch.tensor([[0, 0, 0], [.5, .5, .5]]) - d["ZnS"]["cell"] = torch.tensor([[0, 1., 1], [1, 0, 1], [1, 1, 0]]) + d["ZnS"]["charges"] = torch.tensor([[1.0, -1]]).T + d["ZnS"]["positions"] = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) + d["ZnS"]["cell"] = torch.tensor([[0, 1.0, 1], [1, 0, 1], [1, 1, 0]]) d["ZnS"]["madelung"] = 2 * 1.6381 / SQRT3 # ZnS (O4) in wurtzite structure (triclinic cell) @@ -76,109 +78,133 @@ def crystal_dictionary(self): c = torch.sqrt(1 / u) d["ZnSO4"]["symbols"] = ["S", "Zn", "S", "Zn"] d["ZnSO4"]["atomic_numbers"] = torch.tensor([16, 30, 16, 30]) - d["ZnSO4"]["charges"] = torch.tensor([[1., -1, 1, -1]]).T - d["ZnSO4"]["positions"] = torch.tensor([[.5, .5 / SQRT3, 0.], - [.5, .5 / SQRT3, u * c], - [.5, -.5 / SQRT3, 0.5 * c], - [.5, -.5 / SQRT3, (.5 + u) * c]]) - d["ZnSO4"]["cell"] = torch.tensor([[.5, -0.5 * SQRT3, 0], - [.5, .5 * SQRT3, 0], - [0, 0, c]]) - + d["ZnSO4"]["charges"] = torch.tensor([[1.0, -1, 1, -1]]).T + d["ZnSO4"]["positions"] = torch.tensor( + [ + [0.5, 0.5 / SQRT3, 0.0], + [0.5, 0.5 / SQRT3, u * c], + [0.5, -0.5 / SQRT3, 0.5 * c], + [0.5, -0.5 / SQRT3, (0.5 + u) * c], + ] + ) + d["ZnSO4"]["cell"] = torch.tensor( + [[0.5, -0.5 * SQRT3, 0], [0.5, 0.5 * SQRT3, 0], [0, 0, c]] + ) + d["ZnSO4"]["madelung"] = 1.6413 / (u * c) return d @pytest.mark.parametrize("crystal_name", crystal_list_powers_of_2) @pytest.mark.parametrize("smearing", [0.1, 0.05]) - @pytest.mark.parametrize("interpolation_order", [1,2]) + @pytest.mark.parametrize("interpolation_order", [1, 2]) @pytest.mark.parametrize("scaling_factor", scaling_factors) - def test_madelung_low_order(self, crystal_dictionary, crystal_name, smearing, - scaling_factor, interpolation_order): + def test_madelung_low_order( + self, + crystal_dictionary, + crystal_name, + smearing, + scaling_factor, + interpolation_order, + ): """ For low interpolation orders, if the atoms already lie exactly on a mesh point, there are no additional errors due to smearing the charges. Thus, we can reach a relatively high accuracy. """ dic = crystal_dictionary[crystal_name] - positions = dic['positions'] * scaling_factor - cell = dic['cell'] * scaling_factor - charges = dic['charges'] - madelung = dic['madelung'] / scaling_factor + positions = dic["positions"] * scaling_factor + cell = dic["cell"] * scaling_factor + charges = dic["charges"] + madelung = dic["madelung"] / scaling_factor mesh_spacing = smearing / 2 * scaling_factor smearing_eff = smearing * scaling_factor n_atoms = len(positions) MP = MeshPotential(smearing_eff, mesh_spacing, interpolation_order) - potentials_mesh = MP._compute_single_frame(cell, positions, charges, - subtract_self=True) + potentials_mesh = MP._compute_single_frame( + cell, positions, charges, subtract_self=True + ) energies = potentials_mesh * charges - energies_target = -torch.ones_like(energies)*madelung + energies_target = -torch.ones_like(energies) * madelung assert_close(energies, energies_target, rtol=1e-4, atol=1e-6) - @pytest.mark.parametrize("crystal_name", crystal_list) @pytest.mark.parametrize("smearing", [0.2, 0.12]) - @pytest.mark.parametrize("interpolation_order", [3,4,5]) + @pytest.mark.parametrize("interpolation_order", [3, 4, 5]) @pytest.mark.parametrize("scaling_factor", scaling_factors) - def test_madelung_high_order(self, crystal_dictionary, crystal_name, smearing, - scaling_factor, interpolation_order): + def test_madelung_high_order( + self, + crystal_dictionary, + crystal_name, + smearing, + scaling_factor, + interpolation_order, + ): """ For high interpolation order, the current naive implementation used to subtract - the center contribution introduces additional errors since an atom is smeared + the center contribution introduces additional errors since an atom is smeared onto multiple mesh points, turning the short-range correction into a more complicated expression that has not yet been implemented. Thus, we use a much larger tolerance of 1e-2 for the precision needed in the calculation. """ dic = crystal_dictionary[crystal_name] - positions = dic['positions'] * scaling_factor - cell = dic['cell'] * scaling_factor - charges = dic['charges'] - madelung = dic['madelung'] / scaling_factor + positions = dic["positions"] * scaling_factor + cell = dic["cell"] * scaling_factor + charges = dic["charges"] + madelung = dic["madelung"] / scaling_factor mesh_spacing = smearing / 10 * scaling_factor smearing_eff = smearing * scaling_factor n_atoms = len(positions) MP = MeshPotential(smearing_eff, mesh_spacing, interpolation_order) - potentials_mesh = MP._compute_single_frame(cell, positions, charges, - subtract_self=True) + potentials_mesh = MP._compute_single_frame( + cell, positions, charges, subtract_self=True + ) energies = potentials_mesh * charges - energies_target = -torch.ones_like(energies)*madelung + energies_target = -torch.ones_like(energies) * madelung assert_close(energies, energies_target, rtol=1e-2, atol=1e-3) - @pytest.mark.parametrize("crystal_name", crystal_list_powers_of_2) @pytest.mark.parametrize("smearing", [0.1, 0.05]) - @pytest.mark.parametrize("interpolation_order", [1,2]) + @pytest.mark.parametrize("interpolation_order", [1, 2]) @pytest.mark.parametrize("scaling_factor", scaling_factors) - def test_madelung_low_order_metatensor(self, crystal_dictionary, crystal_name, - smearing, scaling_factor, - interpolation_order): + def test_madelung_low_order_metatensor( + self, + crystal_dictionary, + crystal_name, + smearing, + scaling_factor, + interpolation_order, + ): """ Same test as above but now using the main compute function of the class that is actually facing the user and outputting in metatensor format. """ dic = crystal_dictionary[crystal_name] - positions = dic['positions'] * scaling_factor - cell = dic['cell'] * scaling_factor - atomic_numbers = dic['atomic_numbers'] - charges = dic['charges'] - madelung = dic['madelung'] / scaling_factor + positions = dic["positions"] * scaling_factor + cell = dic["cell"] * scaling_factor + atomic_numbers = dic["atomic_numbers"] + charges = dic["charges"] + madelung = dic["madelung"] / scaling_factor mesh_spacing = smearing / 2 * scaling_factor smearing_eff = smearing * scaling_factor n_atoms = len(positions) frame = System(species=atomic_numbers, positions=positions, cell=cell) - MP = MeshPotential(atomic_gaussian_width=smearing_eff, - mesh_spacing=mesh_spacing, - interpolation_order=interpolation_order) + MP = MeshPotential( + atomic_gaussian_width=smearing_eff, + mesh_spacing=mesh_spacing, + interpolation_order=interpolation_order, + ) potentials_mesh = MP.compute(frame, subtract_self=True) # Compute the actual potential from the features n_species = charges.shape[1] - energies = torch.zeros((n_atoms,1)) + energies = torch.zeros((n_atoms, 1)) for idx_c, c in enumerate(atomic_numbers): for idx_n, n in enumerate(atomic_numbers): - block = potentials_mesh.block({'species_center':int(c), - 'species_neighbor':int(n)}) - energies[idx_c] += charges[idx_c] * charges[idx_n] * block.values[0,0] - - energies_ref = -madelung * torch.ones((n_atoms,1)) - assert_close(energies, energies_ref, rtol=1e-4, atol=1e-6) \ No newline at end of file + block = potentials_mesh.block( + {"species_center": int(c), "species_neighbor": int(n)} + ) + energies[idx_c] += charges[idx_c] * charges[idx_n] * block.values[0, 0] + + energies_ref = -madelung * torch.ones((n_atoms, 1)) + assert_close(energies, energies_ref, rtol=1e-4, atol=1e-6) diff --git a/tests/test_mesh_interpolator.py b/tests/test_mesh_interpolator.py index 640a222a..a01ae98a 100644 --- a/tests/test_mesh_interpolator.py +++ b/tests/test_mesh_interpolator.py @@ -1,20 +1,23 @@ """ Tests for mesh interpolator class """ +import pytest import torch from torch.testing import assert_close -import pytest + from meshlode.mesh_interpolator import MeshInterpolator + class TestMeshInterpolatorForward: """ Tests for the "points_to_mesh" function of the MeshInterpolator class """ + # Define parameters that are common to all tests - interpolation_order = [1,2,3,4,5] + interpolation_order = [1, 2, 3, 4, 5] - @pytest.mark.parametrize('interpolation_order', interpolation_order) - @pytest.mark.parametrize('n_mesh', torch.arange(19,26)) + @pytest.mark.parametrize("interpolation_order", interpolation_order) + @pytest.mark.parametrize("n_mesh", torch.arange(19, 26)) def test_charge_conservation_cubic(self, interpolation_order, n_mesh): """ Test that the total "charge" on the grid after the smearing the particles onto @@ -25,27 +28,28 @@ def test_charge_conservation_cubic(self, interpolation_order, n_mesh): # parameters across the various tests to reduce the number of calls n_particles = 8 n_channels = 5 - L = torch.tensor(6.28318530717) # tau + L = torch.tensor(6.28318530717) # tau # Generate inputs for interpolator class cell = torch.eye(3) * L - positions = torch.rand((n_particles,3)) * L - particle_weights = 3*torch.randn((n_particles,n_channels)) + positions = torch.rand((n_particles, 3)) * L + particle_weights = 3 * torch.randn((n_particles, n_channels)) ns_mesh = torch.tensor([n_mesh, n_mesh, n_mesh]) # Run interpolation - MI = MeshInterpolator(cell=cell, ns_mesh=ns_mesh, interpolation_order=interpolation_order) + MI = MeshInterpolator( + cell=cell, ns_mesh=ns_mesh, interpolation_order=interpolation_order + ) MI.compute_interpolation_weights(positions) mesh_values = MI.points_to_mesh(particle_weights) - + # Compare total "weight (charge)" on the mesh with the sum of the particle # contributions total_weight_target = torch.sum(particle_weights, axis=0) - total_weight = torch.sum(mesh_values, dim=(1,2,3)) + total_weight = torch.sum(mesh_values, dim=(1, 2, 3)) assert_close(total_weight, total_weight_target, rtol=3e-6, atol=3e-6) - - @pytest.mark.parametrize('interpolation_order', interpolation_order) + @pytest.mark.parametrize("interpolation_order", interpolation_order) def test_charge_conservation_general(self, interpolation_order): """ Test that the total "charge" on the grid after the smearing the particles onto @@ -58,30 +62,32 @@ def test_charge_conservation_general(self, interpolation_order): # parameters across the various tests to reduce the number of calls n_particles = 11 n_channels = 2 - L = torch.tensor(2.718281828) # e + L = torch.tensor(2.718281828) # e # Generate inputs for interpolator class - cell = torch.randn((3,3)) * L - positions = torch.rand((n_particles,3)) * L - particle_weights = 3*torch.randn((n_particles,n_channels)) + cell = torch.randn((3, 3)) * L + positions = torch.rand((n_particles, 3)) * L + particle_weights = 3 * torch.randn((n_particles, n_channels)) ns_mesh = torch.randint(11, 18, size=(3,)) # Run interpolation - MI = MeshInterpolator(cell=cell, ns_mesh=ns_mesh, interpolation_order=interpolation_order) + MI = MeshInterpolator( + cell=cell, ns_mesh=ns_mesh, interpolation_order=interpolation_order + ) MI.compute_interpolation_weights(positions) mesh_values = MI.points_to_mesh(particle_weights) - + # Compare total "weight (charge)" on the mesh with the sum of the particle # contributions total_weight_target = torch.sum(particle_weights, axis=0) - total_weight = torch.sum(mesh_values, dim=(1,2,3)) + total_weight = torch.sum(mesh_values, dim=(1, 2, 3)) assert_close(total_weight, total_weight_target, rtol=3e-6, atol=3e-6) # Since the results of the next test fail if two randomly placed atoms are # too close to one another to share the identical nearest mesh point, # we fix the seed of the random number generator - @pytest.mark.parametrize('interpolation_order', [1,2]) - @pytest.mark.parametrize('n_mesh', torch.arange(7, 13)) + @pytest.mark.parametrize("interpolation_order", [1, 2]) + @pytest.mark.parametrize("n_mesh", torch.arange(7, 13)) def test_exact_agreement(self, interpolation_order, n_mesh): """ Test that for interpolation order = 1, 2, if atoms start exactly on the mesh, @@ -93,19 +99,21 @@ def test_exact_agreement(self, interpolation_order, n_mesh): # parameters across the various tests to reduce the number of calls n_particles = 10 n_channels = 3 - L = torch.tensor(0.28209478) # 1/sqrt(4pi) + L = torch.tensor(0.28209478) # 1/sqrt(4pi) # Define all relevant quantities using random numbers # The implementation also works if the particle positions # are not contained within the unit cell - cell = torch.randn((3,3)) * L - indices = torch.randint(low=0, high=n_mesh, size=(3,n_particles)) - positions = torch.matmul(cell.T, indices / n_mesh).T - particle_weights = 3*torch.randn((n_particles, n_channels)) + cell = torch.randn((3, 3)) * L + indices = torch.randint(low=0, high=n_mesh, size=(3, n_particles)) + positions = torch.matmul(cell.T, indices / n_mesh).T + particle_weights = 3 * torch.randn((n_particles, n_channels)) ns_mesh = torch.tensor([n_mesh, n_mesh, n_mesh]) # Perform interpolation - MI = MeshInterpolator(cell=cell, ns_mesh=ns_mesh, interpolation_order=interpolation_order) + MI = MeshInterpolator( + cell=cell, ns_mesh=ns_mesh, interpolation_order=interpolation_order + ) MI.compute_interpolation_weights(positions) mesh_values = MI.points_to_mesh(particle_weights) @@ -113,7 +121,7 @@ def test_exact_agreement(self, interpolation_order, n_mesh): indices_x = indices[0] indices_y = indices[1] indices_z = indices[2] - recovered_weights = mesh_values[:,indices_x,indices_y,indices_z].T + recovered_weights = mesh_values[:, indices_x, indices_y, indices_z].T # !!! WARNING for debugging !!! # If two particles are so close to one another that @@ -128,13 +136,14 @@ class TestMeshInterpolatorBackward: """ Tests for the "mesh_to_points" function of the MeshInterpolator class """ + # Define parameters that are common to all tests - interpolation_orders = [1,2,3,4,5] + interpolation_orders = [1, 2, 3, 4, 5] random_runs = torch.arange(10) - torch.random.manual_seed(3482389) - @pytest.mark.parametrize('random_runs', random_runs) + + @pytest.mark.parametrize("random_runs", random_runs) def test_exact_invertibility_for_order_one(self, random_runs): """ For interpolation order = 1, interpolating forwards and backwards with no @@ -145,14 +154,14 @@ def test_exact_invertibility_for_order_one(self, random_runs): # parameters across the various tests to reduce the number of calls n_particles = 7 n_channels = 4 - L = torch.tensor(2.5066282) # sqrt(tau) + L = torch.tensor(2.5066282) # sqrt(tau) # Define all relevant quantities using random numbers # The implementation also works if the particle positions # are not contained within the unit cell - cell = torch.randn((3,3)) * L - positions = torch.rand((n_particles,3)) * L - particle_weights = 3*torch.randn((n_particles,n_channels)) + cell = torch.randn((3, 3)) * L + positions = torch.rand((n_particles, 3)) * L + particle_weights = 3 * torch.randn((n_particles, n_channels)) ns_mesh = torch.randint(17, 25, size=(3,)) # Smear particles onto mesh and interpolate back onto @@ -168,10 +177,9 @@ def test_exact_invertibility_for_order_one(self, random_runs): # two particles will essentially get merged into a single particle. # With the current seed of the random number generator, however, # this should not be an issue. - assert_close(particle_weights, interpolated_values, rtol=0., atol=0.) - + assert_close(particle_weights, interpolated_values, rtol=0.0, atol=0.0) - @pytest.mark.parametrize('n_mesh', torch.arange(18,31)) + @pytest.mark.parametrize("n_mesh", torch.arange(18, 31)) def test_exact_invertibility_for_order_two(self, n_mesh): """ Test for interpolation order = 2 @@ -182,15 +190,15 @@ def test_exact_invertibility_for_order_two(self, n_mesh): # parameters across the various tests to reduce the number of calls n_particles = 5 n_channels = 1 - L = torch.tensor(1.4142135) # sqrt(2) + L = torch.tensor(1.4142135) # sqrt(2) # Define all relevant quantities using random numbers # The implementation also works if the particle positions # are not contained within the unit cell - cell = torch.randn((3,3)) * L - indices = torch.randint(low=0, high=n_mesh, size=(3,n_particles)) + cell = torch.randn((3, 3)) * L + indices = torch.randint(low=0, high=n_mesh, size=(3, n_particles)) positions = torch.matmul(cell.T, indices / n_mesh).T - particle_weights = 10*torch.randn((n_particles, n_channels)) + particle_weights = 10 * torch.randn((n_particles, n_channels)) ns_mesh = torch.tensor([n_mesh, n_mesh, n_mesh]) # Smear particles onto mesh and interpolate back onto @@ -208,8 +216,8 @@ def test_exact_invertibility_for_order_two(self, n_mesh): # this should not be an issue. assert_close(particle_weights, interpolated_values, rtol=3e-4, atol=1e-6) - @pytest.mark.parametrize('random_runs', random_runs) - @pytest.mark.parametrize('interpolation_order', interpolation_orders) + @pytest.mark.parametrize("random_runs", random_runs) + @pytest.mark.parametrize("interpolation_order", interpolation_orders) def test_total_mass(self, interpolation_order, random_runs): """ interpolate on all mesh points: should yield same total mass @@ -219,40 +227,37 @@ def test_total_mass(self, interpolation_order, random_runs): # parameters across the various tests to reduce the number of calls n_particles = 13 n_channels = 3 - L = torch.tensor(1.7320508) # sqrt(3) + L = torch.tensor(1.7320508) # sqrt(3) # Define random cell and its three basis vectors # The reshaping is to make more efficient use of # broadcasting - cell = torch.randn((3,3)) * L - ax = cell[0].reshape((3,1)) - ay = cell[1].reshape((3,1)) - az = cell[2].reshape((3,1)) - + cell = torch.randn((3, 3)) * L + ax = cell[0].reshape((3, 1)) + ay = cell[1].reshape((3, 1)) + az = cell[2].reshape((3, 1)) + # Generate the vector positions of ns_mesh = torch.randint(11, 27, size=(3,)) nx, ny, nz = ns_mesh nxs_1d = torch.arange(nx) / nx nys_1d = torch.arange(ny) / ny nzs_1d = torch.arange(nz) / nz - nxs, nys, nzs = torch.meshgrid(nxs_1d, nys_1d, nzs_1d, indexing='ij') + nxs, nys, nzs = torch.meshgrid(nxs_1d, nys_1d, nzs_1d, indexing="ij") nxs = torch.flatten(nxs) nys = torch.flatten(nys) nzs = torch.flatten(nzs) positions = (ax * nxs + ay * nys + az * nzs).T # Generate mesh with random values and interpolate - MI = MeshInterpolator(cell=cell, ns_mesh=ns_mesh, interpolation_order=interpolation_order) + MI = MeshInterpolator( + cell=cell, ns_mesh=ns_mesh, interpolation_order=interpolation_order + ) MI.compute_interpolation_weights(positions) - mesh_values = torch.randn(size=(n_channels,nx,ny,nz)) * 3. + 9.3 + mesh_values = torch.randn(size=(n_channels, nx, ny, nz)) * 3.0 + 9.3 interpolated_values = MI.mesh_to_points(mesh_values) - + # Sum and test - weight_before = torch.sum(mesh_values, dim=(1,2,3)) + weight_before = torch.sum(mesh_values, dim=(1, 2, 3)) weight_after = torch.sum(interpolated_values, dim=0) torch.testing.assert_close(weight_before, weight_after, rtol=1e-5, atol=1e-6) - - - - - \ No newline at end of file