Skip to content

Commit

Permalink
Calculate transitions with custom radii per site type (#265)
Browse files Browse the repository at this point in the history
* Pass site_radius dict to _calculate_atom_states

* Flip kdtree coords and center

* Do `search_tree()` in one go

* Implement site_radius specified as dict

* Add test for integer remap

* Add test for site_radius dict

* Refactor if statement out of loop

* Improve code clarity
  • Loading branch information
stefsmeets authored Feb 20, 2024
1 parent dab7b54 commit 80652c8
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 44 deletions.
8 changes: 5 additions & 3 deletions src/gemdat/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def transitions_between_sites(
self,
sites: Structure,
floating_specie: str,
site_radius: Optional[float] = None,
site_radius: float | dict[str, float] | None = None,
site_inner_fraction: float = 1.0,
) -> Transitions:
"""Compute transitions between given sites for floating specie.
Expand All @@ -518,9 +518,11 @@ def transitions_between_sites(
Input structure with known sites
floating_specie : str
Name of the floating specie to calculate transitions for
site_radius: Optional[float]
site_radius: Optional[float, dict[str, float]]
A custom site radius in Ångstrom to determine
if an atom is at a site
if an atom is at a site. A dict keyed by the site label can
be used to have a site per atom type, e.g.
`site_radius = {'Li1': 1.0, 'Li2': 1.2}.
site_inner_fraction:
A fraction of the site radius which is determined to be the `inner site`
which is used in jump calculations
Expand Down
89 changes: 57 additions & 32 deletions src/gemdat/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import typing
from collections import defaultdict
from itertools import pairwise
from typing import Optional

import numpy as np
import pandas as pd
Expand All @@ -15,7 +14,7 @@

from .caching import weak_lru_cache
from .simulation_metrics import SimulationMetrics
from .utils import bfill, ffill
from .utils import bfill, ffill, integer_remap

if typing.TYPE_CHECKING:
from gemdat.jumps import Jumps
Expand Down Expand Up @@ -80,22 +79,29 @@ def from_trajectory(
trajectory: Trajectory,
sites: Structure,
floating_specie: str,
site_radius: Optional[float] = None,
site_radius: float | dict[str, float] | None = None,
site_inner_fraction: float = 1.,
) -> Transitions:
"""Compute transitions for floating specie from trajectory and
structure with known sites.
"""Compute transitions between given sites for floating specie.
Parameters
----------
trajectory : Trajectory
Input trajectory
sites : pymatgen.core.structure.Structure
Input sites with known sites
Input structure with known sites
floating_specie : str
Name of the floating specie to calculate transitions for
site_radius: Optional[float]
A custom site size to use for determining if an atom is at a site
site_radius: float | dict[str, float] | None
A custom site radius in Ångstrom to determine
if an atom is at a site. A dict keyed by the site label can
be used to have a site per atom type, e.g.
`site_radius = {'Li1': 1.0, 'Li2': 1.2}.
site_inner_fraction:
A fraction of the site radius which is determined to be the `inner site`
which is used in jump calculations
Returns
-------
transitions: Transitions
"""
diff_trajectory = trajectory.filter(floating_specie)

Expand All @@ -108,6 +114,9 @@ def from_trajectory(
sites=sites,
vibration_amplitude=vibration_amplitude)

if isinstance(site_radius, float):
site_radius = {'': site_radius}

states = _calculate_atom_states(
sites=sites,
trajectory=diff_trajectory,
Expand Down Expand Up @@ -420,7 +429,7 @@ def _compute_site_radius(trajectory: Trajectory, sites: Structure,
def _calculate_atom_states(
sites: Structure,
trajectory: Trajectory,
site_radius: float,
site_radius: dict[str, float],
site_inner_fraction: float = 1.,
) -> np.ndarray:
"""Calculate nearest site for each atom coordinate in the trajectory.
Expand All @@ -435,8 +444,9 @@ def _calculate_atom_states(
Input sites with pre-defined sites
trajectory : Trajectory
Input trajectory for floating atoms
site_radius : float
Atoms within this distance (in Angstrom) are considered to be close to a site
site_radius : dict[str, float]
Atoms within this distance (in Angstrom) are considered to be close to a site.
Can also be a dict keyed by the site label to specify the radius by atom type.
site_inner_fraction: float
Atoms that are closer than (site_radius*site_inner_fraction) to a site, are considered
to be in the inner site
Expand All @@ -448,34 +458,49 @@ def _calculate_atom_states(
The value corresponds to the index in the `site_coords`.
-1 indicates that atom is not at any site.
"""
# Unit cell parameters

def _site_radius_iterator():
for label, radius in site_radius.items():
if label:
grouped = ((k, site) for k, site in enumerate(sites)
if site.label == label)
key, site_group = zip(*grouped)
frac_coords = np.array(
[site.frac_coords for site in site_group])
yield frac_coords, np.array(key), radius
else:
yield sites.frac_coords, None, radius

lattice = trajectory.get_lattice()

site_coords = sites.frac_coords
cutoff = max(list(site_radius.values()))

# Input array with site coordinates [site, (x, y, z)]
site_cart_coords = np.dot(site_coords, lattice.matrix)
site_coords_tree: PeriodicKDTree = PeriodicKDTree(
traj_frac_coords = trajectory.positions.reshape(-1, 3)
traj_cart_coords = lattice.get_cartesian_coords(traj_frac_coords)

periodic_tree: PeriodicKDTree = PeriodicKDTree(
box=np.array(lattice.parameters, dtype=np.float32))
site_coords_tree.set_coords(site_cart_coords, cutoff=site_radius)
periodic_tree.set_coords(traj_cart_coords, cutoff=cutoff)

shape = trajectory.positions.shape[0:2]

atom_sites = np.full((traj_cart_coords.shape[0]), NOSITE)

atom_sites = []
for coords, key, radius in _site_radius_iterator():
cart_coords = lattice.get_cartesian_coords(coords)
site_index = periodic_tree.search_tree(cart_coords,
radius * site_inner_fraction)

for atom_index, atom_coords in enumerate(
trajectory.positions.swapaxes(0, 1)):
siteno, index = site_index.T

# index and distance of nearest site
atom_cart_coords = np.dot(atom_coords, lattice.matrix)
site_index = site_coords_tree.search_tree(
atom_cart_coords, site_radius * site_inner_fraction)
if key is not None:
siteno = integer_remap(a=siteno,
key=key,
palette=np.unique(siteno))

# construct mapping
atom_site = np.full((atom_coords.shape[0], 1), NOSITE)
for index, site in site_index:
atom_site[index] = site
atom_sites.append(atom_site)
atom_sites[index] = siteno

return np.hstack(atom_sites)
return atom_sites.reshape(shape)


def _calculate_transitions_matrix(events: pd.DataFrame,
Expand Down
27 changes: 27 additions & 0 deletions src/gemdat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,33 @@ def bfill(arr: np.ndarray, fill_val: int = -1, axis=-1) -> np.ndarray:
return np.fliplr(ffill(np.fliplr(arr), fill_val=fill_val))


def integer_remap(a: np.ndarray,
key: np.ndarray,
palette: np.ndarray | None = None) -> np.ndarray:
"""Map integers in array `a` from `palette` -> `key`
Parameters
----------
a : np.ndarray
Input array with values to be
key : np.ndarray
The key gives the new values that the palette will be mapped to
palette : np.ndarray | None
Input values, must be given in sorted order.
If None, use sorted unique values in `a`
Returns
-------
np.ndarray
"""
if palette is None:
palette = np.unique(a)

index = np.digitize(a, palette, right=True)

return key[index].reshape(a.shape)


def meanfreq(x: np.ndarray, fs: float = 1.0) -> np.ndarray:
"""Estimates the mean frequency in terms of the sample rate, fs.
Expand Down
20 changes: 18 additions & 2 deletions tests/integration/transitions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,24 @@ def test_site_radius(self, vasp_traj, structure):
)
assert isclose(site_radius, 0.9284961123176741)

def test_atom_sites(self, vasp_traj, vasp_transitions):
n_steps = len(vasp_traj)
def test_site_radius_dict(self, vasp_traj, structure):
sites = structure.copy()

for site in sites[0::3]:
site.label = 'A'
for site in sites[1::3]:
site.label = 'B'
for site in sites[2::3]:
site.label = 'C'

site_radius = {'A': 0.5, 'B': 0.6, 'C': 0.7}
transitions = vasp_traj.transitions_between_sites(
sites=sites, floating_specie='Li', site_radius=site_radius)

assert transitions.states.sum() == 3445344

def test_atom_sites(self, vasp_transitions):
n_steps = 3750
n_diffusing = 48

slice_ = np.s_[::1000, ::24]
Expand Down
22 changes: 15 additions & 7 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_equal

from gemdat.utils import bfill, ffill, meanfreq
from gemdat.utils import bfill, ffill, integer_remap, meanfreq


@pytest.fixture
Expand All @@ -23,7 +24,7 @@ def test_ffill(arr):
[4, 9, 6, 6, 6],
])

np.testing.assert_equal(ret, expected)
assert_equal(ret, expected)


def test_bfill(arr):
Expand All @@ -34,7 +35,7 @@ def test_bfill(arr):
[4, 9, 6, -1, -1],
])

np.testing.assert_equal(ret, expected)
assert_equal(ret, expected)


def test_ffill_axis0(arr):
Expand All @@ -45,7 +46,7 @@ def test_ffill_axis0(arr):
[4, 9, 6, 8, 2],
])

np.testing.assert_equal(ret, expected)
assert_equal(ret, expected)


def test_bfill_axis0(arr):
Expand All @@ -56,7 +57,14 @@ def test_bfill_axis0(arr):
[4, 9, 6, -1, -1],
])

np.testing.assert_equal(ret, expected)
assert_equal(ret, expected)


def test_integer_remap():
a = np.array([4, 2, 1, 3])
key = np.array([10, 20, 30, 40])
ret = integer_remap(a, key=key)
assert_equal(ret, a * 10)


def test_meanfreq_single_timestep():
Expand All @@ -65,7 +73,7 @@ def test_meanfreq_single_timestep():

expected = np.array([[0.2303359]])

np.testing.assert_allclose(ret, expected)
assert_allclose(ret, expected)


def test_meanfreq():
Expand All @@ -78,4 +86,4 @@ def test_meanfreq():

expected = np.array([[0.2303359], [0.21308077], [0.17074241]])

np.testing.assert_allclose(ret, expected)
assert_allclose(ret, expected)

0 comments on commit 80652c8

Please sign in to comment.