Skip to content

Commit

Permalink
Remove 2D diags from zonal_mean_2d
Browse files Browse the repository at this point in the history
- Move `utils._add_colormap()` to `lat_lon_plot.py`
  • Loading branch information
tomvothecoder committed Feb 14, 2024
1 parent d0ac18a commit b03163b
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 206 deletions.
72 changes: 7 additions & 65 deletions e3sm_diags/driver/zonal_mean_2d_driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import List, Tuple
from typing import Tuple

import xarray as xr
import xcdat as xc # noqa: F401
Expand All @@ -10,7 +10,6 @@
align_grids_to_lower_res,
has_z_axis,
regrid_z_axis_to_plevs,
subset_and_align_datasets,
)
from e3sm_diags.driver.utils.type_annotations import MetricsDict
from e3sm_diags.logger import custom_logger
Expand All @@ -32,7 +31,6 @@ def run_diag(
variables = parameter.variables
seasons = parameter.seasons
ref_name = getattr(parameter, "ref_name", "")
regions = parameter.regions

if not parameter._is_plevs_set():
parameter.plevs = default_plevs
Expand All @@ -57,18 +55,10 @@ def run_diag(
is_vars_3d = has_z_axis(dv_test) and has_z_axis(dv_ref)
is_dims_diff = has_z_axis(dv_test) != has_z_axis(dv_ref)

if not is_vars_3d:
ds_land_sea_mask: xr.Dataset = test_ds._get_land_sea_mask(season)

_run_diags_2d(
parameter,
ds_test,
ds_ref,
ds_land_sea_mask,
season,
regions,
var_key,
ref_name,
if is_dims_diff:
raise RuntimeError(
"The dimensions of the test and reference variables are different, "
f"({dv_test.dims} vs. {dv_ref.dims})."
)
elif is_vars_3d:
_run_diags_3d(
Expand All @@ -79,62 +69,14 @@ def run_diag(
var_key,
ref_name,
)
elif is_dims_diff:
else:
raise RuntimeError(
"Dimensions of the two variables are different. Aborting."
"Only 3-dimensional variables are supported by zonal_mean_2d."
)

return parameter


def _run_diags_2d(
parameter: ZonalMean2dParameter,
ds_test: xr.Dataset,
ds_ref: xr.Dataset,
ds_land_sea_mask: xr.Dataset,
season: str,
regions: List[str],
var_key: str,
ref_name: str,
):
for region in regions:
parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev=None)

(
ds_test_region,
ds_ref_region,
ds_test_region_regrid,
ds_ref_region_regrid,
ds_diff_region,
) = subset_and_align_datasets(
parameter,
ds_test,
ds_ref,
ds_land_sea_mask,
var_key,
region,
)

metrics_dict = _create_metrics_dict(
var_key,
ds_test_region,
ds_test_region_regrid,
ds_ref_region,
ds_ref_region_regrid,
ds_diff_region,
)

_save_data_metrics_and_plots(
parameter,
plot_func,
var_key,
ds_test_region,
ds_ref_region,
ds_diff_region,
metrics_dict,
)


def _run_diags_3d(
parameter: ZonalMean2dParameter,
ds_test: xr.Dataset,
Expand Down
138 changes: 136 additions & 2 deletions e3sm_diags/plot/lat_lon_plot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Tuple

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib
import xarray as xr
import xcdat as xc

from e3sm_diags.derivations.default_regions_xr import REGION_SPECS
from e3sm_diags.logger import custom_logger
from e3sm_diags.parameter.core_parameter import CoreParameter
from e3sm_diags.plot.utils import _add_colormap, _save_plot
from e3sm_diags.plot.utils import (
DEFAULT_PANEL_CFG,
_add_colorbar,
_add_contour_plot,
_add_grid_res_info,
_add_min_mean_max_text,
_add_rmse_corr_text,
_configure_titles,
_configure_x_and_y_axes,
_get_c_levels_and_norm,
_get_x_ticks,
_get_y_ticks,
_make_lon_cyclic,
_save_plot,
)

if TYPE_CHECKING:
from e3sm_diags.driver.lat_lon_driver import MetricsDict
Expand Down Expand Up @@ -102,3 +120,119 @@ def plot(
_save_plot(fig, parameter)

plt.close()


def _add_colormap(
subplot_num: int,
var: xr.DataArray,
fig: plt.Figure,
parameter: CoreParameter,
color_map: str,
contour_levels: List[float],
title: Tuple[str | None, str, str],
metrics: Tuple[float, ...],
):
"""Adds a colormap containing the variable data and metrics to the figure.
This function is used by:
- `lat_lon_plot.py`
- `aerosol_aeronet_plot.py` (when refactored).
Parameters
----------
subplot_num : int
The subplot number.
var : xr.DataArray
The variable to plot.
fig : plt.Figure
The figure object to add the subplot to.
parameter : CoreParameter
The CoreParameter object containing plot configurations.
color_map : str
The colormap styling to use (e.g., "cet_rainbow.rgb").
contour_levels : List[float]
The map contour levels.
title : Tuple[str | None, str, str]
A tuple of strings to form the title of the colormap, in the format
(<optional> years, title, units).
metrics : Tuple[float, ...]
A tuple of metrics for this subplot.
"""
var = _make_lon_cyclic(var)
lat = xc.get_dim_coords(var, axis="Y")
lon = xc.get_dim_coords(var, axis="X")

var = var.squeeze()

# Configure contour levels and boundary norm.
# --------------------------------------------------------------------------
c_levels, norm = _get_c_levels_and_norm(contour_levels)

# Get region info and X and Y plot ticks.
# --------------------------------------------------------------------------
region_key = parameter.regions[0]
region_specs = REGION_SPECS[region_key]

# Get the region's domain slices for latitude and longitude if set, or
# use the default value. If both are not set, then the region type is
# considered "global".
lat_slice = region_specs.get("lat", (-90, 90)) # type: ignore
lon_slice = region_specs.get("lon", (0, 360)) # type: ignore

# Boolean flags for configuring plots.
is_global_domain = lat_slice == (-90, 90) and lon_slice == (0, 360)
is_lon_full = lon_slice == (0, 360)

# Determine X and Y ticks using longitude and latitude domains respectively.
lon_west, lon_east = lon_slice
x_ticks = _get_x_ticks(lon_west, lon_east, is_global_domain, is_lon_full)

lat_south, lat_north = lat_slice
y_ticks = _get_y_ticks(lat_south, lat_north)

# Get the cartopy projection based on region info.
# --------------------------------------------------------------------------
projection = ccrs.PlateCarree()
if is_global_domain or is_lon_full:
projection = ccrs.PlateCarree(central_longitude=180)

# Get the figure Axes object using the projection above.
# --------------------------------------------------------------------------
ax = fig.add_axes(DEFAULT_PANEL_CFG[subplot_num], projection=projection)
ax.set_extent([lon_west, lon_east, lat_south, lat_north], crs=projection)
contour_plot = _add_contour_plot(
ax, parameter, var, lon, lat, color_map, ccrs.PlateCarree(), norm, c_levels
)

# Configure the aspect ratio and coast lines.
# --------------------------------------------------------------------------
# Full world would be aspect 360/(2*180) = 1
ax.set_aspect((lon_east - lon_west) / (2 * (lat_north - lat_south)))
ax.coastlines(lw=0.3)

if not is_global_domain and "RRM" in region_key:
ax.coastlines(resolution="50m", color="black", linewidth=1)
state_borders = cfeature.NaturalEarthFeature(
category="cultural",
name="admin_1_states_provinces_lakes",
scale="50m",
facecolor="none",
)
ax.add_feature(state_borders, edgecolor="black")

# Configure the titles, x and y axes, and colorbar.
# --------------------------------------------------------------------------
_configure_titles(ax, title)
_configure_x_and_y_axes(
ax, x_ticks, y_ticks, ccrs.PlateCarree(), parameter.current_set
)
_add_colorbar(fig, subplot_num, DEFAULT_PANEL_CFG, contour_plot, c_levels)

# Add metrics text to the figure.
# --------------------------------------------------------------------------
_add_min_mean_max_text(fig, subplot_num, DEFAULT_PANEL_CFG, metrics)

if len(metrics) == 5:
_add_rmse_corr_text(fig, subplot_num, DEFAULT_PANEL_CFG, metrics)

_add_grid_res_info(fig, subplot_num, region_key, lat, lon, DEFAULT_PANEL_CFG)
Loading

0 comments on commit b03163b

Please sign in to comment.