Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stefsmeets committed Nov 18, 2024
1 parent 4338047 commit 244df45
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 20 deletions.
4 changes: 2 additions & 2 deletions src/gemdat/plots/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np
import pandas as pd
from pymatgen.core import Species
from pymatgen.core import Element, Species
from scipy.optimize import curve_fit
from scipy.stats import skewnorm

Expand All @@ -26,7 +26,7 @@ def _mean_displacements_per_element(

grouped = defaultdict(list)
for sp, distances in zip(species, trajectory.distances_from_base_position()):
assert isinstance(sp, Species)
assert isinstance(sp, (Species, Element)), f'got {type(sp)}'

grouped[sp.symbol].append(distances)

Expand Down
4 changes: 2 additions & 2 deletions src/gemdat/plots/matplotlib/_msd_per_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import matplotlib.pyplot as plt
import numpy as np
from pymatgen.core import Species
from pymatgen.core import Element, Species

if TYPE_CHECKING:
import matplotlib.figure
Expand Down Expand Up @@ -42,7 +42,7 @@ def msd_per_element(
t_values = np.arange(len(trajectory)) * time_ps

for sp in species:
assert isinstance(sp, Species)
assert isinstance(sp, (Species, Element)), f'got {type(sp)}'

traj = trajectory.filter(sp.symbol)
msd = traj.mean_squared_displacement()
Expand Down
4 changes: 2 additions & 2 deletions src/gemdat/plots/plotly/_msd_per_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import plotly.graph_objects as go
from pymatgen.core import Species
from pymatgen.core import Element, Species

from gemdat.plots._shared import hex2rgba

Expand Down Expand Up @@ -32,7 +32,7 @@ def msd_per_element(*, trajectory: Trajectory) -> go.Figure:
species = list(set(trajectory.species))

for i, sp in enumerate(species):
assert isinstance(sp, Species)
assert isinstance(sp, (Species, Element)), f'got {type(sp)}'

color_hex = fig.layout['template']['layout']['colorway'][i]
color_rgba = hex2rgba(color_hex, opacity=0.3)
Expand Down
23 changes: 11 additions & 12 deletions src/gemdat/plots/plotly/_plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,35 +246,34 @@ def plot_jumps(jumps: Jumps, *, fig: go.Figure):
fig : plotly.graph_objects.Figure
Plotly figure to add traces too
"""
coords = jumps.sites.frac_coords
site_coords = jumps.sites.frac_coords
lattice = jumps.trajectory.get_lattice()

for i, j in zip(*np.triu_indices(len(coords), k=1)):
for i, j in zip(*np.triu_indices(len(site_coords), k=1)):
count = jumps.matrix()[i, j] + jumps.matrix()[j, i]
if count == 0:
continue

coord_i = tuple(coords[i].tolist())
coord_j = tuple(coords[j].tolist())
site_coord_i = tuple(site_coords[i].tolist())
site_coord_j = tuple(site_coords[j].tolist())

lw = 1 + np.log(count)

length, image = lattice.get_distance_and_image(coord_i, coord_j)
length, image = lattice.get_distance_and_image(site_coord_i, site_coord_j)

if np.any(image != 0):
lines = [(coord_i, coord_j + image), (coord_i - image, coord_j)]
lines = [(site_coord_i, site_coord_j + image), (site_coord_i - image, site_coord_j)]
else:
lines = [(coord_i, coord_j)]
lines = [(site_coord_i, site_coord_j)]

for line in lines:
coords = lattice.get_cartesian_coords(line)
coords_t = list(zip(*coords)) # transpose, but pythonic
x, y, z = lattice.get_cartesian_coords(line).T

fig.add_trace(
go.Scatter3d(
x=coords_t[0],
y=coords_t[1],
z=coords_t[2],
x=x,
y=y,
z=z,
mode='lines',
showlegend=False,
line_dash='dashdot' if any(image) != 0 else 'solid',
Expand Down
4 changes: 2 additions & 2 deletions src/gemdat/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def center_of_mass(self) -> Trajectory:
"""Return trajectory with center of mass for positions."""
weights = []
for s in self.species:
assert isinstance(s, Species)
assert isinstance(s, (Species, Element)), f'got {type(s)=}'
weights.append(s.atomic_mass)

positions_no_pbc = self.base_positions + self.cumulative_displacements
Expand Down Expand Up @@ -555,7 +555,7 @@ def drift(
elif floating_species:
species = set()
for sp in self.species:
assert isinstance(sp, Species)
assert isinstance(sp, Species), f'got {type(sp)=}'
if sp.symbol not in floating_species:
species.add(sp)

Expand Down

0 comments on commit 244df45

Please sign in to comment.