Skip to content

Commit

Permalink
Clean up formatting and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Kazuki Huguenin-Dumittan committed Dec 1, 2023
1 parent c585908 commit d814c53
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 293 deletions.
120 changes: 68 additions & 52 deletions src/meshlode/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,42 @@
Our calculator API follows the `rascaline <https://luthaf.fr/rascaline>`_ API and coding
guidelines to promote usability and interoperability with existing workflows.
"""
from typing import List, Union, Tuple, Dict
from typing import Dict, List, Union

# import numpy as np
import torch
from torch import Tensor
from metatensor.torch import Labels, TensorBlock, TensorMap

from meshlode.mesh_interpolator import MeshInterpolator
from meshlode.fourier_convolution import FourierSpaceConvolution
from .system import System
from meshlode.mesh_interpolator import MeshInterpolator
from meshlode.system import System


def my_1d_tolist(x: torch.Tensor):
"""Auxilary function to convert torch tensor to list of integers"""
result: List[int] = []
for i in x:
result.append(i.item())
return result


class MeshPotential(torch.nn.Module):
"""A species wise long range potential.
:param atomic_gaussian_width: Width of the atom-centered gaussian used to create the
atomic density.
:type atomic_gaussian_width: float
:param mesh_spacing: Value that determines the umber of Fourier-space grid points
that will be used along each axis.
:param interpolation_order: Interpolation order for mapping onto the grid.
``4`` equals cubic interpolation.
:type mesh_spacing: float
:param interpolation_order: Interpolation order for mapping onto the grid, where an
interpolation order of p corresponds to interpolation by a polynomial of degree
p-1 (e.g. p=4 for cubic interpolation).
:type interpolation_order: int
Example
-------
>>> calculator = MeshPotential(atomic_gaussian_width=1)
>>> calculator = MeshPotential(atomic_gaussian_width=1.0)
"""

Expand All @@ -61,37 +66,40 @@ def __init__(
# infrastructure, which require a "forward" function. We name this function
# "compute" instead, for compatibility with other COSMO software.
def forward(
self,
systems: Union[List[System],System],
) -> TensorMap:
"""forward just calls :py:meth:`CalculatorModule.compute`"""
res = self.compute(frames=systems)
return res
# return 0.

#@torch.jit.export
self,
systems: Union[List[System], System],
) -> TensorMap:
"""forward just calls :py:meth:`CalculatorModule.compute`"""
res = self.compute(frames=systems)
return res
# return 0.

# @torch.jit.export
def compute(
self,
frames: Union[List[System],System],
frames: Union[List[System], System],
subtract_self: bool = False,
) -> TensorMap:
"""Runs a calculation with this calculator on the given ``systems``.
TODO update this
:param systems: single system or list of systems on which to run the
) -> TensorMap:
"""Compute the potential at the position of each atom for all Systems provided
in "frames".
:param frames: single System or list of Systems on which to run the
calculation. If any of the systems' ``positions`` or ``cell`` has
``requires_grad`` set to :py:obj:`True`, then the corresponding gradients
are computed and registered as a custom node in the computational graph, to
allow backward propagation of the gradients later.
:param gradients: List of forward gradients to keep in the output. If this is
:py:obj:`None` or an empty list ``[]``, no gradients are kept in the output.
Some gradients might still be computed at runtime to allow for backward
propagation.
:param subtract_self: bool. If set to true, subtract from the features of an
atom i the contributions to the potential arising from the "center" atom itself
(but not the periodic images).
:return: TensorMap containing the potential of all atoms. The keys of the
tensormap are "species_center" and "species_neighbor".
"""
# Make sure that the compute function also works if only a single frame is
# provided as input (for convenience of users testing out the code)
if not isinstance(frames, list):
frames = [frames]

# Generate a dictionary to map atomic species to array indices
# In general, the species are sorted according to atomic number
# and assigned the array indices 0, 1, 2,...
Expand All @@ -107,25 +115,28 @@ def compute(

# Initialize dictionary for sparse storage of the features
n_species_sq = n_species * n_species
feat_dic: Dict[int, List[torch.Tensor]] = {a:[] for a in range(n_species_sq)}
feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_species_sq)}

for frame in frames:
# One-hot encoding of charge information
n_atoms = len(frame)
species = frame.species
charges = torch.zeros((n_atoms, n_species), dtype=torch.float)
for i_specie, atomic_number in enumerate(atomic_numbers):
charges[species == atomic_number, i_specie] = 1.
charges[species == atomic_number, i_specie] = 1.0

# Compute the potentials
potential = self._compute_single_frame(frame.cell, frame.positions,
charges, subtract_self)

potential = self._compute_single_frame(
frame.cell, frame.positions, charges, subtract_self
)

# Reorder data into Metatensor format
for spec_center, at_num_center in enumerate(atomic_numbers):
for spec_neighbor in range(len(atomic_numbers)):
a_pair = spec_center * n_species + spec_neighbor
feat_dic[a_pair] += [potential[species==at_num_center,spec_neighbor]]
feat_dic[a_pair] += [
potential[species == at_num_center, spec_neighbor]
]

# Assemble all computed potential values into TensorBlocks for each combination
# of species_center and species_neighbor
Expand All @@ -142,18 +153,18 @@ def compute(
samples_vals.append([i_frame, i_atom])
samples_vals_tensor = torch.tensor((samples_vals), dtype=torch.int32)
labels_samples = Labels(["structure", "center"], samples_vals_tensor)

labels_properties = Labels(["potential"], torch.tensor([[0]]))

block = TensorBlock(
samples=labels_samples,
components=[],
properties=labels_properties,
values=torch.hstack(values).reshape((-1,1)),
samples=labels_samples,
components=[],
properties=labels_properties,
values=torch.hstack(values).reshape((-1, 1)),
)

blocks.append(block)

# Generate TensorMap from TensorBlocks by defining suitable keys
key_values: List[torch.Tensor] = []
for spec_center in atomic_numbers:
Expand All @@ -162,12 +173,15 @@ def compute(
key_values = torch.vstack(key_values)
labels_keys = Labels(["species_center", "species_neighbor"], key_values)

return TensorMap(keys=labels_keys, blocks=blocks)
return TensorMap(keys=labels_keys, blocks=blocks)


def _compute_single_frame(self, cell: torch.Tensor,
positions: torch.Tensor, charges: torch.Tensor,
subtract_self: bool = False ) -> torch.Tensor:
def _compute_single_frame(
self,
cell: torch.Tensor,
positions: torch.Tensor,
charges: torch.Tensor,
subtract_self: bool = False,
) -> torch.Tensor:
"""
Compute the "electrostatic" potential at the position of all atoms in a
structure.
Expand All @@ -177,6 +191,7 @@ def _compute_single_frame(self, cell: torch.Tensor,
:param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian
coordinates of the atoms. The implementation also works if the positions are
not contained within the unit cell.
:param charges: torch.tensor of shape (n_atoms, n_channels). In the simplest
case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the
charge of atom i. More generally, the potential for the same atom positions
Expand All @@ -190,6 +205,7 @@ def _compute_single_frame(self, cell: torch.Tensor,
"Cl" potential. Subtracting these from each other, one could recover the more
standard electrostatic potential in which Na and Cl have charges of +1 and -1,
respectively.
:param subtract_self: bool. If set to true, the contribution to the potential of
the center atom itself is subtracted, meaning that only the potential generated
by the remaining atoms + periodic images of the center atom is taken into
Expand All @@ -205,7 +221,7 @@ def _compute_single_frame(self, cell: torch.Tensor,
# Initializations
n_atoms = len(positions)
n_channels = charges.shape[1]
assert positions.shape == (n_atoms,3)
assert positions.shape == (n_atoms, 3)
assert charges.shape[0] == n_atoms

# Define k-vectors
Expand All @@ -219,25 +235,25 @@ def _compute_single_frame(self, cell: torch.Tensor,
# is reached
basis_norms = torch.linalg.norm(cell, dim=1)
ns_approx = k_cutoff * basis_norms / 2 / torch.pi
ns_actual_approx = 2 * ns_approx + 1 # actual number of mesh points
ns = 2**torch.ceil(torch.log2(ns_actual_approx)).long() # [nx, ny, nz]
ns_actual_approx = 2 * ns_approx + 1 # actual number of mesh points
ns = 2 ** torch.ceil(torch.log2(ns_actual_approx)).long() # [nx, ny, nz]

# Step 1: Smear particles onto mesh
MI = MeshInterpolator(cell, ns, interpolation_order=interpolation_order)
MI.compute_interpolation_weights(positions)
rho_mesh = MI.points_to_mesh(particle_weights=charges)

# Step 2: Perform Fourier space convolution (FSC)
FSC = FourierSpaceConvolution(cell)
value_at_origin = 0. # charge neutrality
value_at_origin = 0.0 # charge neutrality
potential_mesh = FSC.compute(rho_mesh, potential_exponent=1, smearing=smearing)

# Step 3: Back interpolation
interpolated_potential = MI.mesh_to_points(potential_mesh)

# Remove self contribution
if subtract_self:
self_contrib = torch.sqrt(torch.tensor(2./torch.pi)) / smearing
self_contrib = torch.sqrt(torch.tensor(2.0 / torch.pi)) / smearing
interpolated_potential -= charges * self_contrib

return interpolated_potential
return interpolated_potential
Loading

0 comments on commit d814c53

Please sign in to comment.