Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix all mypy issues #341

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ changelog = "https://github.com/GEMDAT-repos/GEMDAT/releases"

[project.optional-dependencies]
develop = [
"kaleido",
"kaleido < 0.4", # 0.4: https://github.com/plotly/Kaleido/issues/223
"bump-my-version",
"coverage[toml]",
"mypy",
Expand Down
2 changes: 1 addition & 1 deletion scripts/analyse_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def analyse_md(
"""
trajectory = Trajectory.from_vasprun(vasp_xml)

equilibration_steps = round(equil_time / trajectory.time_step)
equilibration_steps = round(equil_time / trajectory.time_step) # type: ignore

trajectory = trajectory[equilibration_steps:]

Expand Down
2 changes: 1 addition & 1 deletion src/gemdat/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def write_cif(structure: Structure, filename: Path | str):
filename : Path | str
Filename to write to
"""
filename = Path(filename).with_suffix('.cif')
filename = str(Path(filename).with_suffix('.cif'))
structure.to_file(filename)


Expand Down
2 changes: 1 addition & 1 deletion src/gemdat/jumps.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def site_pairs(self) -> list[tuple[str, str]]:
"""Return list of all unique site pairs."""
labels = self.sites.labels
site_pairs = product(labels, repeat=2)
return [pair for pair in site_pairs]
return [pair for pair in site_pairs] # type: ignore

@property
def jump_names(self) -> list[str]:
Expand Down
3 changes: 3 additions & 0 deletions src/gemdat/orientations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __post_init__(self, in_vectors: np.ndarray | None = None):
@property
def _time_step(self) -> float:
"""Return the time step of the trajectory."""
assert self.trajectory.time_step
return self.trajectory.time_step

@property
Expand All @@ -75,7 +76,9 @@ def _distances(self) -> np.ndarray:
"""Calculate distances between every central atom and all satellite
atoms."""
central_start_coord = self._trajectory_cent.base_positions
assert central_start_coord is not None
satellite_start_coord = self._trajectory_sat.base_positions
assert satellite_start_coord is not None
lattice = self.trajectory.get_lattice()
distance = np.array(
[
Expand Down
3 changes: 2 additions & 1 deletion src/gemdat/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ def total_length(self, lattice: Lattice) -> FloatWithUnit:
length : FloatWithUnit
Total distance in Ångstrom
"""
length = 0
length = 0.0
for a, b in pairwise(self.frac_sites()):
dist, _ = lattice.get_distance_and_image(a, b)
assert dist
length += dist
return FloatWithUnit(length, 'ang')

Expand Down
3 changes: 3 additions & 0 deletions src/gemdat/plots/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import pandas as pd
from pymatgen.core import Element, Species
from scipy.optimize import curve_fit
from scipy.stats import skewnorm

Expand All @@ -25,6 +26,8 @@ def _mean_displacements_per_element(

grouped = defaultdict(list)
for sp, distances in zip(species, trajectory.distances_from_base_position()):
assert isinstance(sp, (Species, Element)), f'got {type(sp)}'

grouped[sp.symbol].append(distances)

means = {}
Expand Down
7 changes: 5 additions & 2 deletions src/gemdat/plots/matplotlib/_msd_per_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import matplotlib.pyplot as plt
import numpy as np
from pymatgen.core import Element, Species

if TYPE_CHECKING:
import matplotlib.figure
Expand Down Expand Up @@ -41,6 +42,8 @@ def msd_per_element(
t_values = np.arange(len(trajectory)) * time_ps

for sp in species:
assert isinstance(sp, (Species, Element)), f'got {type(sp)}'

traj = trajectory.filter(sp.symbol)
msd = traj.mean_squared_displacement()

Expand All @@ -52,9 +55,9 @@ def msd_per_element(
last_color = ax.lines[-1].get_color()

if show_traces:
for i, traj in enumerate(msd):
for i, y_values in enumerate(msd):
label = f'{sp.symbol} trajectories' if (i == 0) else None
ax.plot(t_values, traj, lw=0.1, c=last_color, label=label)
ax.plot(t_values, y_values, lw=0.1, c=last_color, label=label)

if show_shaded:
ax.fill_between(
Expand Down
3 changes: 3 additions & 0 deletions src/gemdat/plots/plotly/_msd_per_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import plotly.graph_objects as go
from pymatgen.core import Element, Species

from gemdat.plots._shared import hex2rgba

Expand Down Expand Up @@ -31,6 +32,8 @@ def msd_per_element(*, trajectory: Trajectory) -> go.Figure:
species = list(set(trajectory.species))

for i, sp in enumerate(species):
assert isinstance(sp, (Species, Element)), f'got {type(sp)}'

color_hex = fig.layout['template']['layout']['colorway'][i]
color_rgba = hex2rgba(color_hex, opacity=0.3)

Expand Down
27 changes: 15 additions & 12 deletions src/gemdat/plots/plotly/_plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,35 +246,34 @@ def plot_jumps(jumps: Jumps, *, fig: go.Figure):
fig : plotly.graph_objects.Figure
Plotly figure to add traces too
"""
coords = jumps.sites.frac_coords
site_coords = jumps.sites.frac_coords
lattice = jumps.trajectory.get_lattice()

for i, j in zip(*np.triu_indices(len(coords), k=1)):
for i, j in zip(*np.triu_indices(len(site_coords), k=1)):
count = jumps.matrix()[i, j] + jumps.matrix()[j, i]
if count == 0:
continue

coord_i = tuple(coords[i].tolist())
coord_j = tuple(coords[j].tolist())
site_coord_i = tuple(site_coords[i].tolist())
site_coord_j = tuple(site_coords[j].tolist())

lw = 1 + np.log(count)

length, image = lattice.get_distance_and_image(coord_i, coord_j)
length, image = lattice.get_distance_and_image(site_coord_i, site_coord_j)

if np.any(image != 0):
lines = [(coord_i, coord_j + image), (coord_i - image, coord_j)]
lines = [(site_coord_i, site_coord_j + image), (site_coord_i - image, site_coord_j)]
else:
lines = [(coord_i, coord_j)]
lines = [(site_coord_i, site_coord_j)]

for line in lines:
line = lattice.get_cartesian_coords(line)
line_t = [_ for _ in zip(*line)] # transpose, but pythonic
x, y, z = lattice.get_cartesian_coords(line).T

fig.add_trace(
go.Scatter3d(
x=line_t[0],
y=line_t[1],
z=line_t[2],
x=x,
y=y,
z=z,
mode='lines',
showlegend=False,
line_dash='dashdot' if any(image) != 0 else 'solid',
Expand Down Expand Up @@ -356,6 +355,10 @@ def plot_3d(
lattice = structure.lattice
elif jumps:
lattice = jumps.trajectory.get_lattice()
else:
raise ValueError(
'Lattice cannot be determined form volume, structure, or jumps object.'
)
else:
raise ValueError('Cannot derive lattice from input.')

Expand Down
4 changes: 2 additions & 2 deletions src/gemdat/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def radial_distribution(
coords = trajectory.positions
sp_coords = trajectory.filter(floating_specie).positions

states2str = _get_states(sites.labels)
states_array = _get_states_array(transitions, sites.labels)
states2str = _get_states(sites.labels) # type: ignore
states_array = _get_states_array(transitions, sites.labels) # type: ignore
symbol_indices = _get_symbol_indices(base_structure)

bins = np.arange(0, max_dist + resolution, resolution)
Expand Down
5 changes: 3 additions & 2 deletions src/gemdat/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .utils import warn_lattice_not_close

if TYPE_CHECKING:
from pymatgen.symmetry.analyzer import SpacegroupOperations
from pymatgen.symmetry.groups import SpaceGroup
from pymatgen.symmetry.structure import SymmetrizedStructure

Expand Down Expand Up @@ -95,7 +96,7 @@ def __init__(
*,
sites: Collection[PeriodicSite],
lattice: Lattice,
spacegroup: SpaceGroup,
spacegroup: SpaceGroup | SpacegroupOperations,
):
"""Set up shape analyzer from a collection of unique periodic sites,
the lattice, and spacegroup.
Expand Down Expand Up @@ -400,7 +401,7 @@ def to_structure(self) -> Structure:
sg=self.spacegroup.int_number,
lattice=self.lattice,
species=[site.specie for site in self.sites],
coords=[site.frac_coords for site in self.sites],
coords=[site.frac_coords for site in self.sites], # type: ignore
labels=[site.label for site in self.sites],
)
return structure
29 changes: 22 additions & 7 deletions src/gemdat/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import TYPE_CHECKING, Collection, Optional

import numpy as np
from pymatgen.core import Element, Lattice
from pymatgen.core import Element, Lattice, Species
from pymatgen.core.trajectory import Trajectory as PymatgenTrajectory
from pymatgen.io import vasp

Expand Down Expand Up @@ -133,16 +133,19 @@ def to_volume(self, resolution: float = 0.2) -> Volume:
@property
def time_step_ps(self) -> float:
"""Return time step in picoseconds."""
assert self.time_step
return self.time_step * 1e12

@property
def total_time(self) -> float:
"""Return total time for trajectory."""
assert self.time_step
return len(self) * self.time_step

@property
def sampling_frequency(self) -> float:
"""Return number of time steps per second."""
assert self.time_step
return 1 / self.time_step

@property
Expand Down Expand Up @@ -469,9 +472,9 @@ def get_lattice(self, idx: int | None = None) -> Lattice:
Pymatgen Lattice object
"""
if self.constant_lattice:
return Lattice(self.lattice)
return Lattice(self.lattice) # type: ignore

latt = self.lattices[idx]
latt = self.lattices[idx] # type: ignore
return Lattice(latt)

@property
Expand Down Expand Up @@ -503,7 +506,10 @@ def distances_from_base_position(self) -> np.ndarray:

def center_of_mass(self) -> Trajectory:
"""Return trajectory with center of mass for positions."""
weights = [s.atomic_mass for s in self.species]
weights = []
for s in self.species:
assert isinstance(s, (Species, Element)), f'got {type(s)=}'
weights.append(s.atomic_mass)

positions_no_pbc = self.base_positions + self.cumulative_displacements

Expand Down Expand Up @@ -547,8 +553,13 @@ def drift(
if fixed_species:
displacements = self.filter(species=fixed_species).displacements
elif floating_species:
species = {sp.symbol for sp in self.species if sp.symbol not in floating_species}
displacements = self.filter(species=species).displacements
species = set()
for sp in self.species:
assert isinstance(sp, Species), f'got {type(sp)=}'
if sp.symbol not in floating_species:
species.add(sp)

displacements = self.filter(species=species).displacements # type: ignore
else:
displacements = self.displacements

Expand Down Expand Up @@ -609,7 +620,11 @@ def filter(self, species: str | Collection[str]) -> Trajectory:
if isinstance(species, str):
species = [species]

idx = [sp.symbol in species for sp in self.species]
idx = []
for sp in self.species:
assert isinstance(sp, (Species, Element))
idx.append(sp.symbol in species)

new_coords = self.positions[:, idx]
new_species = list(compress(self.species, idx))

Expand Down
3 changes: 1 addition & 2 deletions src/gemdat/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def from_volumetric_data(cls, volume: VolumetricData):
Input volumetric data
"""
return cls(
data=volume.data,
data=volume.data['total'],
lattice=volume.structure.lattice,
)

Expand Down Expand Up @@ -506,5 +506,4 @@ def trajectory_to_volume(
data=data,
lattice=lattice,
label='trajectory',
units=None,
)