diff --git a/src/gemdat/trajectory.py b/src/gemdat/trajectory.py index 653123eb..82ec555e 100644 --- a/src/gemdat/trajectory.py +++ b/src/gemdat/trajectory.py @@ -507,7 +507,7 @@ def transitions_between_sites( self, sites: Structure, floating_specie: str, - site_radius: Optional[float] = None, + site_radius: float | dict[str, float] | None = None, site_inner_fraction: float = 1.0, ) -> Transitions: """Compute transitions between given sites for floating specie. @@ -518,9 +518,11 @@ def transitions_between_sites( Input structure with known sites floating_specie : str Name of the floating specie to calculate transitions for - site_radius: Optional[float] + site_radius: Optional[float, dict[str, float]] A custom site radius in Ã…ngstrom to determine - if an atom is at a site + if an atom is at a site. A dict keyed by the site label can + be used to have a site per atom type, e.g. + `site_radius = {'Li1': 1.0, 'Li2': 1.2}. site_inner_fraction: A fraction of the site radius which is determined to be the `inner site` which is used in jump calculations diff --git a/src/gemdat/transitions.py b/src/gemdat/transitions.py index d7b63705..dd9202c0 100644 --- a/src/gemdat/transitions.py +++ b/src/gemdat/transitions.py @@ -6,7 +6,6 @@ import typing from collections import defaultdict from itertools import pairwise -from typing import Optional import numpy as np import pandas as pd @@ -15,7 +14,7 @@ from .caching import weak_lru_cache from .simulation_metrics import SimulationMetrics -from .utils import bfill, ffill +from .utils import bfill, ffill, integer_remap if typing.TYPE_CHECKING: from gemdat.jumps import Jumps @@ -80,22 +79,29 @@ def from_trajectory( trajectory: Trajectory, sites: Structure, floating_specie: str, - site_radius: Optional[float] = None, + site_radius: float | dict[str, float] | None = None, site_inner_fraction: float = 1., ) -> Transitions: - """Compute transitions for floating specie from trajectory and - structure with known sites. + """Compute transitions between given sites for floating specie. Parameters ---------- - trajectory : Trajectory - Input trajectory sites : pymatgen.core.structure.Structure - Input sites with known sites + Input structure with known sites floating_specie : str Name of the floating specie to calculate transitions for - site_radius: Optional[float] - A custom site size to use for determining if an atom is at a site + site_radius: float | dict[str, float] | None + A custom site radius in Ã…ngstrom to determine + if an atom is at a site. A dict keyed by the site label can + be used to have a site per atom type, e.g. + `site_radius = {'Li1': 1.0, 'Li2': 1.2}. + site_inner_fraction: + A fraction of the site radius which is determined to be the `inner site` + which is used in jump calculations + + Returns + ------- + transitions: Transitions """ diff_trajectory = trajectory.filter(floating_specie) @@ -108,6 +114,9 @@ def from_trajectory( sites=sites, vibration_amplitude=vibration_amplitude) + if isinstance(site_radius, float): + site_radius = {'': site_radius} + states = _calculate_atom_states( sites=sites, trajectory=diff_trajectory, @@ -420,7 +429,7 @@ def _compute_site_radius(trajectory: Trajectory, sites: Structure, def _calculate_atom_states( sites: Structure, trajectory: Trajectory, - site_radius: float, + site_radius: dict[str, float], site_inner_fraction: float = 1., ) -> np.ndarray: """Calculate nearest site for each atom coordinate in the trajectory. @@ -435,8 +444,9 @@ def _calculate_atom_states( Input sites with pre-defined sites trajectory : Trajectory Input trajectory for floating atoms - site_radius : float - Atoms within this distance (in Angstrom) are considered to be close to a site + site_radius : dict[str, float] + Atoms within this distance (in Angstrom) are considered to be close to a site. + Can also be a dict keyed by the site label to specify the radius by atom type. site_inner_fraction: float Atoms that are closer than (site_radius*site_inner_fraction) to a site, are considered to be in the inner site @@ -448,34 +458,49 @@ def _calculate_atom_states( The value corresponds to the index in the `site_coords`. -1 indicates that atom is not at any site. """ - # Unit cell parameters + + def _site_radius_iterator(): + for label, radius in site_radius.items(): + if label: + grouped = ((k, site) for k, site in enumerate(sites) + if site.label == label) + key, site_group = zip(*grouped) + frac_coords = np.array( + [site.frac_coords for site in site_group]) + yield frac_coords, np.array(key), radius + else: + yield sites.frac_coords, None, radius + lattice = trajectory.get_lattice() - site_coords = sites.frac_coords + cutoff = max(list(site_radius.values())) - # Input array with site coordinates [site, (x, y, z)] - site_cart_coords = np.dot(site_coords, lattice.matrix) - site_coords_tree: PeriodicKDTree = PeriodicKDTree( + traj_frac_coords = trajectory.positions.reshape(-1, 3) + traj_cart_coords = lattice.get_cartesian_coords(traj_frac_coords) + + periodic_tree: PeriodicKDTree = PeriodicKDTree( box=np.array(lattice.parameters, dtype=np.float32)) - site_coords_tree.set_coords(site_cart_coords, cutoff=site_radius) + periodic_tree.set_coords(traj_cart_coords, cutoff=cutoff) + + shape = trajectory.positions.shape[0:2] + + atom_sites = np.full((traj_cart_coords.shape[0]), NOSITE) - atom_sites = [] + for coords, key, radius in _site_radius_iterator(): + cart_coords = lattice.get_cartesian_coords(coords) + site_index = periodic_tree.search_tree(cart_coords, + radius * site_inner_fraction) - for atom_index, atom_coords in enumerate( - trajectory.positions.swapaxes(0, 1)): + siteno, index = site_index.T - # index and distance of nearest site - atom_cart_coords = np.dot(atom_coords, lattice.matrix) - site_index = site_coords_tree.search_tree( - atom_cart_coords, site_radius * site_inner_fraction) + if key is not None: + siteno = integer_remap(a=siteno, + key=key, + palette=np.unique(siteno)) - # construct mapping - atom_site = np.full((atom_coords.shape[0], 1), NOSITE) - for index, site in site_index: - atom_site[index] = site - atom_sites.append(atom_site) + atom_sites[index] = siteno - return np.hstack(atom_sites) + return atom_sites.reshape(shape) def _calculate_transitions_matrix(events: pd.DataFrame, diff --git a/src/gemdat/utils.py b/src/gemdat/utils.py index 3f5fa0cf..a0eefdd2 100644 --- a/src/gemdat/utils.py +++ b/src/gemdat/utils.py @@ -93,6 +93,33 @@ def bfill(arr: np.ndarray, fill_val: int = -1, axis=-1) -> np.ndarray: return np.fliplr(ffill(np.fliplr(arr), fill_val=fill_val)) +def integer_remap(a: np.ndarray, + key: np.ndarray, + palette: np.ndarray | None = None) -> np.ndarray: + """Map integers in array `a` from `palette` -> `key` + + Parameters + ---------- + a : np.ndarray + Input array with values to be + key : np.ndarray + The key gives the new values that the palette will be mapped to + palette : np.ndarray | None + Input values, must be given in sorted order. + If None, use sorted unique values in `a` + + Returns + ------- + np.ndarray + """ + if palette is None: + palette = np.unique(a) + + index = np.digitize(a, palette, right=True) + + return key[index].reshape(a.shape) + + def meanfreq(x: np.ndarray, fs: float = 1.0) -> np.ndarray: """Estimates the mean frequency in terms of the sample rate, fs. diff --git a/tests/integration/transitions_test.py b/tests/integration/transitions_test.py index a169b6e6..12fb2356 100644 --- a/tests/integration/transitions_test.py +++ b/tests/integration/transitions_test.py @@ -28,8 +28,24 @@ def test_site_radius(self, vasp_traj, structure): ) assert isclose(site_radius, 0.9284961123176741) - def test_atom_sites(self, vasp_traj, vasp_transitions): - n_steps = len(vasp_traj) + def test_site_radius_dict(self, vasp_traj, structure): + sites = structure.copy() + + for site in sites[0::3]: + site.label = 'A' + for site in sites[1::3]: + site.label = 'B' + for site in sites[2::3]: + site.label = 'C' + + site_radius = {'A': 0.5, 'B': 0.6, 'C': 0.7} + transitions = vasp_traj.transitions_between_sites( + sites=sites, floating_specie='Li', site_radius=site_radius) + + assert transitions.states.sum() == 3445344 + + def test_atom_sites(self, vasp_transitions): + n_steps = 3750 n_diffusing = 48 slice_ = np.s_[::1000, ::24] diff --git a/tests/utils_test.py b/tests/utils_test.py index 02dfc620..c4502ebc 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -2,8 +2,9 @@ import numpy as np import pytest +from numpy.testing import assert_allclose, assert_equal -from gemdat.utils import bfill, ffill, meanfreq +from gemdat.utils import bfill, ffill, integer_remap, meanfreq @pytest.fixture @@ -23,7 +24,7 @@ def test_ffill(arr): [4, 9, 6, 6, 6], ]) - np.testing.assert_equal(ret, expected) + assert_equal(ret, expected) def test_bfill(arr): @@ -34,7 +35,7 @@ def test_bfill(arr): [4, 9, 6, -1, -1], ]) - np.testing.assert_equal(ret, expected) + assert_equal(ret, expected) def test_ffill_axis0(arr): @@ -45,7 +46,7 @@ def test_ffill_axis0(arr): [4, 9, 6, 8, 2], ]) - np.testing.assert_equal(ret, expected) + assert_equal(ret, expected) def test_bfill_axis0(arr): @@ -56,7 +57,14 @@ def test_bfill_axis0(arr): [4, 9, 6, -1, -1], ]) - np.testing.assert_equal(ret, expected) + assert_equal(ret, expected) + + +def test_integer_remap(): + a = np.array([4, 2, 1, 3]) + key = np.array([10, 20, 30, 40]) + ret = integer_remap(a, key=key) + assert_equal(ret, a * 10) def test_meanfreq_single_timestep(): @@ -65,7 +73,7 @@ def test_meanfreq_single_timestep(): expected = np.array([[0.2303359]]) - np.testing.assert_allclose(ret, expected) + assert_allclose(ret, expected) def test_meanfreq(): @@ -78,4 +86,4 @@ def test_meanfreq(): expected = np.array([[0.2303359], [0.21308077], [0.17074241]]) - np.testing.assert_allclose(ret, expected) + assert_allclose(ret, expected)