Skip to content

Commit

Permalink
Add PR review changes
Browse files Browse the repository at this point in the history
- Update instructions for `template_run_script.py` to run the script as a Python module
- Add `panel_configs` and `border_padding` args to `plot.utils.save_plot()`
- Delete `_save_plot()` in `zonal_mean_xy_plot.py` and replace with `plot.utils.save_plot()`
- Update `_calc_zonal_mean()` to reuse `metrics.spatial_avg()`
  • Loading branch information
tomvothecoder committed Feb 1, 2024
1 parent 1530283 commit b19aec0
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 107 deletions.
12 changes: 7 additions & 5 deletions auxiliary_tools/cdat_regression_testing/template_run_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
"meridional_mean_2d", "annual_cycle_zonal_mean", "enso_diags", "qbo",
"area_mean_time_series", "diurnal_cycle", "streamflow", "arm_diags",
"tc_analysis", "aerosol_aeronet", "aerosol_budget", "mp_partition",
6. Run this script
- Make sure to run this command on NERSC perlmutter cpu:
`salloc --nodes 1 --qos interactive --time 01:00:00 --constraint cpu --account=e3sm
conda activate <NAME-OF-DEV-ENV>`
- python auxiliary_tools/cdat_regression_testing/<ISSUE-<SET_NAME>
6. Run this script as a Python module
- `auxiliary_tools` is not included in `setup.py`, so `-m` is required
to run the script as a Python module
- Command: python -m auxiliary_tools.cdat_regression_testing.<ISSUE-<SET_NAME>/<SCRIPT-NAME>
- Example: python -m auxiliary_tools.cdat_regression_testing.660_cosp_histogram.run_script
7. Make a copy of the CDAT regression testing notebook in the same directory
as this script and follow the instructions there to start testing.
"""
Expand Down
7 changes: 4 additions & 3 deletions e3sm_diags/driver/zonal_mean_xy_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
regrid_z_axis_to_plevs,
)
from e3sm_diags.logger import custom_logger
from e3sm_diags.metrics.metrics import spatial_avg
from e3sm_diags.plot.zonal_mean_xy_plot import plot as plot_func

logger = custom_logger(__name__)
Expand Down Expand Up @@ -249,10 +250,10 @@ def _calc_zonal_mean(
A Tuple containing the zonal mean for the test variable and the ref
variable.
"""
da_test_1d = ds_test.spatial.average(var_key, axis="X", weights="generate")[var_key]
da_ref_1d = ds_ref.spatial.average(var_key, axis="X", weights="generate")[var_key]
da_test_1d = spatial_avg(ds_test, var_key, "X", as_list=False)
da_ref_1d = spatial_avg(ds_ref, var_key, "X", as_list=False)

return da_test_1d, da_ref_1d
return da_test_1d, da_ref_1d # type: ignore


def _get_diff_of_zonal_means(da_a: xr.DataArray, da_b: xr.DataArray) -> xr.DataArray:
Expand Down
52 changes: 34 additions & 18 deletions e3sm_diags/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,23 @@
PLOT_SIDE_TITLE = {"fontsize": 9.5}

# Position and sizes of subplot axes in page coordinates (0 to 1)
PANEL = [
DEFAULT_PANEL_CONFIGS = [
(0.1691, 0.6810, 0.6465, 0.2258),
(0.1691, 0.3961, 0.6465, 0.2258),
(0.1691, 0.1112, 0.6465, 0.2258),
]

# Border padding relative to subplot axes for saving individual panels
# (left, bottom, right, top) in page coordinates
BORDER_PADDING = (-0.06, -0.03, 0.13, 0.03)
DEFAULT_BORDER_PADDING = (-0.06, -0.03, 0.13, 0.03)


def _save_plot(fig: plt.Figure, parameter: CoreParameter):
def _save_plot(
fig: plt.Figure,
parameter: CoreParameter,
panel_configs: List[Tuple[float, float, float, float]] = DEFAULT_PANEL_CONFIGS,
border_padding: Tuple[float, float, float, float] = DEFAULT_BORDER_PADDING,
):
"""Save the plot using the figure object and parameter configs.
This function creates the output filename to save the plot. It also
Expand All @@ -52,6 +57,12 @@ def _save_plot(fig: plt.Figure, parameter: CoreParameter):
The plot figure.
parameter : CoreParameter
The CoreParameter with file configurations.
panel_configs : List[Tuple[float, float, float, float]]
A list of panel configs consiting of positions and sizes, with each
element representing a panel.
border_padding : Tuple[float, float, float, float]
A tuple of border padding configs (left, bottom, right, top) for each
panel relative to the subplot axes.
"""
for f in parameter.output_format:
f = f.lower().split(".")[-1]
Expand All @@ -64,9 +75,9 @@ def _save_plot(fig: plt.Figure, parameter: CoreParameter):

# Save individual subplots
if parameter.ref_name == "":
panels = [PANEL[0]]
panels = [panel_configs[0]]
else:
panels = PANEL
panels = panel_configs

for f in parameter.output_format_subplot:
fnm = os.path.join(
Expand All @@ -79,7 +90,7 @@ def _save_plot(fig: plt.Figure, parameter: CoreParameter):
# Extent of subplot
subpage = np.array(panel).reshape(2, 2)
subpage[1, :] = subpage[0, :] + subpage[1, :]
subpage = subpage + np.array(BORDER_PADDING).reshape(2, 2)
subpage = subpage + np.array(border_padding).reshape(2, 2)
subpage = list(((subpage) * page).flatten()) # type: ignore
extent = Bbox.from_extents(*subpage)

Expand Down Expand Up @@ -174,7 +185,7 @@ def _add_colormap(
if is_global_domain or is_lon_full:
projection = ccrs.PlateCarree(central_longitude=180)

ax = fig.add_axes(PANEL[subplot_num], projection=projection)
ax = fig.add_axes(DEFAULT_PANEL_CONFIGS[subplot_num], projection=projection)
ax.set_extent([lon_west, lon_east, lat_south, lat_north], crs=projection)
color_map = get_colormap(color_map, parameter)
p1 = ax.contourf(
Expand Down Expand Up @@ -231,7 +242,12 @@ def _add_colormap(
# Add and configure the color bar.
# --------------------------------------------------------------------------
cbax = fig.add_axes(
(PANEL[subplot_num][0] + 0.6635, PANEL[subplot_num][1] + 0.0215, 0.0326, 0.1792)
(
DEFAULT_PANEL_CONFIGS[subplot_num][0] + 0.6635,
DEFAULT_PANEL_CONFIGS[subplot_num][1] + 0.0215,
0.0326,
0.1792,
)
)
cbar = fig.colorbar(p1, cax=cbax)

Expand All @@ -249,8 +265,8 @@ def _add_colormap(
# --------------------------------------------------------------------------
# Min, Mean, Max
fig.text(
PANEL[subplot_num][0] + 0.6635,
PANEL[subplot_num][1] + 0.2107,
DEFAULT_PANEL_CONFIGS[subplot_num][0] + 0.6635,
DEFAULT_PANEL_CONFIGS[subplot_num][1] + 0.2107,
"Max\nMean\nMin",
ha="left",
fontdict=PLOT_SIDE_TITLE,
Expand All @@ -266,8 +282,8 @@ def _add_colormap(
fmt_metrics = f"%.{fmt_m[0]}\n%.{fmt_m[1]}\n%.{fmt_m[2]}"

fig.text(
PANEL[subplot_num][0] + 0.7635,
PANEL[subplot_num][1] + 0.2107,
DEFAULT_PANEL_CONFIGS[subplot_num][0] + 0.7635,
DEFAULT_PANEL_CONFIGS[subplot_num][1] + 0.2107,
# "%.2f\n%.2f\n%.2f" % stats[0:3],
fmt_metrics % metrics[0:3],
ha="right",
Expand All @@ -277,15 +293,15 @@ def _add_colormap(
# RMSE, CORR
if len(metrics) == 5:
fig.text(
PANEL[subplot_num][0] + 0.6635,
PANEL[subplot_num][1] - 0.0105,
DEFAULT_PANEL_CONFIGS[subplot_num][0] + 0.6635,
DEFAULT_PANEL_CONFIGS[subplot_num][1] - 0.0105,
"RMSE\nCORR",
ha="left",
fontdict=PLOT_SIDE_TITLE,
)
fig.text(
PANEL[subplot_num][0] + 0.7635,
PANEL[subplot_num][1] - 0.0105,
DEFAULT_PANEL_CONFIGS[subplot_num][0] + 0.7635,
DEFAULT_PANEL_CONFIGS[subplot_num][1] - 0.0105,
"%.2f\n%.2f" % metrics[3:5],
ha="right",
fontdict=PLOT_SIDE_TITLE,
Expand All @@ -297,8 +313,8 @@ def _add_colormap(
dlat = lat[2] - lat[1]
dlon = lon[2] - lon[1]
fig.text(
PANEL[subplot_num][0] + 0.4635,
PANEL[subplot_num][1] - 0.04,
DEFAULT_PANEL_CONFIGS[subplot_num][0] + 0.4635,
DEFAULT_PANEL_CONFIGS[subplot_num][1] - 0.04,
"Resolution: {:.2f}x{:.2f}".format(dlat, dlon),
ha="left",
fontdict=PLOT_SIDE_TITLE,
Expand Down
96 changes: 15 additions & 81 deletions e3sm_diags/plot/zonal_mean_xy_plot.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,22 @@
from __future__ import print_function

import os

import matplotlib
import numpy as np
import xarray as xr
import xcdat as xc
from matplotlib.transforms import Bbox

from e3sm_diags.driver.utils.general import get_output_dir
from e3sm_diags.logger import custom_logger
from e3sm_diags.parameter.core_parameter import CoreParameter

# from typing import TYPE_CHECKING


# if TYPE_CHECKING:
# from e3sm_diags.driver.lat_lon_driver import MetricsDict
from e3sm_diags.plot.utils import _save_plot

matplotlib.use("Agg")
import matplotlib.pyplot as plt # isort:skip # noqa: E402

logger = custom_logger(__name__)

plotTitle = {"fontsize": 12.5}
plotSideTitle = {"fontsize": 11.5}
# Plot title and side title configurations.
PLOT_TITLE = {"fontsize": 12.5}
PLOT_SIDE_TITLE = {"fontsize": 11.5}

# Position and sizes of subplot axes in page coordinates (0 to 1)
TWO_PANELS = [
PANEL_CONFIGS = [
(0.1500, 0.5500, 0.7500, 0.3000),
(0.1500, 0.1300, 0.7500, 0.3000),
]
Expand Down Expand Up @@ -64,7 +53,7 @@ def plot(
long_name = parameter.viewer_descr[parameter.var_id]

# Top PANEL
ax1 = fig.add_axes(TWO_PANELS[0])
ax1 = fig.add_axes(PANEL_CONFIGS[0])
ax1.plot(lat_test, da_test.values, "k", linewidth=2)
ax1.plot(
lat_ref,
Expand All @@ -85,27 +74,27 @@ def plot(
)
ref_title += " : {}".format(parameter.ref_name_yrs)
fig.text(
TWO_PANELS[0][0],
TWO_PANELS[0][1] + TWO_PANELS[0][3] + 0.03,
PANEL_CONFIGS[0][0],
PANEL_CONFIGS[0][1] + PANEL_CONFIGS[0][3] + 0.03,
test_title,
ha="left",
fontdict=plotSideTitle,
fontdict=PLOT_SIDE_TITLE,
color="black",
)
fig.text(
TWO_PANELS[0][0],
TWO_PANELS[0][1] + TWO_PANELS[0][3] + 0.01,
PANEL_CONFIGS[0][0],
PANEL_CONFIGS[0][1] + PANEL_CONFIGS[0][3] + 0.01,
ref_title,
ha="left",
fontdict=plotSideTitle,
fontdict=PLOT_SIDE_TITLE,
color="red",
)

# Bottom PANEL
ax2 = fig.add_axes(TWO_PANELS[1])
ax2 = fig.add_axes(PANEL_CONFIGS[1])
ax2.plot(lat_diff, da_diff.values, "k", linewidth=2)
ax2.axhline(y=0, color="0.5")
ax2.set_title(parameter.diff_title, fontdict=plotSideTitle, loc="center")
ax2.set_title(parameter.diff_title, fontdict=PLOT_SIDE_TITLE, loc="center")
ax2.set_xticks([-90, -60, -30, 0, 30, 60, 90])
ax2.set_xlim(-90, 90)
ax2.tick_params(labelsize=11.0, direction="out", width=1)
Expand All @@ -115,61 +104,6 @@ def plot(
# Figure title
fig.suptitle(parameter.main_title, x=0.5, y=0.95, fontsize=18)

_save_plot(fig, parameter)
_save_plot(fig, parameter, PANEL_CONFIGS, BORDER_PADDING)

plt.close()


def _save_plot(fig: plt.figure, parameter: CoreParameter):
"""Save the plot using the figure object and parameter configs.
This function creates the output filename to save the plot. It also
saves each individual subplot if the reference name is an empty string ("").
Parameters
----------
fig : plt.figure
The plot figure.
parameter : CoreParameter
The CoreParameter with file configurations.
"""
for f in parameter.output_format:
f = f.lower().split(".")[-1]
fnm = os.path.join(
get_output_dir(parameter.current_set, parameter),
parameter.output_file + "." + f,
)
plt.savefig(fnm)
logger.info(f"Plot saved in: {fnm}")

# Save individual subplots
if parameter.ref_name == "":
PANELs = [TWO_PANELS[0]]
else:
PANELs = TWO_PANELS

for f in parameter.output_format_subplot:
fnm = os.path.join(
get_output_dir(parameter.current_set, parameter),
parameter.output_file,
)
page = fig.get_size_inches()

for idx, PANEL in enumerate(PANELs):
# Extent of subplot
subpage = np.array(PANEL).reshape(2, 2)
subpage[1, :] = subpage[0, :] + subpage[1, :]
subpage = subpage + np.array(BORDER_PADDING).reshape(2, 2)
subpage = list(((subpage) * page).flatten()) # type: ignore
extent = Bbox.from_extents(*subpage)

# Save subplot
fname = fnm + ".%i." % idx + f
plt.savefig(fname, bbox_inches=extent)

orig_fnm = os.path.join(
get_output_dir(parameter.current_set, parameter),
parameter.output_file,
)
fname = orig_fnm + ".%i." % idx + f
logger.info(f"Sub-plot saved in: {fname}")

0 comments on commit b19aec0

Please sign in to comment.