Skip to content

Commit 0d490f2

Browse files
committed
Refactor atomic drift plotting code to use sisl viz
1 parent d76d09f commit 0d490f2

File tree

1 file changed

+9
-61
lines changed

1 file changed

+9
-61
lines changed

vsh/utils/structure_tools.py

+9-61
Original file line numberDiff line numberDiff line change
@@ -132,70 +132,18 @@ def get_symmetry_operations(file: str):
132132

133133
def plot_atomic_drift(initial_file: str, final_file: str, output_file: str = None):
134134
'''Plots the atomic drift between two structures'''
135-
from ase.io import read
136-
from ase.data.colors import jmol_colors, cpk_colors
137-
from ase.data import covalent_radii
135+
import sisl
136+
import sisl.viz
138137
import numpy as np
139-
import plotly.graph_objects as go
140-
141-
initial = read(initial_file)
142-
final = read(final_file)
138+
intial_atoms = read(initial_file)
139+
final_atoms = read(final_file)
140+
drift = final_atoms.get_positions() - intial_atoms.get_positions()
141+
geom = sisl.get_sile(initial_file).read_geometry()
142+
plot = geom.plot()
143143

144-
#atom types
145-
atomic_numbers = initial.get_atomic_numbers()
146-
atom_types = initial.get_chemical_symbols()
147-
atomic_radii = [ covalent_radii[number] for number in atomic_numbers ]
148-
atomic_colors = [ jmol_colors[number] for number in atomic_numbers ]
149-
drift = final.get_positions() - initial.get_positions()
150-
151-
fig = go.Figure()
144+
plot.update_inputs(arrows={"data": drift, "name": "Drift", "color": "orange", "width": 2}, axes='xyz')
152145

153-
for i in range(len(initial)):
154-
atom = initial[i]
155-
x, y, z = atom.position
156-
dx, dy, dz = drift[i]
157-
r = atomic_radii[i]
158-
color = f"rgb({atomic_colors[i][0]}, {atomic_colors[i][1]}, {atomic_colors[i][2]})",
159-
fig.add_trace(go.Scatter3d(
160-
x=[x+dx, x],
161-
y=[y+dy, y],
162-
z=[z+dz, z],
163-
mode='lines+markers',
164-
marker=dict(
165-
size=18,
166-
color=color,
167-
line=dict(
168-
color='DarkSlateGrey',
169-
width=2
170-
)
171-
),
172-
line=dict(
173-
color=color,
174-
width=2
175-
),
176-
name=atom_types[i]
177-
))
178-
179-
180-
fig.update_layout(
181-
title='Atomic drift',
182-
scene=dict(
183-
xaxis_title='x (Å)',
184-
yaxis_title='y (Å)',
185-
zaxis_title='z (Å)'
186-
)
187-
)
146+
plot.show()
188147

189-
190-
#remove grid
191-
fig.update_xaxes(showgrid=False, zeroline=False)
192-
fig.update_yaxes(showgrid=False, zeroline=False)
193-
fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
194-
195-
196-
if not output_file:
197-
fig.show()
198-
else:
199-
fig.write_image(output_file)
200148

201149

0 commit comments

Comments
 (0)