Skip to content

Commit

Permalink
Add tests for main compute function
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Kazuki Huguenin-Dumittan committed Nov 30, 2023
1 parent 42c35b7 commit c585908
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 19 deletions.
6 changes: 2 additions & 4 deletions src/meshlode/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,6 @@ def _compute_single_frame(self, cell: torch.Tensor,
# Remove self contribution
if subtract_self:
self_contrib = torch.sqrt(torch.tensor(2./torch.pi)) / smearing
for i in range(n_atoms):
interpolated_potential[i] -= charges[i,0] * self_contrib

return interpolated_potential
interpolated_potential -= charges * self_contrib

return interpolated_potential
20 changes: 20 additions & 0 deletions src/meshlode/fourier_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ def generate_kvectors(self, ns: torch.Tensor) -> torch.Tensor:
def kernel_func(self, ksq : torch.Tensor,
potential_exponent: int = 1,
smearing: float = 0.2) -> torch.Tensor:
"""
Fourier transform of the Coulomb potential or more general effective 1/r**p
potentials with additional smearing to remove the singularity at the origin.
:param ksq: torch.tensor of shape (N_k,) Squared norm of the k-vectors
:param potential_exponent: Exponent of the effective 1/r**p decay
:param smearing: Broadening of the 1/r**p decay close to the origin
:returns: torch.tensor of shape (N_k,) with the values of the kernel function
G(k) evaluated at the provided (squared norms of the) k-vectors
"""
if potential_exponent == 1:
return 4*torch.pi / ksq * torch.exp(-0.5*smearing**2*ksq)
elif potential_exponent == 0:
Expand All @@ -64,6 +75,15 @@ def kernel_func(self, ksq : torch.Tensor,
raise ValueError('Only potential exponents 0 and 1 are supported')

def value_at_origin(self, potential_exponent: int = 1, smearing: float = 0.2) -> float:
"""
Since the kernel function in reciprocal space typically has a (removable)
singularity at k=0, the value at that point needs to be specified explicitly.
:param potential_exponent: Exponent of the effective 1/r**p decay
:param smearing: Broadening of the 1/r**p decay close to the origin
:returns: float of G(k=0), the value of the kernel function at the origin.
"""
if potential_exponent in [1,2,3]:
return 0.
elif potential_exponent == 0:
Expand Down
2 changes: 1 addition & 1 deletion src/meshlode/mesh_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def mesh_to_points(self, mesh_vals: torch.Tensor) -> torch.Tensor:
Absolute positions of particles in Cartesian coordinates, onto whose
locations we wish to interpolate the mesh values.
:returns: interpolated_values: torch.tensor of shape (n_channels, n_points)
:returns: interpolated_values: torch.tensor of shape (n_points, n_channels)
Values of the interpolated function.
"""
interpolated_values = (
Expand Down
105 changes: 94 additions & 11 deletions tests/test_calculators.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,123 @@
import torch
from packaging import version

from typing import List
from meshlode import calculators
from meshlode import MeshPotential
from meshlode.system import System
import pytest

from metatensor.torch import TensorMap, TensorBlock, Labels

def system() -> System:
# Define toy system consisting of a single structure for testing
def toy_system_single_frame() -> System:
return System(
species=torch.tensor([1, 1, 8, 8]),
positions=torch.tensor([[0.0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]]),
cell=torch.tensor([[10., 0, 0], [0, 10, 0], [0, 0, 10]]),
)


# Initialize the calculators. For now, only the MeshPotential is implemented.
def descriptor() -> MeshPotential:
return MeshPotential(
atomic_gaussian_width=1.,
)


# Make sure that the calculators are computing the features without raising errors,
# and returns the correct output format (TensorMap)
def check_operation(calculator):
# this only runs basic checks functionality checks, and that the code produces
# output with the right type

descriptor = calculator.compute(system())

descriptor = calculator.compute(toy_system_single_frame())
assert isinstance(descriptor, torch.ScriptObject)
if version.parse(torch.__version__) >= version.parse("2.1"):
assert descriptor._type().name() == "TensorMap"


# Run the above test as a normal python script
def test_operation_as_python():
check_operation(descriptor())


# Similar to the above, but also testing that the code can be compiled as a torch script
def test_operation_as_torch_script():
scripted = torch.jit.script(descriptor())
check_operation(scripted)


# Define a more complex toy system consisting of multiple frames, mixing three species.
def toy_system_2() -> List[System]:
# First few frames containing Nitrogen
L = 2.
frames = []
frames.append(System(species=torch.tensor([7]), positions=torch.zeros((1,3)), cell=L*2*torch.eye(3)))
frames.append(System(species=torch.tensor([7,7]), positions=torch.zeros((2,3)), cell=L*2*torch.eye(3)))
frames.append(System(species=torch.tensor([7,7,7]), positions=torch.zeros((3,3)), cell=L*2*torch.eye(3)))

# One more frame containing Na and Cl
positions = torch.tensor([[0, 0, 0], [1., 0, 0]])
cell = torch.tensor([[0, 1., 1], [1, 0, 1], [1, 1, 0]])
frames.append(System(species=torch.tensor([11,17]), positions=positions, cell=cell))

return frames

class TestMultiFrameToySystem:
# Compute TensorMap containing features for various hyperparameters, including more
# extreme values.
tensormaps_list = []
frames = toy_system_2()
for atomic_gaussian_width in [0.01, 0.3, 3.7]:
for mesh_spacing in [15.3, 0.19]:
for interpolation_order in [1,2,3,4,5]:
MP = MeshPotential(atomic_gaussian_width=atomic_gaussian_width,
mesh_spacing=mesh_spacing,
interpolation_order = interpolation_order)
tensormaps_list.append(MP.compute(frames, subtract_self=False))

@pytest.mark.parametrize("features", tensormaps_list)
def test_tensormap_labels(self, features):
# Test that the keys of the TensorMap for the toy system are correct
label_values = torch.tensor([[7,7],[7,11],[7,17],[11,7],[11,11],[11,17],
[17,7],[17,11],[17,17]])
label_names = ["species_center", "species_neighbor"]
labels_ref = Labels(names = label_names, values = label_values)

assert labels_ref == features.keys

@pytest.mark.parametrize("features", tensormaps_list)
def test_zero_blocks(self, features):
# Since the first 3 frames contain Nitrogen only, while the last frame
# only contains Na and Cl, the features should be zero
for i in [11, 17]:
# For structures in which Nitrogen is present, there will be no Na or Cl
# neighbors. There are six such center atoms in total.
block = features.block({"species_center":7, "species_neighbor":i})
assert torch.equal(block.values, torch.zeros((6,1)))

# For structures in which Na or Cl are present, there will be no Nitrogen
# neighbors.
block = features.block({"species_center":i, "species_neighbor":7})
assert torch.equal(block.values, torch.zeros((1,1)))

@pytest.mark.parametrize("features", tensormaps_list)
def test_nitrogen_blocks(self, features):
# For this toy data set:
# - the first frame contains a single atom at the origin
# - the second frame contains two atoms at the origin
# - the third frame contains three atoms at the origin
# Thus, the features should almost be identical, up to a global factor
# that is the number of atoms (that are exactly on the same position).
block = features.block({"species_center":7, "species_neighbor":7})
values = block.values[:,0] # flatten to 1d
values_ref = torch.tensor([1.,2,2,3,3,3])

# We use a slightly higher relative tolerance due to numerical errors
torch.testing.assert_close(values / values[0], values_ref, rtol=1e-6, atol=0.)

@pytest.mark.parametrize("features", tensormaps_list)
def test_nacl_blocks(self, features):
# In the NaCl structure, swapping the positions of all Na and Cl atoms leads to
# an equivalent structure (up to global translation). This leads to symmetry
# in the features: the Na-density around Cl is the same as the Cl-density around
# Na and so on.
block_nana = features.block({"species_center":11, "species_neighbor":11})
block_nacl = features.block({"species_center":11, "species_neighbor":17})
block_clna = features.block({"species_center":17, "species_neighbor":11})
block_clcl = features.block({"species_center":17, "species_neighbor":17})
torch.testing.assert_close(block_nacl.values, block_clna.values, rtol=1e-15, atol=0.)
torch.testing.assert_close(block_nana.values, block_clcl.values, rtol=1e-15, atol=0.)
50 changes: 47 additions & 3 deletions tests/test_madelung.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from torch.testing import assert_close

from meshlode.system import System
from meshlode import MeshPotential

class TestMadelung:
Expand Down Expand Up @@ -37,6 +38,7 @@ def crystal_dictionary(self):
# closest Na-Cl pair is exactly 1. The cubic unit cell
# in these units would have a length of 2.
d["NaCl"]["symbols"] = ['Na', 'Cl']
d["NaCl"]["atomic_numbers"] = torch.tensor([11, 17])
d["NaCl"]["charges"] = torch.tensor([[1., -1]]).T
d["NaCl"]["positions"] = torch.tensor([[0, 0, 0], [1., 0, 0]])
d["NaCl"]["cell"] = torch.tensor([[0, 1., 1], [1, 0, 1], [1, 1, 0]])
Expand All @@ -47,7 +49,8 @@ def crystal_dictionary(self):
# is just the usual cubic cell with side length set to one.
# The closest Cs-Cl distance is sqrt(3)/2. We thus divide
# the Madelung constant by this value to match the reference.
d["CsCl"]["symbols"] = ["Cl", "Cs"]
d["CsCl"]["symbols"] = ["Cs", "Cl"]
d["CsCl"]["atomic_numbers"] = torch.tensor([55, 17])
d["CsCl"]["charges"] = torch.tensor([[1., -1]]).T
d["CsCl"]["positions"] = torch.tensor([[0, 0, 0], [.5, .5, .5]])
d["CsCl"]["cell"] = torch.eye(3)
Expand All @@ -62,6 +65,7 @@ def crystal_dictionary(self):
# If, on the other han_pylode_without_centerd, we set the lattice constant of
# the cubic cell equal to 1, the Zn-S distance is sqrt(3)/4.
d["ZnS"]["symbols"] = ["S", "Zn"]
d["ZnS"]["atomic_numbers"] = torch.tensor([16, 30])
d["ZnS"]["charges"] = torch.tensor([[1., -1]]).T
d["ZnS"]["positions"] = torch.tensor([[0, 0, 0], [.5, .5, .5]])
d["ZnS"]["cell"] = torch.tensor([[0, 1., 1], [1, 0, 1], [1, 1, 0]])
Expand All @@ -70,8 +74,8 @@ def crystal_dictionary(self):
# ZnS (O4) in wurtzite structure (triclinic cell)
u = torch.tensor([3 / 8])
c = torch.sqrt(1 / u)

d["ZnSO4"]["symbols"] = ["S", "Zn", "S", "Zn"]
d["ZnSO4"]["atomic_numbers"] = torch.tensor([16, 30, 16, 30])
d["ZnSO4"]["charges"] = torch.tensor([[1., -1, 1, -1]]).T
d["ZnSO4"]["positions"] = torch.tensor([[.5, .5 / SQRT3, 0.],
[.5, .5 / SQRT3, u * c],
Expand Down Expand Up @@ -111,6 +115,7 @@ def test_madelung_low_order(self, crystal_dictionary, crystal_name, smearing,
energies_target = -torch.ones_like(energies)*madelung
assert_close(energies, energies_target, rtol=1e-4, atol=1e-6)


@pytest.mark.parametrize("crystal_name", crystal_list)
@pytest.mark.parametrize("smearing", [0.2, 0.12])
@pytest.mark.parametrize("interpolation_order", [3,4,5])
Expand All @@ -137,4 +142,43 @@ def test_madelung_high_order(self, crystal_dictionary, crystal_name, smearing,
subtract_self=True)
energies = potentials_mesh * charges
energies_target = -torch.ones_like(energies)*madelung
assert_close(energies, energies_target, rtol=1e-2, atol=1e-3)
assert_close(energies, energies_target, rtol=1e-2, atol=1e-3)


@pytest.mark.parametrize("crystal_name", crystal_list_powers_of_2)
@pytest.mark.parametrize("smearing", [0.1, 0.05])
@pytest.mark.parametrize("interpolation_order", [1,2])
@pytest.mark.parametrize("scaling_factor", scaling_factors)
def test_madelung_low_order_metatensor(self, crystal_dictionary, crystal_name,
smearing, scaling_factor,
interpolation_order):
"""
Same test as above but now using the main compute function of the class that is
actually facing the user and outputting in metatensor format.
"""
dic = crystal_dictionary[crystal_name]
positions = dic['positions'] * scaling_factor
cell = dic['cell'] * scaling_factor
atomic_numbers = dic['atomic_numbers']
charges = dic['charges']
madelung = dic['madelung'] / scaling_factor
mesh_spacing = smearing / 2 * scaling_factor
smearing_eff = smearing * scaling_factor
n_atoms = len(positions)
frame = System(species=atomic_numbers, positions=positions, cell=cell)
MP = MeshPotential(atomic_gaussian_width=smearing_eff,
mesh_spacing=mesh_spacing,
interpolation_order=interpolation_order)
potentials_mesh = MP.compute(frame, subtract_self=True)

# Compute the actual potential from the features
n_species = charges.shape[1]
energies = torch.zeros((n_atoms,1))
for idx_c, c in enumerate(atomic_numbers):
for idx_n, n in enumerate(atomic_numbers):
block = potentials_mesh.block({'species_center':int(c),
'species_neighbor':int(n)})
energies[idx_c] += charges[idx_c] * charges[idx_n] * block.values[0,0]

energies_ref = -madelung * torch.ones((n_atoms,1))
assert_close(energies, energies_ref, rtol=1e-4, atol=1e-6)

0 comments on commit c585908

Please sign in to comment.