@@ -132,70 +132,18 @@ def get_symmetry_operations(file: str):
132
132
133
133
def plot_atomic_drift (initial_file : str , final_file : str , output_file : str = None ):
134
134
'''Plots the atomic drift between two structures'''
135
- from ase .io import read
136
- from ase .data .colors import jmol_colors , cpk_colors
137
- from ase .data import covalent_radii
135
+ import sisl
136
+ import sisl .viz
138
137
import numpy as np
139
- import plotly .graph_objects as go
140
-
141
- initial = read (initial_file )
142
- final = read (final_file )
138
+ intial_atoms = read (initial_file )
139
+ final_atoms = read (final_file )
140
+ drift = final_atoms .get_positions () - intial_atoms .get_positions ()
141
+ geom = sisl .get_sile (initial_file ).read_geometry ()
142
+ plot = geom .plot ()
143
143
144
- #atom types
145
- atomic_numbers = initial .get_atomic_numbers ()
146
- atom_types = initial .get_chemical_symbols ()
147
- atomic_radii = [ covalent_radii [number ] for number in atomic_numbers ]
148
- atomic_colors = [ jmol_colors [number ] for number in atomic_numbers ]
149
- drift = final .get_positions () - initial .get_positions ()
150
-
151
- fig = go .Figure ()
144
+ plot .update_inputs (arrows = {"data" : drift , "name" : "Drift" , "color" : "orange" , "width" : 2 }, axes = 'xyz' )
152
145
153
- for i in range (len (initial )):
154
- atom = initial [i ]
155
- x , y , z = atom .position
156
- dx , dy , dz = drift [i ]
157
- r = atomic_radii [i ]
158
- color = f"rgb({ atomic_colors [i ][0 ]} , { atomic_colors [i ][1 ]} , { atomic_colors [i ][2 ]} )" ,
159
- fig .add_trace (go .Scatter3d (
160
- x = [x + dx , x ],
161
- y = [y + dy , y ],
162
- z = [z + dz , z ],
163
- mode = 'lines+markers' ,
164
- marker = dict (
165
- size = 18 ,
166
- color = color ,
167
- line = dict (
168
- color = 'DarkSlateGrey' ,
169
- width = 2
170
- )
171
- ),
172
- line = dict (
173
- color = color ,
174
- width = 2
175
- ),
176
- name = atom_types [i ]
177
- ))
178
-
179
-
180
- fig .update_layout (
181
- title = 'Atomic drift' ,
182
- scene = dict (
183
- xaxis_title = 'x (Å)' ,
184
- yaxis_title = 'y (Å)' ,
185
- zaxis_title = 'z (Å)'
186
- )
187
- )
146
+ plot .show ()
188
147
189
-
190
- #remove grid
191
- fig .update_xaxes (showgrid = False , zeroline = False )
192
- fig .update_yaxes (showgrid = False , zeroline = False )
193
- fig .update_scenes (xaxis_visible = False , yaxis_visible = False ,zaxis_visible = False )
194
-
195
-
196
- if not output_file :
197
- fig .show ()
198
- else :
199
- fig .write_image (output_file )
200
148
201
149
0 commit comments