Skip to content

Commit

Permalink
Force push to solve branch conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
SCiarella committed Feb 5, 2024
1 parent 0aa231e commit fdb06a7
Show file tree
Hide file tree
Showing 11 changed files with 201 additions and 204 deletions.
35 changes: 20 additions & 15 deletions src/gemdat/jumps.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,16 @@ def collective(self, max_dist: float = 1) -> Collective:

@weak_lru_cache()
def activation_energies(
self) -> dict[tuple[str, str], tuple[float, float]]:
self,
n_parts: int = 10) -> dict[tuple[str, str], tuple[float, float]]:
"""Calculate activation energies for jumps (UNITS?).
Returns
-------
e_act : dict[tuple[str, str], tuple[float, float]]
Dictionary with jump activation energies and standard deviations between site pairs.
n_parts : int
Number of parts to split transitions/jumps into for statistics
"""
trajectory = self.transitions.trajectory
attempt_freq, _ = SimulationMetrics(trajectory).attempt_frequency()
Expand All @@ -252,16 +255,21 @@ def activation_energies(

temperature = trajectory.metadata['temperature']

atom_locations_parts = self.sites.atom_locations_parts(
self.transitions)
parts = self.jumps_counter_parts(self.sites.n_parts)
atom_locations_parts = [
self.sites.atom_locations(part)
for part in self.transitions.split(n_parts)
]
jumps_counter_parts = [
part.jumps_counter() for part in self.split(n_parts)
]

for site_pair in self.sites.site_pairs:
site_start, site_stop = site_pair

n_jumps = np.array([part[site_pair] for part in parts])
n_jumps = np.array(
[part[site_pair] for part in jumps_counter_parts])

part_time = trajectory.total_time / self.sites.n_parts
part_time = trajectory.total_time / n_parts

atom_percentage = np.array(
[part[site_start] for part in atom_locations_parts])
Expand Down Expand Up @@ -310,28 +318,25 @@ def split(self, n_parts) -> list[Jumps]:
]

@weak_lru_cache()
def jumps_counter_parts(self, n_parts) -> list[Counter]:
"""Return [jump counters][gemdat.sites.SitesData.jumps] per part."""

return [part.jumps_counter() for part in self.split(n_parts)]

@weak_lru_cache()
def rates(self) -> dict[tuple[str, str], tuple[float, float]]:
def rates(self,
n_parts: int = 10) -> dict[tuple[str, str], tuple[float, float]]:
"""Calculate jump rates (total jumps / second).
Returns
-------
rates : dict[tuple[str, str], tuple[float, float]]
Dictionary with jump rates and standard deviations between site pairs
n_parts : int
Number of parts to split jumps into for statistics
"""
rates: dict[tuple[str, str], tuple[float, float]] = {}

parts = self.jumps_counter_parts(self.sites.n_parts)
parts = [part.jumps_counter() for part in self.split(n_parts)]

for site_pair in self.sites.site_pairs:
n_jumps = [part[site_pair] for part in parts]

part_time = self.transitions.trajectory.total_time / self.sites.n_parts
part_time = self.transitions.trajectory.total_time / n_parts
denom = self.n_floating * part_time

jump_freq_mean = np.mean(n_jumps) / denom
Expand Down
1 change: 0 additions & 1 deletion src/gemdat/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def analyse_md(
structure=sites_structure,
trajectory=trajectory,
floating_specie=diff_elem,
n_parts=nr_parts,
)

transitions = Transitions.from_trajectory(trajectory=trajectory,
Expand Down
16 changes: 6 additions & 10 deletions src/gemdat/plots/matplotlib/_displacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,15 @@ def msd_per_element(*, trajectory: Trajectory) -> plt.Figure:
fig : matplotlib.figure.Figure
Output figure
"""
grouped = defaultdict(list)

species = trajectory.species

for sp, distances in zip(species,
trajectory.distances_from_base_position()):
grouped[sp.symbol].append(distances**2)
species = list(set(trajectory.species))

fig, ax = plt.subplots()

for symbol, sq_distances in grouped.items():
mean_sq_disp = np.mean(sq_distances, axis=0)
ax.plot(mean_sq_disp, lw=0.3, label=symbol)
for sp in species:
traj = trajectory.filter(sp.symbol)
ax.plot(traj.mean_squared_displacement().mean(axis=0),
lw=0.5,
label=sp.symbol)

ax.legend()
ax.set(title='Mean squared displacement per element',
Expand Down
17 changes: 6 additions & 11 deletions src/gemdat/plots/plotly/_displacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,17 @@ def msd_per_element(*, trajectory: Trajectory) -> go.Figure:

fig = go.Figure()

grouped = defaultdict(list)

species = trajectory.species
species = list(set(trajectory.species))

for sp, distances in zip(species,
trajectory.distances_from_base_position()):
grouped[sp.symbol].append(distances**2)
for sp in species:
traj = trajectory.filter(sp.symbol)

for symbol, sq_distances in grouped.items():
mean_sq_disp = np.mean(sq_distances, axis=0)
fig.add_trace(
go.Scatter(y=mean_sq_disp,
name=symbol,
go.Scatter(y=traj.mean_squared_displacement().mean(axis=0),
name=sp.symbol,
mode='lines',
line={'width': 3},
legendgroup=symbol))
legendgroup=sp.symbol))

fig.update_layout(title='Mean squared displacement per element',
xaxis_title='Time step',
Expand Down
112 changes: 14 additions & 98 deletions src/gemdat/sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@
import numpy as np
from pymatgen.core import Structure

from .caching import weak_lru_cache
from .simulation_metrics import SimulationMetrics
from .transitions import Transitions
from .utils import is_lattice_similar

if typing.TYPE_CHECKING:

from gemdat.trajectory import Trajectory
from gemdat.transitions import Transitions

NOSITE = -1

Expand All @@ -29,7 +27,6 @@ def __init__(
structure: Structure,
trajectory: Trajectory,
floating_specie: str,
n_parts: int = 10,
site_radius: Optional[float] = None,
):
"""Contain sites and jumps data.
Expand All @@ -42,8 +39,6 @@ def __init__(
Input trajectory
floating_specie : str
Name of the floating or diffusing specie
n_parts : int, optional
Number of parts to divide transitions into for statistics
site_radius: Optional[float]
if set it fixes the site_radius instead of determining it
dynamically
Expand All @@ -52,8 +47,6 @@ def __init__(
raise ValueError(
'Trajectory must have constant lattice for site analysis.')

self.n_parts = n_parts

self.floating_specie = floating_specie
self.structure = structure

Expand Down Expand Up @@ -101,54 +94,7 @@ def site_pairs(self) -> list[tuple[str, str]]:
site_pairs = product(labels, repeat=2)
return [pair for pair in site_pairs] # type: ignore

def occupancy_parts(self,
transitions: Transitions) -> list[dict[int, int]]:
"""Return [occupancy arrays][gemdat.transitions.Transitions.occupancy]
from parts."""
return [part.occupancy() for part in transitions.split(self.n_parts)]

def site_occupancy_parts(
self, transitions: Transitions) -> list[dict[str, float]]:
"""Return [site occupancy][gemdat.sites.SitesData.site_occupancy] dicts
per part."""
labels = self.site_labels
n_steps = len(self.trajectory)

parts = transitions.split(self.n_parts)

return [
_calculate_site_occupancy(occupancy=part.occupancy(),
labels=labels,
n_steps=int(n_steps / self.n_parts))
for part in parts
]

@weak_lru_cache()
def atom_locations_parts(
self, transitions: Transitions) -> list[dict[str, float]]:
"""Return [atom locations][gemdat.sites.SitesData.atom_locations] dicts
per part."""
multiplier = self.n_sites / self.n_floating
return [{
k: v * multiplier
for k, v in part.items()
} for part in self.site_occupancy_parts(transitions)]

def site_occupancy(self, transitions: Transitions):
"""Calculate percentage occupancy per unique site.
Returns
-------
site_occopancy : dict[str, float]
Percentage occupancy per unique site
"""
labels = self.site_labels
n_steps = len(self.trajectory)
return _calculate_site_occupancy(occupancy=transitions.occupancy(),
labels=labels,
n_steps=n_steps)

def atom_locations(self, transitions):
def atom_locations(self, transitions: Transitions):
"""Calculate fraction of time atoms spent at a type of site.
Returns
Expand All @@ -157,45 +103,15 @@ def atom_locations(self, transitions):
Return dict with the fraction of time atoms spent at a site
"""
multiplier = self.n_sites / self.n_floating
return {
k: v * multiplier
for k, v in self.site_occupancy(transitions).items()
}


def _calculate_site_occupancy(
*,
occupancy: dict[int, int],
labels: list[str],
n_steps: int,
) -> dict[str, float]:
"""Calculate percentage occupancy per unique site.
Parameters
----------
occupancy : dict[int, int]
Occupancy dict
labels : list[str]
Site labels
n_steps : int
Number of steps in time series
Returns
-------
dict[str, float]
Percentage occupancy per unique site
"""
counts = defaultdict(list)

assert all(v >= 0 for v in occupancy)

for k, v in occupancy.items():
label = labels[k]
counts[label].append(v)

site_occupancies = {
k: sum(v) / (n_steps * labels.count(k))
for k, v in counts.items()
}

return site_occupancies

compositions_by_label = defaultdict(list)

for site in transitions.occupancy():
compositions_by_label[site.label].append(site.species.num_atoms)

ret = {}

for k, v in compositions_by_label.items():
ret[k] = (sum(v) / len(v)) * multiplier

return ret
Loading

0 comments on commit fdb06a7

Please sign in to comment.