From 42c35b7085c70fe82c612bcd56f94f7fddb6a2a7 Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Thu, 30 Nov 2023 12:07:06 +0100 Subject: [PATCH] Add multi-frame and torchscript support --- src/meshlode/calculators.py | 150 +++++++++++++----- src/meshlode/fourier_convolution.py | 58 ++++--- src/meshlode/mesh_interpolator.py | 31 +++- src/meshlode/system.py | 6 + tests/{calculators.py => test_calculators.py} | 13 +- tests/test_fourier_convolution.py | 3 +- 6 files changed, 188 insertions(+), 73 deletions(-) rename tests/{calculators.py => test_calculators.py} (74%) diff --git a/src/meshlode/calculators.py b/src/meshlode/calculators.py index 235d5a4f..a67cb423 100644 --- a/src/meshlode/calculators.py +++ b/src/meshlode/calculators.py @@ -9,15 +9,22 @@ Our calculator API follows the `rascaline `_ API and coding guidelines to promote usability and interoperability with existing workflows. """ -from typing import List, Optional, Union +from typing import List, Union, Tuple, Dict +# import numpy as np import torch +from torch import Tensor from metatensor.torch import Labels, TensorBlock, TensorMap from meshlode.mesh_interpolator import MeshInterpolator from meshlode.fourier_convolution import FourierSpaceConvolution from .system import System +def my_1d_tolist(x: torch.Tensor): + result: List[int] = [] + for i in x: + result.append(i.item()) + return result class MeshPotential(torch.nn.Module): """A species wise long range potential. @@ -46,19 +53,30 @@ def __init__( ): super().__init__() - self.parameters = { - "atomic_gaussian_width": atomic_gaussian_width, - "mesh_spacing": mesh_spacing, - "interpolation_order": interpolation_order, - } + self.atomic_gaussian_width = atomic_gaussian_width + self.mesh_spacing = mesh_spacing + self.interpolation_order = interpolation_order + # This function is kept to keep MeshLODE compatible with the broader pytorch + # infrastructure, which require a "forward" function. We name this function + # "compute" instead, for compatibility with other COSMO software. + def forward( + self, + systems: Union[List[System],System], + ) -> TensorMap: + """forward just calls :py:meth:`CalculatorModule.compute`""" + res = self.compute(frames=systems) + return res + # return 0. + + #@torch.jit.export def compute( self, - systems: Union[System, List[System]], - gradients: Optional[List[str]] = None, - ) -> TensorMap: + frames: Union[List[System],System], + subtract_self: bool = False, + ) -> TensorMap: """Runs a calculation with this calculator on the given ``systems``. - + TODO update this :param systems: single system or list of systems on which to run the calculation. If any of the systems' ``positions`` or ``cell`` has ``requires_grad`` set to :py:obj:`True`, then the corresponding gradients @@ -69,27 +87,87 @@ def compute( Some gradients might still be computed at runtime to allow for backward propagation. """ - - # Do actual calculations here... - block = TensorBlock( - samples=Labels.single(), + # Make sure that the compute function also works if only a single frame is + # provided as input (for convenience of users testing out the code) + if not isinstance(frames, list): + frames = [frames] + + # Generate a dictionary to map atomic species to array indices + # In general, the species are sorted according to atomic number + # and assigned the array indices 0, 1, 2,... + # Example: for H2O: H is mapped to 0 and O is mapped to 1. + all_species = [] + n_atoms_tot = 0 + for frame in frames: + n_atoms_tot += len(frame) + all_species.append(frame.species) + all_species = torch.hstack(all_species) + atomic_numbers = my_1d_tolist(torch.unique(all_species)) + n_species = len(atomic_numbers) + + # Initialize dictionary for sparse storage of the features + n_species_sq = n_species * n_species + feat_dic: Dict[int, List[torch.Tensor]] = {a:[] for a in range(n_species_sq)} + + for frame in frames: + # One-hot encoding of charge information + n_atoms = len(frame) + species = frame.species + charges = torch.zeros((n_atoms, n_species), dtype=torch.float) + for i_specie, atomic_number in enumerate(atomic_numbers): + charges[species == atomic_number, i_specie] = 1. + + # Compute the potentials + potential = self._compute_single_frame(frame.cell, frame.positions, + charges, subtract_self) + + # Reorder data into Metatensor format + for spec_center, at_num_center in enumerate(atomic_numbers): + for spec_neighbor in range(len(atomic_numbers)): + a_pair = spec_center * n_species + spec_neighbor + feat_dic[a_pair] += [potential[species==at_num_center,spec_neighbor]] + + # Assemble all computed potential values into TensorBlocks for each combination + # of species_center and species_neighbor + blocks: List[TensorBlock] = [] + for keys, values in feat_dic.items(): + spec_center = atomic_numbers[keys // n_species] + + # Generate the Labels objects for the samples and properties of the + # TensorBlock. + samples_vals: List[List[int]] = [] + for i_frame, frame in enumerate(frames): + for i_atom in range(len(frame)): + if frame.species[i_atom] == spec_center: + samples_vals.append([i_frame, i_atom]) + samples_vals_tensor = torch.tensor((samples_vals), dtype=torch.int32) + labels_samples = Labels(["structure", "center"], samples_vals_tensor) + + labels_properties = Labels(["potential"], torch.tensor([[0]])) + + block = TensorBlock( + samples=labels_samples, components=[], - properties=Labels.single(), - values=torch.tensor([[1.0]]), - ) - return TensorMap(keys=Labels.single(), blocks=[block]) + properties=labels_properties, + values=torch.hstack(values).reshape((-1,1)), + ) + + blocks.append(block) + + # Generate TensorMap from TensorBlocks by defining suitable keys + key_values: List[torch.Tensor] = [] + for spec_center in atomic_numbers: + for spec_neighbor in atomic_numbers: + key_values.append(torch.tensor([spec_center, spec_neighbor])) + key_values = torch.vstack(key_values) + labels_keys = Labels(["species_center", "species_neighbor"], key_values) + + return TensorMap(keys=labels_keys, blocks=blocks) - def forward( - self, - systems: List[System], - gradients: Optional[List[str]] = None, - ) -> TensorMap: - """forward just calls :py:meth:`CalculatorModule.compute`""" - return self.compute(systems=systems, gradients=gradients) - def _compute_single_frame(self, cell: torch.tensor, - positions: torch.tensor, charges: torch.tensor, - subtract_self=False) -> torch.tensor: + def _compute_single_frame(self, cell: torch.Tensor, + positions: torch.Tensor, charges: torch.Tensor, + subtract_self: bool = False ) -> torch.Tensor: """ Compute the "electrostatic" potential at the position of all atoms in a structure. @@ -120,14 +198,15 @@ def _compute_single_frame(self, cell: torch.tensor, :returns: torch.tensor of shape (n_atoms, n_channels) containing the potential at the position of each atom for the n_channels independent meshes separately. """ - smearing = self.parameters['atomic_gaussian_width'] - mesh_resolution = self.parameters['mesh_spacing'] - interpolation_order = self.parameters['interpolation_order'] + smearing = self.atomic_gaussian_width + mesh_resolution = self.mesh_spacing + interpolation_order = self.interpolation_order # Initializations n_atoms = len(positions) + n_channels = charges.shape[1] assert positions.shape == (n_atoms,3) - assert charges.shape == (n_atoms, 1) + assert charges.shape[0] == n_atoms # Define k-vectors if mesh_resolution is None: @@ -138,7 +217,7 @@ def _compute_single_frame(self, cell: torch.tensor, # Compute number of times each basis vector of the # reciprocal space can be scaled until the cutoff # is reached - basis_norms = torch.linalg.norm(cell, axis=1) + basis_norms = torch.linalg.norm(cell, dim=1) ns_approx = k_cutoff * basis_norms / 2 / torch.pi ns_actual_approx = 2 * ns_approx + 1 # actual number of mesh points ns = 2**torch.ceil(torch.log2(ns_actual_approx)).long() # [nx, ny, nz] @@ -147,12 +226,11 @@ def _compute_single_frame(self, cell: torch.tensor, MI = MeshInterpolator(cell, ns, interpolation_order=interpolation_order) MI.compute_interpolation_weights(positions) rho_mesh = MI.points_to_mesh(particle_weights=charges) - + # Step 2: Perform Fourier space convolution (FSC) FSC = FourierSpaceConvolution(cell) - kernel_func = lambda ksq: 4*torch.pi / ksq * torch.exp(-0.5*smearing**2*ksq) value_at_origin = 0. # charge neutrality - potential_mesh = FSC.compute(rho_mesh, kernel_func, value_at_origin) + potential_mesh = FSC.compute(rho_mesh, potential_exponent=1, smearing=smearing) # Step 3: Back interpolation interpolated_potential = MI.mesh_to_points(potential_mesh) diff --git a/src/meshlode/fourier_convolution.py b/src/meshlode/fourier_convolution.py index 932d3afc..5db71d1c 100644 --- a/src/meshlode/fourier_convolution.py +++ b/src/meshlode/fourier_convolution.py @@ -2,6 +2,7 @@ Fourier Convolution =================== """ +from typing import Callable import torch class FourierSpaceConvolution: @@ -12,10 +13,10 @@ class FourierSpaceConvolution: :param cell: torch.tensor of shape (3,3) Tensor specifying the real space unit cell of a structure, where cell[i] is the i-th basis vector """ - def __init__(self, cell: torch.tensor): - self.cell = cell + def __init__(self, cell: torch.Tensor): + self.cell: torch.Tensor = cell - def generate_kvectors(self, ns: torch.tensor) -> torch.tensor: + def generate_kvectors(self, ns: torch.Tensor) -> torch.Tensor: """ For a given unit cell, compute all reciprocal space vectors that are used to perform sums in the Fourier transformed space. @@ -45,30 +46,45 @@ def generate_kvectors(self, ns: torch.tensor) -> torch.tensor: nys_1d = torch.fft.fftfreq(ns[1]) * ns[1] nzs_1d = torch.fft.rfftfreq(ns[2]) * ns[2] # real FFT nxs, nys, nzs = torch.meshgrid(nxs_1d, nys_1d, nzs_1d, indexing='ij') - nxs = nxs.reshape((ns[0],ns[1],len(nzs_1d),1)) - nys = nys.reshape((ns[0],ns[1],len(nzs_1d),1)) - nzs = nzs.reshape((ns[0],ns[1],len(nzs_1d),1)) + nxs = nxs.reshape((int(ns[0]),int(ns[1]),len(nzs_1d),1)) + nys = nys.reshape((int(ns[0]),int(ns[1]),len(nzs_1d),1)) + nzs = nzs.reshape((int(ns[0]),int(ns[1]),len(nzs_1d),1)) k_vectors = nxs * bx + nys * by + nzs * bz return k_vectors - def compute(self, mesh_values: torch.tensor, kernel_func, value_at_origin=None) -> torch.tensor: + def kernel_func(self, ksq : torch.Tensor, + potential_exponent: int = 1, + smearing: float = 0.2) -> torch.Tensor: + if potential_exponent == 1: + return 4*torch.pi / ksq * torch.exp(-0.5*smearing**2*ksq) + elif potential_exponent == 0: + return torch.ones_like(ksq) + else: + raise ValueError('Only potential exponents 0 and 1 are supported') + + def value_at_origin(self, potential_exponent: int = 1, smearing: float = 0.2) -> float: + if potential_exponent in [1,2,3]: + return 0. + elif potential_exponent == 0: + return 1. + else: + raise ValueError('Only potential exponents 0 and 1 are supported') + + def compute(self, mesh_values: torch.Tensor, + potential_exponent: int = 1, + smearing: float = 0.2) -> torch.Tensor: """ Compute the "electrostatic potential" from the density defined on a discrete mesh. :param mesh_values: torch.tensor of shape (n_channels, nx, ny, nz) The values of the density defined on a mesh. - :param kernel_func: function - The kernel function takes k**2 as an argument and should output the Fourier - transform of the potential. For the standard Ewald summation using the - Coulomb potential of a Gaussian charge density: - :math:`G(k) = 4pi / k**2 * exp(-0.5*(sigma*k)**2)` and hence - :math:`kernel_func(x) = 4pi / x * exp(-0.5*sigma**2 * x)` since x=k**2. - :param value_at_origin: float - For some kernel functions like the one above, the value - at k=0 is singular. In such cases, it is possible to - manually specify what G(0) should be set to. + :param potential_exponent: int + The exponent in the 1/r**p decay of the effective potential, where p=1 + corresponds to the Coulomb potential, and p=0 is set as a delta-potential. + :param smearing: float + Width of the Gaussian smearing (for the Coulomb potential). :returns: torch.tensor of shape (n_channels, nx, ny, nz) The potential evaluated on the same mesh points as the provided @@ -81,16 +97,16 @@ def compute(self, mesh_values: torch.tensor, kernel_func, value_at_origin=None) # Get the relevant reciprocal space vectors (k-vectors) # and compute their norm. kvectors = self.generate_kvectors(ns) - knorm_sq = torch.sum(kvectors**2, axis=3) + knorm_sq = torch.sum(kvectors**2, dim=3) # G(k) is the Fourier transform of the Coulomb potential # generated by a Gaussian charge density # We remove the singularity at k=0 by explicitly setting its # value to be equal to zero. This mathematically corresponds # to the requirement that the net charge of the cell is zero. - G = kernel_func(knorm_sq) - if value_at_origin is not None: - G[0,0,0] = value_at_origin + #G = kernel_func(knorm_sq) + G = self.kernel_func(knorm_sq, potential_exponent=potential_exponent, smearing=smearing) + G[0,0,0] = self.value_at_origin(potential_exponent=potential_exponent, smearing=smearing) # Fourier transforms consisting of the following substeps: # 1. Fourier transform the density diff --git a/src/meshlode/mesh_interpolator.py b/src/meshlode/mesh_interpolator.py index 68e77adc..5274ef54 100644 --- a/src/meshlode/mesh_interpolator.py +++ b/src/meshlode/mesh_interpolator.py @@ -32,14 +32,27 @@ class MeshInterpolator: the interpolation order (once one moves to the 3D case). """ def __init__( - self, cell: torch.tensor, ns_mesh: torch.tensor, interpolation_order: int + self, cell: torch.Tensor, ns_mesh: torch.Tensor, interpolation_order: int ): self.cell = cell self.ns_mesh = ns_mesh self.interpolation_order = interpolation_order - def compute_1d_weights(self, x: torch.tensor) -> torch.tensor: + # Initialize the variables in which to store the intermediate + # interpolation nodes and weights + self.interpolation_weights: torch.Tensor = torch.tensor(0.) + self.x_shifts: torch.Tensor = torch.tensor(0) + self.y_shifts: torch.Tensor = torch.tensor(0) + self.z_shifts: torch.Tensor = torch.tensor(0) + self.x_indices: torch.Tensor = torch.tensor(0) + self.y_indices: torch.Tensor = torch.tensor(0) + self.z_indices: torch.Tensor = torch.tensor(0) + + + + + def compute_1d_weights(self, x: torch.Tensor) -> torch.Tensor: """ Generate the smooth interpolation weights used to smear the particles onto a mesh. @@ -95,7 +108,7 @@ def compute_1d_weights(self, x: torch.tensor) -> torch.tensor: else: raise ValueError("Only `interpolation_order` from 1 to 5 are allowed") - def compute_interpolation_weights(self, positions: torch.tensor): + def compute_interpolation_weights(self, positions: torch.Tensor): """ Compute the interpolation weights of each atom for a given cell (specified during initialization of this class). The weights are not returned, but are used @@ -154,7 +167,7 @@ def compute_interpolation_weights(self, positions: torch.tensor): self.y_indices = indices_to_interpolate[self.y_shifts, :, 1] self.z_indices = indices_to_interpolate[self.z_shifts, :, 2] - def points_to_mesh(self, particle_weights: torch.tensor) -> torch.tensor: + def points_to_mesh(self, particle_weights: torch.Tensor) -> torch.Tensor: """ Generate a discretized density from interpolation weights. It assumes that "compute_interpolation_weights" has been called before to compute all the @@ -171,8 +184,10 @@ def points_to_mesh(self, particle_weights: torch.tensor) -> torch.tensor: """ # Update mesh values by combining particle weights and interpolation weights n_channels = particle_weights.shape[1] - nx, ny, nz = self.ns_mesh - rho_mesh = torch.zeros((n_channels, nx, ny, nz)) + nx = int(self.ns_mesh[0]) + ny = int(self.ns_mesh[1]) + nz = int(self.ns_mesh[2]) + rho_mesh = torch.zeros((n_channels,nx,ny,nz)) for a in range(n_channels): rho_mesh[a].index_put_( (self.x_indices, self.y_indices, self.z_indices), @@ -187,7 +202,7 @@ def points_to_mesh(self, particle_weights: torch.tensor) -> torch.tensor: return rho_mesh - def mesh_to_points(self, mesh_vals: torch.tensor) -> torch.tensor: + def mesh_to_points(self, mesh_vals: torch.Tensor) -> torch.Tensor: """ Take a function defined on a mesh and interpolate its values on arbitrary positions. @@ -211,7 +226,7 @@ def mesh_to_points(self, mesh_vals: torch.tensor) -> torch.tensor: * self.interpolation_weights[self.y_shifts, :, 1] * self.interpolation_weights[self.z_shifts, :, 2] ) - .sum(axis=1) + .sum(dim=1) .T ) diff --git a/src/meshlode/system.py b/src/meshlode/system.py index c5173904..201c96f1 100644 --- a/src/meshlode/system.py +++ b/src/meshlode/system.py @@ -46,3 +46,9 @@ def cell(self) -> torch.Tensor: """ return self._cell + + def __len__(self) -> int: + """ + Return the number of atoms + """ + return len(self._positions) diff --git a/tests/calculators.py b/tests/test_calculators.py similarity index 74% rename from tests/calculators.py rename to tests/test_calculators.py index 92787250..3937f0a2 100644 --- a/tests/calculators.py +++ b/tests/test_calculators.py @@ -2,20 +2,21 @@ from packaging import version from meshlode import calculators +from meshlode import MeshPotential from meshlode.system import System -def system(): +def system() -> System: return System( species=torch.tensor([1, 1, 8, 8]), positions=torch.tensor([[0.0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]]), - cell=torch.tensor([[10, 0, 0], [0, 10, 0], [0, 0, 10]]), + cell=torch.tensor([[10., 0, 0], [0, 10, 0], [0, 0, 10]]), ) -def descriptor(): - return calculators.MeshPotential( - atomic_gaussian_width=1, +def descriptor() -> MeshPotential: + return MeshPotential( + atomic_gaussian_width=1., ) @@ -23,7 +24,7 @@ def check_operation(calculator): # this only runs basic checks functionality checks, and that the code produces # output with the right type - descriptor = calculator.compute(system(), gradients=["positions"]) + descriptor = calculator.compute(system()) assert isinstance(descriptor, torch.ScriptObject) if version.parse(torch.__version__) >= version.parse("2.1"): diff --git a/tests/test_fourier_convolution.py b/tests/test_fourier_convolution.py index b9b49361..1d1fb440 100644 --- a/tests/test_fourier_convolution.py +++ b/tests/test_fourier_convolution.py @@ -89,7 +89,6 @@ def test_convolution_for_delta(self, cell, mesh_vals): n_channels, nx, ny, nz = mesh_vals.shape n_fft = nx*ny*nz FSC = FourierSpaceConvolution(cell) - kernel_func = lambda ksq: torch.ones_like(ksq) - mesh_vals_new = FSC.compute(mesh_vals, kernel_func) * volume / n_fft + mesh_vals_new = FSC.compute(mesh_vals, potential_exponent=0) * volume / n_fft assert_close(mesh_vals, mesh_vals_new, rtol=1e-4, atol=1e-6) \ No newline at end of file