-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcalculate_flowlines.py
executable file
·242 lines (220 loc) · 10.5 KB
/
calculate_flowlines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import xarray as xr
################ This is the import statement required to reference scripts within the package
import os,sys,glob
ndh_tools_path_opts = [
'/mnt/data01/Code/',
'/mnt/l/mnt/data01/Code/',
'/home/common/HolschuhLab/Code/'
]
for i in ndh_tools_path_opts:
if os.path.isfile(i): sys.path.append(i)
################################################################################################
import NDH_Tools as ndh
def calculate_flowlines(input_xr,seed_points,uv_varnames=['u','v'],xy_varnames=['x','y'],steps=20000,ds=2,forward0_both1_backward2=1):
"""
% (C) Nick Holschuh - Amherst College -- 2022 (Nick.Holschuh@gmail.com)
%
% This function takes a vector field described in an xarray dataset, an
% array of points, and calculates flowlines that pass through the array
% points following the vector field.
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% The inputs are:
%
% input_xr -- this must be an xarray dataset with two dataarrays, represting
% the x components and the y components of a vector field. The data
% variables and the coordinate variables that describe them can have
% any name, but the defaults are 'u','v','x','y'.
% seed_points -- this should be an nx2 array containing x/y pairs for seed points
% used to constrain the calculated flowlines
% uv_varnames -- default=['u','v'], these are the datavariable names for the
% vector field components.
% xy_varnames -- default=['x','y'], these are the coordinate variable names
% describing the columns and rows of the vector field arrays
% steps -- default=20000, this is the number of steps to take away from the seed
% in either the forward or backward direction
% ds -- default=2, this is the step-size to take when propagating the flowline away
% from the seedpoint (in the same units as the coordinate variables
% forward0_both1_backward2 -- default=1, this sets whether or not you want
% the flowlines to extend down-vector, up-vector, or
% both from the seed point.
%
%%%%%%%%%%%%%%%
% The outputs are:
%
% output -- a list of nx2 arrays containing the flowlines associated with
% each seed point
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
"""
##################### Here, we standardize the naming convention within the xarray object
input_xr = input_xr.rename({xy_varnames[0]:'x',xy_varnames[1]:'y'})
uv_scalar = np.sqrt(input_xr[uv_varnames[0]].values**2 + input_xr[uv_varnames[1]].values**2)
input_xr[uv_varnames[0]] = (('y','x'),input_xr[uv_varnames[0]].values/uv_scalar)
input_xr[uv_varnames[1]] = (('y','x'),input_xr[uv_varnames[1]].values/uv_scalar)
#################### We initialize the objects for the flowline calculation
flowlines = []
#################### Here is the forward calculation
if forward0_both1_backward2 <= 1:
temp_xs = np.expand_dims(seed_points[:,0],0)
temp_ys = np.expand_dims(seed_points[:,1],0)
for ind0 in tqdm.tqdm(np.arange(steps)):
x_search = xr.DataArray(temp_xs[-1,:],dims=['vector_index'])
y_search = xr.DataArray(temp_ys[-1,:],dims=['vector_index'])
new_u = input_xr[uv_varnames[0]].sel(x=x_search,y=y_search,method='nearest')
new_v = input_xr[uv_varnames[1]].sel(x=x_search,y=y_search,method='nearest')
######### This is an order of magnitude slower
#new_u = input_xr[uv_varnames[0]].interp(x=x_search,y=y_search)
#new_v = input_xr[uv_varnames[1]].interp(x=x_search,y=y_search)
temp_xs = np.concatenate([temp_xs,temp_xs[-1:,:]+new_u.values.T*ds])
temp_ys = np.concatenate([temp_ys,temp_ys[-1:,:]+new_v.values.T*ds])
xs = temp_xs
ys = temp_ys
else:
xs = np.empty([0,len(seed_points)])
ys = np.empty([0,len(seed_points)])
#################### Here is the backward calculation
if forward0_both1_backward2 >= 1:
temp_xs = np.expand_dims(seed_points[:,0],0)
temp_ys = np.expand_dims(seed_points[:,1],0)
for ind0 in tqdm.tqdm(np.arange(steps)):
x_search = xr.DataArray(temp_xs[-1,:],dims=['vector_index'])
y_search = xr.DataArray(temp_ys[-1,:],dims=['vector_index'])
new_u = input_xr[uv_varnames[0]].sel(x=x_search,y=y_search,method='nearest')
new_v = input_xr[uv_varnames[1]].sel(x=x_search,y=y_search,method='nearest')
######### This is an order of magnitude slower
#new_u = input_xr[uv_varnames[0]].interp(x=x_search,y=y_search)
#new_v = input_xr[uv_varnames[1]].interp(x=x_search,y=y_search)
temp_xs = np.concatenate([temp_xs,temp_xs[-1:,:]-new_u.values.T*ds])
temp_ys = np.concatenate([temp_ys,temp_ys[-1:,:]-new_v.values.T*ds])
xs = np.concatenate([np.flipud(temp_xs),xs])
ys = np.concatenate([np.flipud(temp_ys),ys])
flowlines = []
for ind0 in np.arange(len(xs[0,:])):
xy = np.stack([xs[:,ind0],ys[:,ind0]]).T
flowlines.append(xy)
return flowlines
##########################################################################################
#### This version of the code doesn't work quite right...
##########################################################################################
##def calculate_flowlines(x,y,u,v,seed_points,max_error=0.00001,retry_count_threshold=10):
## """
## % (C) Nick Holschuh - Amherst College -- 2022 (Nick.Holschuh@gmail.com)
## %
## % This function prints out the minimum and maximum values of an array
## %
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## % The inputs are:
## %
## % input_array -- array of data to analyze
## %
## %%%%%%%%%%%%%%%
## % The outputs are:
## %
## % output -- the min and max in a 1x2 array
## %
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## """
##
## ################# This uses a modified plt.streamline to pass through a user-editable keyword
## ################# argument "max_error", which goes into the interpolater to guarantee
## ################# accurate streamline calculation. Copy the updated version of streamplot
## ################# into your matplotlib directory to enable the use of streamline
## ################# calculation from NDH_Tools (in streamline.py, which calls _integrate_rk12)
##
## if isinstance(seed_points,list):
## seed_points = np.array(seed_points)
##
## if len(seed_points.shape) == 1:
## seed_points = np.expand_dims(seed_points,axis=0)
##
## ###################### Initialize the returned object
## final_sls = []
## for ind0 in np.arange(len(seed_points[:,0])):
## final_sls.append([])
##
## retry_count = 0
## retry_inds = np.arange(0,len(seed_points[:,0]))
## seed_subset = seed_points
##
## while len(retry_inds) > 0:
##
## sls = []
##
## ################# Calculate the streamlines for all unfound seed points
## fig = plt.figure()
## if retry_count == 0:
## print('The initial streamline calculation -- this can be slow. Finding '+str(len(seed_subset[:,0]))+' streamlines')
## try:
## streamlines = plt.streamplot(x,y,u,v,start_points=seed_points, max_error=max_error, density=100)
## except:
## streamlines = plt.streamplot(x,y,u,v,start_points=seed_points, density=100)
## if retry_count == 0:
## print('Note: You need to update your matplotlib streamline.py and reduce the max error for this to work properly')
## plt.close(fig)
##
## ################# Here we extract the coordinate info from the streamlines
## sl_deconstruct = []
## for i in streamlines.lines.get_paths():
## sl_deconstruct.append(i.vertices[1])
## sl_deconstruct = np.array(sl_deconstruct)
##
## ################ Here we separate the streamlines based on large breaks in distance
## sl_dist = ndh.distance_vector(sl_deconstruct[:,0],sl_deconstruct[:,1],1)
## dist_mean = np.mean(sl_dist)
## breaks = np.where(sl_dist > (dist_mean+1)*50)[0]
## if len(breaks) > 0:
## breaks = np.concatenate([np.array([-1]),breaks,np.array([len(sl_deconstruct[:,0])])])+1
## else:
## breaks = np.array([0,len(sl_deconstruct[:,0])+1])
##
## for ind0 in np.arange(len(breaks)-1):
## sls.append(sl_deconstruct[breaks[ind0]:breaks[ind0+1],:])
##
## ################ Here we identify which streamline goes with which seed_point
## matching = []
## for ind0 in np.arange(len(seed_subset[:,0])):
## dists = []
## for ind1,sl in enumerate(sls):
## comp_vals = ndh.find_nearest_xy(sl,seed_subset[ind0,:])
## dists.append(comp_vals['distance'][0])
## best = np.where(np.array(dists) < 1e-8)[0]
## try:
## matching.append(best[0])
## except:
## matching.append(-1)
##
## ################# populate the final object
## for ind0,i in enumerate(matching):
## if i != -1:
## final_sls[retry_inds[ind0]] = sls[i]
##
## ################# Finally, we identify the new set of streamlines that need to be computed, based on which have no match
## new_retry_inds = np.where(np.array(matching) == -1)[0]
## seed_subset = seed_points[retry_inds[new_retry_inds],:]
## retry_inds = retry_inds[new_retry_inds]
##
## if len(retry_inds) > 0:
## retry_count = retry_count+1
## print('Recalculating for nearly overlapping points -- try '+str(retry_count)+'. Finding '+str(len(seed_subset[:,0]))+' streamlines')
##
## if retry_count > retry_count_threshold:
## break
##
## if 0:
## plt.figure()
## plt.plot(test_dist)
## plt.axhline(dist_median,c='orange')
##
## if 0:
## plt.figure()
## plt.plot(test[:,0],test[:,1],c='blue')
## for i in final_sls:
## plt.plot(i[:,0],i[:,1],c='red')
## plt.plot(seed_points[:,0],seed_points[:,1],'o')
##
##
## return final_sls