Skip to content

Commit

Permalink
Refactor radial_distribution_between_species() (#340)
Browse files Browse the repository at this point in the history
* Refactor radial_distribution_between_species()

* Add shortcut to radial distribution on trajectory

* Fix test fail
  • Loading branch information
stefsmeets authored Nov 14, 2024
1 parent ea6b841 commit c1888d8
Show file tree
Hide file tree
Showing 15 changed files with 102 additions and 199 deletions.
2 changes: 0 additions & 2 deletions src/gemdat/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
msd_per_element,
plot_3d,
radial_distribution,
radial_distribution_between_species,
rectilinear,
shape,
vibrational_amplitudes,
Expand All @@ -45,7 +44,6 @@
'msd_per_element',
'plot_3d',
'radial_distribution',
'radial_distribution_between_species',
'rectilinear',
'shape',
'vibrational_amplitudes',
Expand Down
44 changes: 0 additions & 44 deletions src/gemdat/plots/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from scipy.stats import skewnorm

if TYPE_CHECKING:
from typing import Collection

from gemdat.orientations import Orientations
from gemdat.trajectory import Trajectory

Expand Down Expand Up @@ -149,45 +147,3 @@ def _get_vibrational_amplitudes_hist(
std = np.std(data, axis=0)

return VibrationalAmplitudeHist(amplitudes=amplitudes, counts=mean, std=std)


def _get_radial_distribution_between_species(
*,
trajectory: Trajectory,
specie_1: str | Collection[str],
specie_2: str | Collection[str],
max_dist: float = 5.0,
resolution: float = 0.1,
) -> tuple[np.ndarray, np.ndarray]:
coords_1 = trajectory.filter(specie_1).coords
coords_2 = trajectory.filter(specie_2).coords
lattice = trajectory.get_lattice()

if coords_2.ndim == 2:
num_time_steps = 1
num_atoms, num_dimensions = coords_2.shape
else:
num_time_steps, num_atoms, num_dimensions = coords_2.shape

particle_vol = num_atoms / lattice.volume

all_dists = np.concatenate(
[
lattice.get_all_distances(coords_1[t, :, :], coords_2[t, :, :])
for t in range(num_time_steps)
]
)
distances = all_dists.flatten()

bins = np.arange(0, max_dist + resolution, resolution)
rdf, _ = np.histogram(distances, bins=bins, density=False)

def normalize(radius: np.ndarray) -> np.ndarray:
"""Normalize bin to volume."""
shell = (radius + resolution) ** 3 - radius**3
return particle_vol * (4 / 3) * np.pi * shell

norm = normalize(bins)[:-1]
rdf = rdf / norm

return bins, rdf
2 changes: 0 additions & 2 deletions src/gemdat/plots/matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ._jumps_vs_time import jumps_vs_time
from ._msd_per_element import msd_per_element
from ._radial_distribution import radial_distribution
from ._radial_distribution_between_species import radial_distribution_between_species
from ._rectilinear import rectilinear
from ._shape import shape
from ._vibrational_amplitudes import vibrational_amplitudes
Expand All @@ -37,7 +36,6 @@
'jumps_vs_time',
'msd_per_element',
'radial_distribution',
'radial_distribution_between_species',
'rectilinear',
'shape',
'vibrational_amplitudes',
Expand Down
7 changes: 4 additions & 3 deletions src/gemdat/plots/matplotlib/_radial_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ def radial_distribution(rdfs: Iterable[RDFData]) -> matplotlib.figure.Figure:
fig, ax = plt.subplots()

for rdf in rdfs:
ax.plot(rdf.x, rdf.y, label=rdf.symbol)
ax.plot(rdf.x, rdf.y, label=rdf.label)

states = ', '.join({rdf.state for rdf in rdfs})
states = ', '.join({rdf.state for rdf in rdfs if rdf.state})
state_suffix = f' ({states})' if states else ''

ax.legend()
ax.set(
title=f'Radial distribution function ({states})',
title=f'Radial distribution function{state_suffix}',
xlabel='Distance (Å)',
ylabel='Counts',
)
Expand Down

This file was deleted.

2 changes: 0 additions & 2 deletions src/gemdat/plots/plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ._msd_per_element import msd_per_element
from ._plot3d import plot_3d
from ._radial_distribution import radial_distribution
from ._radial_distribution_between_species import radial_distribution_between_species
from ._rectilinear import rectilinear
from ._shape import shape
from ._vibrational_amplitudes import vibrational_amplitudes
Expand All @@ -38,7 +37,6 @@
'msd_per_element',
'plot_3d',
'radial_distribution',
'radial_distribution_between_species',
'rectilinear',
'shape',
'vibrational_amplitudes',
Expand Down
8 changes: 5 additions & 3 deletions src/gemdat/plots/plotly/_radial_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@ def radial_distribution(rdfs: Iterable[RDFData]) -> go.Figure:
go.Scatter(
x=rdf.x,
y=rdf.y,
name=rdf.symbol,
name=rdf.label,
mode='lines',
# line={'width': 0.25}
)
)

states = ', '.join({rdf.state for rdf in rdfs})
states = ', '.join({rdf.state for rdf in rdfs if rdf.state})
state_suffix = f' ({states})' if states else ''

fig.update_layout(
title=f'Radial distribution function ({states})',
title=f'Radial distribution function{state_suffix}',
xaxis_title='Distance (Å)',
yaxis_title='Counts',
)
Expand Down
66 changes: 0 additions & 66 deletions src/gemdat/plots/plotly/_radial_distribution_between_species.py

This file was deleted.

76 changes: 72 additions & 4 deletions src/gemdat/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
from ._plot_backend import plot_backend

if TYPE_CHECKING:
from typing import Collection

from pymatgen.core import Structure

from gemdat import Trajectory
from gemdat.transitions import Transitions


Expand Down Expand Up @@ -79,16 +82,16 @@ class RDFData:
1D array with x data (bins)
y : np.ndarray
1D array with y data (counts)
symbol : str
Distance to species with this symbol
label : str
Distance to species with this symbol label
state : str
State that the floating species is in, e.g.
the jump that it is making.
"""

x: np.ndarray
y: np.ndarray
symbol: str
label: str
state: str

@plot_backend
Expand Down Expand Up @@ -178,10 +181,75 @@ def radial_distribution(
x=bins,
# Drop last element with distance > max_dist
y=values[:-1],
symbol=symbol,
label=symbol,
state=state,
)
ret.setdefault(state, RDFCollection())
ret[state].append(rdf_data)

return ret


def radial_distribution_between_species(
*,
trajectory: Trajectory,
specie_1: str | Collection[str],
specie_2: str | Collection[str],
max_dist: float = 5.0,
resolution: float = 0.1,
) -> RDFData:
"""Calculate RDFs from specie_1 to specie_2.
Parameters
----------
trajectory: Trajectory
Input trajectory.
specie_1: str | list[str]
Name of specie or list of species
specie_2: str | list[str]
Name of specie or list of species
max_dist: float, optional
Max distance for rdf calculation
resolution: float, optional
Width of the bins
Returns
-------
rdf : RDFData
RDF data for the given species.
"""
coords_1 = trajectory.filter(specie_1).coords
coords_2 = trajectory.filter(specie_2).coords
lattice = trajectory.get_lattice()

if coords_2.ndim == 2:
num_time_steps = 1
num_atoms, num_dimensions = coords_2.shape
else:
num_time_steps, num_atoms, num_dimensions = coords_2.shape

particle_vol = num_atoms / lattice.volume

all_dists = np.concatenate(
[
lattice.get_all_distances(coords_1[t, :, :], coords_2[t, :, :])
for t in range(num_time_steps)
]
)
distances = all_dists.flatten()

bins = np.arange(0, max_dist + resolution, resolution)
rdf, _ = np.histogram(distances, bins=bins, density=False)

def normalize(radius: np.ndarray) -> np.ndarray:
"""Normalize bin to volume."""
shell = (radius + resolution) ** 3 - radius**3
return particle_vol * (4 / 3) * np.pi * shell

norm = normalize(bins)[:-1]
counts = rdf / norm

str1 = specie_1 if isinstance(specie_1, str) else '/'.join(specie_1)
str2 = specie_1 if isinstance(specie_2, str) else '/'.join(specie_2)

return RDFData(x=bins[:-1], y=counts, label=f'{str1}-{str2}', state='')
Loading

0 comments on commit c1888d8

Please sign in to comment.