From 4a56dce4f6979048352578b62dca8951aba5a841 Mon Sep 17 00:00:00 2001 From: Feng Zhu Date: Mon, 15 Jan 2024 18:11:28 -0700 Subject: [PATCH] Improve proxy visualization funcs --- cfr/proxy.py | 13 ++++++++++--- cfr/utils.py | 4 ++++ cfr/visual.py | 11 +++++++++-- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/cfr/proxy.py b/cfr/proxy.py index e884d42..de481e6 100644 --- a/cfr/proxy.py +++ b/cfr/proxy.py @@ -621,7 +621,7 @@ def plotly(self, **kwargs): return fig def plot(self, figsize=[12, 4], legend=False, ms=200, stock_img=True, edge_clr='w', - wspace=0.1, hspace=0.1, plot_map=True, **kwargs): + wspace=0.1, hspace=0.1, plot_map=True, p=visual.STYLE, **kwargs): ''' Visualize the ProxyRecord Args: @@ -635,7 +635,10 @@ def plot(self, figsize=[12, 4], legend=False, ms=200, stock_img=True, edge_clr=' plot_map (bool): if True, plot the record on a map. Defaults to True. ''' if 'color' not in kwargs and 'c' not in kwargs: - kwargs['color'] = visual.STYLE.colors_dict[self.ptype] + if self.ptype in p.colors_dict: + kwargs['color'] = p.colors_dict[self.ptype] + else: + kwargs['color'] = 'tab:blue' fig = plt.figure(figsize=figsize) @@ -671,9 +674,13 @@ def plot(self, figsize=[12, 4], legend=False, ms=200, stock_img=True, edge_clr=' if stock_img: ax['map'].stock_img() + if self.ptype in p.markers_dict: + marker = p.markers_dict[self.ptype] + else: + marker = 'o' transform=ccrs.PlateCarree() ax['map'].scatter( - self.lon, self.lat, marker=visual.STYLE.markers_dict[self.ptype], + self.lon, self.lat, marker=marker, s=ms, c=kwargs['color'], edgecolor=edge_clr, transform=transform, ) diff --git a/cfr/utils.py b/cfr/utils.py index cd7d348..df2fd49 100644 --- a/cfr/utils.py +++ b/cfr/utils.py @@ -611,6 +611,10 @@ def colored_noise_2regimes(alpha1, alpha2, f_break, t, f0=None, m=None, seed=Non return y +def arr_str2np(arr): + arr = np.array([float(s) for s in arr[1:-1].split(',')]) + return arr + def is_numeric(obj): attrs = ['__add__', '__sub__', '__mul__', '__truediv__', '__pow__'] return all(hasattr(obj, attr) for attr in attrs) diff --git a/cfr/visual.py b/cfr/visual.py index 48a2a98..da67a2f 100644 --- a/cfr/visual.py +++ b/cfr/visual.py @@ -555,10 +555,17 @@ def plot_proxies(df, year=np.arange(2001), lon_col='lon', lat_col='lat', type_co type_names.append(f'{ptype} (n={max_count[-1]})') lons = list(df[selector][lon_col]) lats = list(df[selector][lat_col]) + if ptype in markers_dict: + marker = markers_dict[ptype] + color = colors_dict[ptype] + else: + marker = 'o' + color = 'tab:blue' + s_plots.append( ax['map'].scatter( - lons, lats, marker=markers_dict[ptype], - c=colors_dict[ptype], edgecolor='k', s=markersize, transform=ccrs.PlateCarree() + lons, lats, marker=marker, c=color, + edgecolor='k', s=markersize, transform=ccrs.PlateCarree() ) )