diff --git a/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/672_aerosol_aeronet_run_script.py b/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/672_aerosol_aeronet_run_script.py new file mode 100644 index 000000000..6b67b309f --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/672_aerosol_aeronet_run_script.py @@ -0,0 +1,8 @@ +from auxiliary_tools.cdat_regression_testing.base_run_script import run_set + +SET_NAME = "aerosol_aeronet" +SET_DIR = "672-aerosol-aeronet" +CFG_PATH: str | None = None +MULTIPROCESSING = True + +run_set(SET_NAME, SET_DIR, CFG_PATH, MULTIPROCESSING) diff --git a/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/dev-results/AERONET-AODABS-ANN-global.png b/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/dev-results/AERONET-AODABS-ANN-global.png new file mode 100644 index 000000000..159db66ff Binary files /dev/null and b/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/dev-results/AERONET-AODABS-ANN-global.png differ diff --git a/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/dev-results/AERONET-AODVIS-ANN-global.png b/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/dev-results/AERONET-AODVIS-ANN-global.png new file mode 100644 index 000000000..39e8f76c3 Binary files /dev/null and b/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/dev-results/AERONET-AODVIS-ANN-global.png differ diff --git a/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/main-results/AERONET-AODABS-ANN-global.png b/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/main-results/AERONET-AODABS-ANN-global.png new file mode 100644 index 000000000..c4d24110d Binary files /dev/null and b/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/main-results/AERONET-AODABS-ANN-global.png differ diff --git a/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/main-results/AERONET-AODVIS-ANN-global.png b/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/main-results/AERONET-AODVIS-ANN-global.png new file mode 100644 index 000000000..f40bb4c2e Binary files /dev/null and b/auxiliary_tools/cdat_regression_testing/672-aerosol-aeronet/main-results/AERONET-AODVIS-ANN-global.png differ diff --git a/e3sm_diags/driver/aerosol_aeronet_driver.py b/e3sm_diags/driver/aerosol_aeronet_driver.py index 4010ed606..f185d979a 100644 --- a/e3sm_diags/driver/aerosol_aeronet_driver.py +++ b/e3sm_diags/driver/aerosol_aeronet_driver.py @@ -1,20 +1,22 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import numpy as np import pandas as pd +import xarray as xr +import xcdat as xc from scipy import interpolate import e3sm_diags from e3sm_diags.driver import utils +from e3sm_diags.driver.utils.dataset_xr import Dataset from e3sm_diags.logger import custom_logger -from e3sm_diags.plot.cartopy import aerosol_aeronet_plot +from e3sm_diags.metrics.metrics import spatial_avg +from e3sm_diags.plot import aerosol_aeronet_plot if TYPE_CHECKING: - from cdms2.tvariable import TransientVariable - from e3sm_diags.parameter.core_parameter import CoreParameter @@ -25,74 +27,110 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: - """Runs the aerosol aeronet diagnostic. - - :param parameter: Parameters for the run - :type parameter: CoreParameter - :raises ValueError: Invalid run type - :return: Parameters for the run - :rtype: CoreParameter + """Run the aerosol aeronet diagnostics. + + Parameters + ---------- + parameter : CoreParameter + The parameter for the diagnostic. + + Returns + ------- + CoreParameter + The parameter for the diagnostic with the result (completed or failed). + + Raises + ------ + ValueError + If the run type is not valid. """ variables = parameter.variables run_type = parameter.run_type seasons = parameter.seasons - for season in seasons: - test_data = utils.dataset.Dataset(parameter, test=True) - parameter.test_name_yrs = utils.general.get_name_and_yrs( - parameter, test_data, season - ) - parameter.ref_name_yrs = "AERONET (2006-2015)" + test_ds = Dataset(parameter, data_type="test") - for var in variables: - logger.info("Variable: {}".format(var)) - parameter.var_id = var + for var_key in variables: + logger.info("Variable: {}".format(var_key)) + parameter.var_id = var_key - test = test_data.get_climo_variable(var, season) - test_site = interpolate_model_output_to_obs_sites(test, var) + for season in seasons: + ds_test = test_ds.get_climo_dataset(var_key, season) + da_test = ds_test[var_key] - if run_type == "model_vs_model": - ref_data = utils.dataset.Dataset(parameter, ref=True) - parameter.ref_name_yrs = utils.general.get_name_and_yrs( - parameter, ref_data, season + test_site_arr = interpolate_model_output_to_obs_sites( + ds_test[var_key], var_key ) - ref = ref_data.get_climo_variable(var, season) - ref_site = interpolate_model_output_to_obs_sites(ref, var) - elif run_type == "model_vs_obs": - ref_site = interpolate_model_output_to_obs_sites(None, var) - else: - raise ValueError("Invalid run_type={}".format(run_type)) + parameter.test_name_yrs = test_ds.get_name_yrs_attr(season) + parameter.ref_name_yrs = "AERONET (2006-2015)" - parameter.output_file = ( - f"{parameter.ref_name}-{parameter.var_id}-{season}-global" - ) - aerosol_aeronet_plot.plot(test, test_site, ref_site, parameter) + if run_type == "model_vs_model": + ref_ds = Dataset(parameter, data_type="ref") + + parameter.ref_name_yrs = utils.general.get_name_and_yrs( + parameter, ref_ds, season + ) + + ds_ref = ref_ds.get_climo_dataset(var_key, season) + ref_site_arr = interpolate_model_output_to_obs_sites( + ds_ref[var_key], var_key + ) + elif run_type == "model_vs_obs": + ref_site_arr = interpolate_model_output_to_obs_sites(None, var_key) + else: + raise ValueError("Invalid run_type={}".format(run_type)) + + parameter.output_file = ( + f"{parameter.ref_name}-{parameter.var_id}-{season}-global" + ) + + metrics_dict = { + "max": da_test.max().item(), + "min": da_test.min().item(), + "mean": spatial_avg(ds_test, var_key, axis=["X", "Y"]), + } + aerosol_aeronet_plot.plot( + parameter, da_test, test_site_arr, ref_site_arr, metrics_dict + ) return parameter def interpolate_model_output_to_obs_sites( - var: Optional[TransientVariable], var_id: str -): + da_var: xr.DataArray | None, var_key: str +) -> np.ndarray: """Interpolate model outputs (on regular lat lon grids) to observational sites - :param var: Input model variable, var_id: name of the variable - :type var: TransientVariable or NoneType, var_id: str - :raises IOError: Invalid variable input - :return: interpolated values over all observational sites - :rtype: 1-D numpy.array + # TODO: Add test coverage for this function. + + Parameters + ---------- + da_var : xr.DataArray | None + An optional input model variable dataarray. + var_key : str + The key of the variable. + Returns + ------- + np.ndarray + The interpolated values over all observational sites. + + Raises + ------ + IOError + If the variable key is invalid. """ logger.info( "Interpolate model outputs (on regular lat lon grids) to observational sites" ) - if var_id == "AODABS": + + if var_key == "AODABS": aeronet_file = os.path.join( e3sm_diags.INSTALL_PATH, "aerosol_aeronet/aaod550_AERONET_2006-2015.txt" ) var_header = "aaod" - elif var_id == "AODVIS": + elif var_key == "AODVIS": aeronet_file = os.path.join( e3sm_diags.INSTALL_PATH, "aerosol_aeronet/aod550_AERONET_2006-2015.txt" ) @@ -102,22 +140,24 @@ def interpolate_model_output_to_obs_sites( data_obs = pd.read_csv(aeronet_file, dtype=object, sep=",") - lonloc = np.array(data_obs["lon"].astype(float)) - latloc = np.array(data_obs["lat"].astype(float)) - obsloc = np.array(data_obs[var_header].astype(float)) - # sitename = np.array(data_obs["site"].astype(str)) - nsite = len(obsloc) + lon_loc = np.array(data_obs["lon"].astype(float)) + lat_loc = np.array(data_obs["lat"].astype(float)) + obs_loc = np.array(data_obs[var_header].astype(float)) - # express lonloc from 0 to 360 - lonloc[lonloc < 0.0] = lonloc[lonloc < 0.0] + 360.0 + num_sites = len(obs_loc) - if var is not None: - f_intp = interpolate.RectBivariateSpline( - var.getLatitude()[:], var.getLongitude()[:], var - ) - var_intp = np.zeros(nsite) - for i in range(nsite): - var_intp[i] = f_intp(latloc[i], lonloc[i]) + # Express lon_loc from 0 to 360. + lon_loc[lon_loc < 0.0] = lon_loc[lon_loc < 0.0] + 360.0 + + if da_var is not None: + lat = xc.get_dim_coords(da_var, axis="Y") + lon = xc.get_dim_coords(da_var, axis="X") + f_intp = interpolate.RectBivariateSpline(lat.values, lon.values, da_var.values) + + var_intp = np.zeros(num_sites) + for i in range(num_sites): + var_intp[i] = f_intp(lat_loc[i], lon_loc[i]) return var_intp - return obsloc + + return obs_loc diff --git a/e3sm_diags/driver/zonal_mean_2d_driver.py b/e3sm_diags/driver/zonal_mean_2d_driver.py index da2aa26c0..c256877e9 100755 --- a/e3sm_diags/driver/zonal_mean_2d_driver.py +++ b/e3sm_diags/driver/zonal_mean_2d_driver.py @@ -18,7 +18,7 @@ DEFAULT_PLEVS, ZonalMean2dParameter, ) -from e3sm_diags.plot.cartopy.zonal_mean_2d_plot import plot as plot_func +from e3sm_diags.plot.zonal_mean_2d_plot import plot as plot_func logger = custom_logger(__name__) diff --git a/e3sm_diags/plot/aerosol_aeronet_plot.py b/e3sm_diags/plot/aerosol_aeronet_plot.py new file mode 100644 index 000000000..14293ba4c --- /dev/null +++ b/e3sm_diags/plot/aerosol_aeronet_plot.py @@ -0,0 +1,114 @@ +import matplotlib +import numpy as np +import xarray as xr + +from e3sm_diags.driver.utils.type_annotations import MetricsDict +from e3sm_diags.logger import custom_logger +from e3sm_diags.parameter.core_parameter import CoreParameter +from e3sm_diags.plot.lat_lon_plot import _add_colormap +from e3sm_diags.plot.utils import _save_plot + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # isort:skip # noqa: E402 + +logger = custom_logger(__name__) + +# Plot scatter plot +# Position and sizes of subplot axes in page coordinates (0 to 1) +# (left, bottom, width, height) in page coordinates +PANEL_CFG = [ + (0.09, 0.40, 0.72, 0.30), + (0.19, 0.2, 0.62, 0.30), +] +# Border padding relative to subplot axes for saving individual panels +# (left, bottom, right, top) in page coordinates. +BORDER_PADDING = (-0.06, -0.03, 0.13, 0.03) + + +def plot( + parameter: CoreParameter, + da_test: xr.DataArray, + test_site_arr: np.ndarray, + ref_site_arr: np.ndarray, + metrics_dict: MetricsDict, +): + """Plot the test variable's metrics generated for the aerosol_aeronet set. + + Parameters + ---------- + parameter : CoreParameter + The CoreParameter object containing plot configurations. + da_test : xr.DataArray + The test data. + test_site : np.ndarray + The array containing values for the test site. + ref_site : np.ndarray + The array containing values for the ref site. + metrics_dict : MetricsDict + The metrics. + """ + fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) + fig.suptitle(parameter.var_id, x=0.5, y=0.97) + + # Add the colormap subplot for test data. + min = metrics_dict["min"] + mean = metrics_dict["mean"] + max = metrics_dict["max"] + + _add_colormap( + 0, + da_test, + fig, + parameter, + parameter.test_colormap, + parameter.contour_levels, + title=(parameter.test_name_yrs, None, None), # type: ignore + metrics=(max, mean, min), # type: ignore + ) + + # Add the scatter plot. + ax = fig.add_axes(PANEL_CFG[1]) + ax.set_title(f"{parameter.var_id} from AERONET sites") + + # Define 1:1 line, and x, y axis limits. + if parameter.var_id == "AODVIS": + x1 = np.arange(0.01, 3.0, 0.1) + y1 = np.arange(0.01, 3.0, 0.1) + plt.xlim(0.03, 1) + plt.ylim(0.03, 1) + else: + x1 = np.arange(0.0001, 1.0, 0.01) + y1 = np.arange(0.0001, 1.0, 0.01) + plt.xlim(0.001, 0.3) + plt.ylim(0.001, 0.3) + + plt.loglog(x1, y1, "-k", linewidth=0.5) + plt.loglog(x1, y1 * 0.5, "--k", linewidth=0.5) + plt.loglog(x1 * 0.5, y1, "--k", linewidth=0.5) + + corr = np.corrcoef(ref_site_arr, test_site_arr) + xmean = np.mean(ref_site_arr) + ymean = np.mean(test_site_arr) + ax.text( + 0.3, + 0.9, + f"Mean (test): {ymean:.3f} \n Mean (ref): {xmean:.3f}\n Corr: {corr[0, 1]:.2f}", + horizontalalignment="right", + verticalalignment="top", + transform=ax.transAxes, + ) + + # Configure axis ticks. + plt.tick_params(axis="both", which="major") + plt.tick_params(axis="both", which="minor") + + # Configure axis labels + plt.xlabel(f"ref: {parameter.ref_name_yrs}") + plt.ylabel(f"test: {parameter.test_name_yrs}") + + plt.loglog(ref_site_arr, test_site_arr, "kx", markersize=3.0, mfc="none") + + # Configure legend. + plt.legend(frameon=False, prop={"size": 5}) + + _save_plot(fig, parameter, PANEL_CFG, BORDER_PADDING) diff --git a/e3sm_diags/plot/cartopy/aerosol_aeronet_plot.py b/e3sm_diags/plot/cartopy/aerosol_aeronet_plot.py deleted file mode 100644 index 765235095..000000000 --- a/e3sm_diags/plot/cartopy/aerosol_aeronet_plot.py +++ /dev/null @@ -1,132 +0,0 @@ -import os - -import cartopy.crs as ccrs -import matplotlib -import numpy as np - -from e3sm_diags.driver.utils.general import get_output_dir -from e3sm_diags.logger import custom_logger -from e3sm_diags.metrics import mean -from e3sm_diags.plot.cartopy.deprecated_lat_lon_plot import plot_panel - -matplotlib.use("Agg") -import matplotlib.pyplot as plt # isort:skip # noqa: E402 - -logger = custom_logger(__name__) - -plotTitle = {"fontsize": 11.5} -plotSideTitle = {"fontsize": 9.5} - - -def plot(test, test_site, ref_site, parameter): - # Plot scatter plot - # Position and sizes of subplot axes in page coordinates (0 to 1) - # (left, bottom, width, height) in page coordinates - panel = [ - (0.09, 0.40, 0.72, 0.30), - (0.19, 0.2, 0.62, 0.30), - ] - # Border padding relative to subplot axes for saving individual panels - # (left, bottom, right, top) in page coordinates - border = (-0.06, -0.03, 0.13, 0.03) - - fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) - fig.suptitle(parameter.var_id, x=0.5, y=0.97) - proj = ccrs.PlateCarree() - max1 = test.max() - min1 = test.min() - mean1 = mean(test) - # TODO: Replace this function call with `e3sm_diags.plot.utils._add_colormap()`. - plot_panel( - 0, - fig, - proj, - test, - parameter.contour_levels, - parameter.test_colormap, - (parameter.test_name_yrs, None, None), - parameter, - stats=(max1, mean1, min1), - ) - - ax = fig.add_axes(panel[1]) - ax.set_title(f"{parameter.var_id} from AERONET sites") - - # define 1:1 line, and x y axis limits - - if parameter.var_id == "AODVIS": - x1 = np.arange(0.01, 3.0, 0.1) - y1 = np.arange(0.01, 3.0, 0.1) - plt.xlim(0.03, 1) - plt.ylim(0.03, 1) - else: - x1 = np.arange(0.0001, 1.0, 0.01) - y1 = np.arange(0.0001, 1.0, 0.01) - plt.xlim(0.001, 0.3) - plt.ylim(0.001, 0.3) - - plt.loglog(x1, y1, "-k", linewidth=0.5) - plt.loglog(x1, y1 * 0.5, "--k", linewidth=0.5) - plt.loglog(x1 * 0.5, y1, "--k", linewidth=0.5) - - corr = np.corrcoef(ref_site, test_site) - xmean = np.mean(ref_site) - ymean = np.mean(test_site) - ax.text( - 0.3, - 0.9, - f"Mean (test): {ymean:.3f} \n Mean (ref): {xmean:.3f}\n Corr: {corr[0, 1]:.2f}", - horizontalalignment="right", - verticalalignment="top", - transform=ax.transAxes, - ) - - # axis ticks - plt.tick_params(axis="both", which="major") - plt.tick_params(axis="both", which="minor") - - # axis labels - plt.xlabel(f"ref: {parameter.ref_name_yrs}") - plt.ylabel(f"test: {parameter.test_name_yrs}") - - plt.loglog(ref_site, test_site, "kx", markersize=3.0, mfc="none") - - # legend - plt.legend(frameon=False, prop={"size": 5}) - - # TODO: This section can be refactored to use `plot.utils._save_plot()`. - for f in parameter.output_format: - f = f.lower().split(".")[-1] - fnm = os.path.join( - get_output_dir(parameter.current_set, parameter), - f"{parameter.output_file}" + "." + f, - ) - plt.savefig(fnm) - logger.info(f"Plot saved in: {fnm}") - - for f in parameter.output_format_subplot: - fnm = os.path.join( - get_output_dir(parameter.current_set, parameter), - parameter.output_file, - ) - page = fig.get_size_inches() - i = 0 - for p in panel: - # Extent of subplot - subpage = np.array(p).reshape(2, 2) - subpage[1, :] = subpage[0, :] + subpage[1, :] - subpage = subpage + np.array(border).reshape(2, 2) - subpage = list(((subpage) * page).flatten()) # type: ignore - extent = matplotlib.transforms.Bbox.from_extents(*subpage) - # Save subplot - fname = fnm + ".%i." % (i) + f - plt.savefig(fname, bbox_inches=extent) - - orig_fnm = os.path.join( - get_output_dir(parameter.current_set, parameter), - parameter.output_file, - ) - fname = orig_fnm + ".%i." % (i) + f - logger.info(f"Sub-plot saved in: {fname}") - - i += 1 diff --git a/e3sm_diags/plot/cartopy/deprecated_lat_lon_plot.py b/e3sm_diags/plot/cartopy/deprecated_lat_lon_plot.py deleted file mode 100644 index 4eaebcf80..000000000 --- a/e3sm_diags/plot/cartopy/deprecated_lat_lon_plot.py +++ /dev/null @@ -1,360 +0,0 @@ -""" -WARNING: This module has been deprecated and replaced by -`e3sm_diags.plot.lat_lon_plot.py`. This file temporarily kept because -`e3sm_diags.plot.cartopy.aerosol_aeronet_plot.plot` references the -`plot_panel()` function. Once the aerosol_aeronet set is refactored, this -file can be deleted. -""" -from __future__ import print_function - -import os - -import cartopy.crs as ccrs -import cartopy.feature as cfeature -import cdutil -import matplotlib -import numpy as np -import numpy.ma as ma -from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter - -from e3sm_diags.derivations.default_regions import regions_specs -from e3sm_diags.driver.utils.general import get_output_dir -from e3sm_diags.logger import custom_logger -from e3sm_diags.plot import get_colormap - -matplotlib.use("Agg") -import matplotlib.colors as colors # isort:skip # noqa: E402 -import matplotlib.pyplot as plt # isort:skip # noqa: E402 - -logger = custom_logger(__name__) - -plotTitle = {"fontsize": 11.5} -plotSideTitle = {"fontsize": 9.5} - -# Position and sizes of subplot axes in page coordinates (0 to 1) -panel = [ - (0.1691, 0.6810, 0.6465, 0.2258), - (0.1691, 0.3961, 0.6465, 0.2258), - (0.1691, 0.1112, 0.6465, 0.2258), -] - -# Border padding relative to subplot axes for saving individual panels -# (left, bottom, right, top) in page coordinates -border = (-0.06, -0.03, 0.13, 0.03) - - -def add_cyclic(var): - lon = var.getLongitude() - return var(longitude=(lon[0], lon[0] + 360.0, "coe")) - - -def get_ax_size(fig, ax): - bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) - width, height = bbox.width, bbox.height - width *= fig.dpi - height *= fig.dpi - return width, height - - -def determine_tick_step(degrees_covered): - if degrees_covered > 180: - return 60 - if degrees_covered > 60: - return 30 - elif degrees_covered > 30: - return 10 - elif degrees_covered > 20: - return 5 - else: - return 1 - - -def plot_panel( # noqa: C901 - n, fig, proj, var, clevels, cmap, title, parameters, stats=None -): - var = add_cyclic(var) - lon = var.getLongitude() - lat = var.getLatitude() - var = ma.squeeze(var.asma()) - - # Contour levels - levels = None - norm = None - if len(clevels) > 0: - levels = [-1.0e8] + clevels + [1.0e8] - norm = colors.BoundaryNorm(boundaries=levels, ncolors=256) - - # ax.set_global() - region_str = parameters.regions[0] - region = regions_specs[region_str] - global_domain = True - full_lon = True - if "domain" in region.keys(): # type: ignore - # Get domain to plot - domain = region["domain"] # type: ignore - global_domain = False - else: - # Assume global domain - domain = cdutil.region.domain(latitude=(-90.0, 90, "ccb")) - kargs = domain.components()[0].kargs - lon_west, lon_east, lat_south, lat_north = (0, 360, -90, 90) - if "longitude" in kargs: - full_lon = False - lon_west, lon_east, _ = kargs["longitude"] - # Note cartopy Problem with gridlines across the dateline:https://github.com/SciTools/cartopy/issues/821. Region cross dateline is not supported yet. - if lon_west > 180 and lon_east > 180: - lon_west = lon_west - 360 - lon_east = lon_east - 360 - - if "latitude" in kargs: - lat_south, lat_north, _ = kargs["latitude"] - lon_covered = lon_east - lon_west - lon_step = determine_tick_step(lon_covered) - xticks = np.arange(lon_west, lon_east, lon_step) - # Subtract 0.50 to get 0 W to show up on the right side of the plot. - # If less than 0.50 is subtracted, then 0 W will overlap 0 E on the left side of the plot. - # If a number is added, then the value won't show up at all. - if global_domain or full_lon: - xticks = np.append(xticks, lon_east - 0.50) - proj = ccrs.PlateCarree(central_longitude=180) - else: - xticks = np.append(xticks, lon_east) - lat_covered = lat_north - lat_south - lat_step = determine_tick_step(lat_covered) - yticks = np.arange(lat_south, lat_north, lat_step) - yticks = np.append(yticks, lat_north) - - # Contour plot - ax = fig.add_axes(panel[n], projection=proj) - ax.set_extent([lon_west, lon_east, lat_south, lat_north], crs=proj) - cmap = get_colormap(cmap, parameters) - p1 = ax.contourf( - lon, - lat, - var, - transform=ccrs.PlateCarree(), - norm=norm, - levels=levels, - cmap=cmap, - extend="both", - ) - - # ax.set_aspect('auto') - # Full world would be aspect 360/(2*180) = 1 - ax.set_aspect((lon_east - lon_west) / (2 * (lat_north - lat_south))) - ax.coastlines(lw=0.3) - if not global_domain and "RRM" in region_str: - ax.coastlines(resolution="50m", color="black", linewidth=1) - state_borders = cfeature.NaturalEarthFeature( - category="cultural", - name="admin_1_states_provinces_lakes", - scale="50m", - facecolor="none", - ) - ax.add_feature(state_borders, edgecolor="black") - if title[0] is not None: - ax.set_title(title[0], loc="left", fontdict=plotSideTitle) - if title[1] is not None: - ax.set_title(title[1], fontdict=plotTitle) - if title[2] is not None: - ax.set_title(title[2], loc="right", fontdict=plotSideTitle) - ax.set_xticks(xticks, crs=ccrs.PlateCarree()) - ax.set_yticks(yticks, crs=ccrs.PlateCarree()) - lon_formatter = LongitudeFormatter(zero_direction_label=True, number_format=".0f") - lat_formatter = LatitudeFormatter() - ax.xaxis.set_major_formatter(lon_formatter) - ax.yaxis.set_major_formatter(lat_formatter) - ax.tick_params(labelsize=8.0, direction="out", width=1) - ax.xaxis.set_ticks_position("bottom") - ax.yaxis.set_ticks_position("left") - - # Color bar - cbax = fig.add_axes((panel[n][0] + 0.6635, panel[n][1] + 0.0215, 0.0326, 0.1792)) - cbar = fig.colorbar(p1, cax=cbax) - w, h = get_ax_size(fig, cbax) - - if levels is None: - cbar.ax.tick_params(labelsize=9.0, length=0) - - else: - maxval = np.amax(np.absolute(levels[1:-1])) - if maxval < 0.2: - fmt = "%5.3f" - pad = 28 - elif maxval < 10.0: - fmt = "%5.2f" - pad = 25 - elif maxval < 100.0: - fmt = "%5.1f" - pad = 25 - elif maxval > 9999.0: - fmt = "%.0f" - pad = 40 - else: - fmt = "%6.1f" - pad = 30 - - cbar.set_ticks(levels[1:-1]) - labels = [fmt % level for level in levels[1:-1]] - cbar.ax.set_yticklabels(labels, ha="right") - cbar.ax.tick_params(labelsize=9.0, pad=pad, length=0) - - # Min, Mean, Max - fig.text( - panel[n][0] + 0.6635, - panel[n][1] + 0.2107, - "Max\nMean\nMin", - ha="left", - fontdict=plotSideTitle, - ) - - fmt_m = [] - # printing in scientific notation if value greater than 10^5 - for i in range(len(stats[0:3])): - fs = "1e" if stats[i] > 100000.0 else "2f" - fmt_m.append(fs) - fmt_metrics = f"%.{fmt_m[0]}\n%.{fmt_m[1]}\n%.{fmt_m[2]}" - - fig.text( - panel[n][0] + 0.7635, - panel[n][1] + 0.2107, - # "%.2f\n%.2f\n%.2f" % stats[0:3], - fmt_metrics % stats[0:3], - ha="right", - fontdict=plotSideTitle, - ) - - # RMSE, CORR - if len(stats) == 5: - fig.text( - panel[n][0] + 0.6635, - panel[n][1] - 0.0105, - "RMSE\nCORR", - ha="left", - fontdict=plotSideTitle, - ) - fig.text( - panel[n][0] + 0.7635, - panel[n][1] - 0.0105, - "%.2f\n%.2f" % stats[3:5], - ha="right", - fontdict=plotSideTitle, - ) - - # grid resolution info: - if n == 2 and "RRM" in region_str: - dlat = lat[2] - lat[1] - dlon = lon[2] - lon[1] - fig.text( - panel[n][0] + 0.4635, - panel[n][1] - 0.04, - "Resolution: {:.2f}x{:.2f}".format(dlat, dlon), - ha="left", - fontdict=plotSideTitle, - ) - - -def plot(reference, test, diff, metrics_dict, parameter): - # Create figure, projection - fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) - proj = ccrs.PlateCarree() - - # Figure title - fig.suptitle(parameter.main_title, x=0.5, y=0.96, fontsize=18) - - # First two panels - min1 = metrics_dict["test"]["min"] - mean1 = metrics_dict["test"]["mean"] - max1 = metrics_dict["test"]["max"] - - plot_panel( - 0, - fig, - proj, - test, - parameter.contour_levels, - parameter.test_colormap, - (parameter.test_name_yrs, parameter.test_title, test.units), - parameter, - stats=(max1, mean1, min1), - ) - - if not parameter.model_only: - min2 = metrics_dict["ref"]["min"] - mean2 = metrics_dict["ref"]["mean"] - max2 = metrics_dict["ref"]["max"] - - plot_panel( - 1, - fig, - proj, - reference, - parameter.contour_levels, - parameter.reference_colormap, - (parameter.ref_name_yrs, parameter.reference_title, reference.units), - parameter, - stats=(max2, mean2, min2), - ) - - # Third panel - min3 = metrics_dict["diff"]["min"] - mean3 = metrics_dict["diff"]["mean"] - max3 = metrics_dict["diff"]["max"] - r = metrics_dict["misc"]["rmse"] - c = metrics_dict["misc"]["corr"] - plot_panel( - 2, - fig, - proj, - diff, - parameter.diff_levels, - parameter.diff_colormap, - (None, parameter.diff_title, test.units), - parameter, - stats=(max3, mean3, min3, r, c), - ) - - # Save figure - for f in parameter.output_format: - f = f.lower().split(".")[-1] - fnm = os.path.join( - get_output_dir(parameter.current_set, parameter), - parameter.output_file + "." + f, - ) - plt.savefig(fnm) - logger.info(f"Plot saved in: {fnm}") - - # Save individual subplots - if parameter.ref_name == "": - panels = [panel[0]] - else: - panels = panel - - for f in parameter.output_format_subplot: - fnm = os.path.join( - get_output_dir(parameter.current_set, parameter), - parameter.output_file, - ) - page = fig.get_size_inches() - i = 0 - for p in panels: - # Extent of subplot - subpage = np.array(p).reshape(2, 2) - subpage[1, :] = subpage[0, :] + subpage[1, :] - subpage = subpage + np.array(border).reshape(2, 2) - subpage = list(((subpage) * page).flatten()) # type: ignore - extent = matplotlib.transforms.Bbox.from_extents(*subpage) - # Save subplot - fname = fnm + ".%i." % (i) + f - plt.savefig(fname, bbox_inches=extent) - - orig_fnm = os.path.join( - get_output_dir(parameter.current_set, parameter), - parameter.output_file, - ) - fname = orig_fnm + ".%i." % (i) + f - logger.info(f"Sub-plot saved in: {fname}") - - i += 1 - - plt.close() diff --git a/e3sm_diags/plot/cartopy/zonal_mean_2d_stratosphere_plot.py b/e3sm_diags/plot/cartopy/zonal_mean_2d_stratosphere_plot.py deleted file mode 100644 index 004f3c93d..000000000 --- a/e3sm_diags/plot/cartopy/zonal_mean_2d_stratosphere_plot.py +++ /dev/null @@ -1,15 +0,0 @@ -import xarray as xr - -from e3sm_diags.driver.utils.type_annotations import MetricsDict -from e3sm_diags.parameter.core_parameter import CoreParameter -from e3sm_diags.plot.cartopy.zonal_mean_2d_plot import plot as base_plot - - -def plot( - parameter: CoreParameter, - da_test: xr.DataArray, - da_ref: xr.DataArray, - da_diff: xr.DataArray, - metrics_dict: MetricsDict, -): - return base_plot(parameter, da_test, da_ref, da_diff, metrics_dict) diff --git a/e3sm_diags/plot/cartopy/zonal_mean_xy_plot.py b/e3sm_diags/plot/cartopy/zonal_mean_xy_plot.py deleted file mode 100644 index 938faab48..000000000 --- a/e3sm_diags/plot/cartopy/zonal_mean_xy_plot.py +++ /dev/null @@ -1,138 +0,0 @@ -from __future__ import print_function - -import os - -import matplotlib -import numpy as np -import numpy.ma as ma - -from e3sm_diags.driver.utils.general import get_output_dir -from e3sm_diags.logger import custom_logger - -matplotlib.use("Agg") -import matplotlib.pyplot as plt # isort:skip # noqa: E402 - -logger = custom_logger(__name__) - -plotTitle = {"fontsize": 12.5} -plotSideTitle = {"fontsize": 11.5} - -# Position and sizes of subplot axes in page coordinates (0 to 1) -panel = [ - (0.1500, 0.5500, 0.7500, 0.3000), - (0.1500, 0.1300, 0.7500, 0.3000), -] - -# Border padding relative to subplot axes for saving individual panels -# (left, bottom, right, top) in page coordinates -border = (-0.14, -0.06, 0.04, 0.08) - - -def get_ax_size(fig, ax): - bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) - width, height = bbox.width, bbox.height - width *= fig.dpi - height *= fig.dpi - return width, height - - -def plot(reference, test, diff, metrics_dict, parameter): - # Create figure - fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) - - # Top panel - ax1 = fig.add_axes(panel[0]) - ax1.plot(test.getLatitude()[:], ma.squeeze(test.asma()), "k", linewidth=2) - ax1.plot( - reference.getLatitude()[:], - ma.squeeze(reference.asma()), - "r", - linewidth=2, - ) - ax1.set_xticks([-90, -60, -30, 0, 30, 60, 90]) - ax1.set_xlim(-90, 90) - ax1.tick_params(labelsize=11.0, direction="out", width=1) - ax1.xaxis.set_ticks_position("bottom") - ax1.set_ylabel(test.long_name + " (" + test.units + ")") - - test_title = "Test" if parameter.test_title == "" else parameter.test_title - test_title += " : {}".format(parameter.test_name_yrs) - ref_title = ( - "Reference" if parameter.reference_title == "" else parameter.reference_title - ) - ref_title += " : {}".format(parameter.ref_name_yrs) - fig.text( - panel[0][0], - panel[0][1] + panel[0][3] + 0.03, - test_title, - ha="left", - fontdict=plotSideTitle, - color="black", - ) - fig.text( - panel[0][0], - panel[0][1] + panel[0][3] + 0.01, - ref_title, - ha="left", - fontdict=plotSideTitle, - color="red", - ) - - # Bottom panel - ax2 = fig.add_axes(panel[1]) - ax2.plot(diff.getLatitude()[:], ma.squeeze(diff.asma()), "k", linewidth=2) - ax2.axhline(y=0, color="0.5") - ax2.set_title(parameter.diff_title, fontdict=plotSideTitle, loc="center") - ax2.set_xticks([-90, -60, -30, 0, 30, 60, 90]) - ax2.set_xlim(-90, 90) - ax2.tick_params(labelsize=11.0, direction="out", width=1) - ax2.xaxis.set_ticks_position("bottom") - ax2.set_ylabel(test.long_name + " (" + test.units + ")") - - # Figure title - fig.suptitle(parameter.main_title, x=0.5, y=0.95, fontsize=18) - - # Save figure - for f in parameter.output_format: - f = f.lower().split(".")[-1] - fnm = os.path.join( - get_output_dir(parameter.current_set, parameter), - parameter.output_file + "." + f, - ) - plt.savefig(fnm) - # Get the filename that the user has passed in and display that. - fnm = os.path.join( - get_output_dir(parameter.current_set, parameter), - parameter.output_file + "." + f, - ) - logger.info(f"Plot saved in: {fnm}") - - # Save individual subplots - for f in parameter.output_format_subplot: - fnm = os.path.join( - get_output_dir(parameter.current_set, parameter), - parameter.output_file, - ) - page = fig.get_size_inches() - i = 0 - for p in panel: - # Extent of subplot - subpage = np.array(p).reshape(2, 2) - subpage[1, :] = subpage[0, :] + subpage[1, :] - subpage = subpage + np.array(border).reshape(2, 2) - subpage = list(((subpage) * page).flatten()) # type: ignore - extent = matplotlib.transforms.Bbox.from_extents(*subpage) - # Save subplot - fname = fnm + ".%i." % (i) + f - plt.savefig(fname, bbox_inches=extent) - - orig_fnm = os.path.join( - get_output_dir(parameter.current_set, parameter), - parameter.output_file, - ) - fname = orig_fnm + ".%i." % (i) + f - logger.info(f"Sub-plot saved in: {fname}") - - i += 1 - - plt.close() diff --git a/e3sm_diags/plot/utils.py b/e3sm_diags/plot/utils.py index e365c2995..4c224e66b 100644 --- a/e3sm_diags/plot/utils.py +++ b/e3sm_diags/plot/utils.py @@ -37,7 +37,8 @@ # Border padding relative to subplot axes for saving individual panels # (left, bottom, right, top) in page coordinates -DEFAULT_BORDER_PADDING = (-0.06, -0.03, 0.13, 0.03) +BorderPadding = Tuple[float, float, float, float] +DEFAULT_BORDER_PADDING: BorderPadding = (-0.06, -0.03, 0.13, 0.03) # Sets that use the lat_lon formatter to configure the X and Y axes of the plot. SETS_USING_LAT_LON_FORMATTER = [ @@ -56,7 +57,7 @@ def _save_plot( fig: plt.Figure, parameter: CoreParameter, panel_configs: PanelConfig = DEFAULT_PANEL_CFG, - border_padding: Tuple[float, float, float, float] = DEFAULT_BORDER_PADDING, + border_padding: BorderPadding = DEFAULT_BORDER_PADDING, ): """Save the plot using the figure object and parameter configs. @@ -130,7 +131,6 @@ def _add_grid_res_info(fig, subplot_num, region_key, lat, lon, panel_configs): ha="left", fontdict={"fontsize": SECONDARY_TITLE_FONTSIZE}, ) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def _make_lon_cyclic(var: xr.DataArray): diff --git a/e3sm_diags/plot/cartopy/zonal_mean_2d_plot.py b/e3sm_diags/plot/zonal_mean_2d_plot.py similarity index 100% rename from e3sm_diags/plot/cartopy/zonal_mean_2d_plot.py rename to e3sm_diags/plot/zonal_mean_2d_plot.py