Skip to content

Commit

Permalink
Add initial refactored driver
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Mar 14, 2024
1 parent 0247f05 commit c907ca0
Show file tree
Hide file tree
Showing 2 changed files with 379 additions and 169 deletions.
344 changes: 175 additions & 169 deletions e3sm_diags/driver/meridional_mean_2d_driver.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from __future__ import annotations

from copy import deepcopy
from typing import TYPE_CHECKING

import cdms2
import cdutil
import MV2
import numpy
import xarray as xr

from e3sm_diags.driver import utils
from e3sm_diags.driver.utils.dataset_xr import Dataset
from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots
from e3sm_diags.driver.utils.regrid import (
align_grids_to_lower_res,
has_z_axis,
regrid_z_axis_to_plevs,
)
from e3sm_diags.driver.utils.type_annotations import MetricsDict
from e3sm_diags.logger import custom_logger
from e3sm_diags.metrics import corr, max_cdms, mean, min_cdms, rmse
from e3sm_diags.parameter.zonal_mean_2d_parameter import ZonalMean2dParameter
from e3sm_diags.plot import plot
from e3sm_diags.metrics.metrics import correlation, rmse, spatial_avg
from e3sm_diags.parameter.zonal_mean_2d_parameter import DEFAULT_PLEVS

# TODO: Update this ref after moving the plotter down a dir level.
from e3sm_diags.plot.cartopy.meridional_mean_2d_plot import plot as plot_func

logger = custom_logger(__name__)

Expand All @@ -20,185 +27,184 @@
MeridionalMean2dParameter,
)

DEFAULT_PLEVS = deepcopy(DEFAULT_PLEVS)

def create_metrics(ref, test, ref_regrid, test_regrid, diff):
"""
Creates the mean, max, min, rmse, corr in a dictionary.

def run_diag(parameter: MeridionalMean2dParameter) -> MeridionalMean2dParameter:
"""Run the meridional_mean_2d diagnostics.
Parameters
----------
parameter : MeridionalMean2dParameter
The parameter for the diagnostic.
Returns
-------
MeridionalMean2dParameter
The parameter for the diagnostic with the result (completed or failed)
Raises
------
RuntimeError
If the dimensions of the test and ref variables differ.
RuntimeError
If the test or ref variables do are not 3-D (no Z-axis).
"""
orig_bounds = cdms2.getAutoBounds()
cdms2.setAutoBounds(1)
lev = ref.getLevel()
if lev is not None:
lev.setBounds(None)
variables = parameter.variables
seasons = parameter.seasons
ref_name = getattr(parameter, "ref_name", "")

lev = test.getLevel()
if lev is not None:
lev.setBounds(None)
test_ds = Dataset(parameter, data_type="test")
ref_ds = Dataset(parameter, data_type="ref")

lev = test_regrid.getLevel()
if lev is not None:
lev.setBounds(None)
for var_key in variables:
logger.info("Variable: {}".format(var_key))
parameter.var_id = var_key

lev = ref_regrid.getLevel()
if lev is not None:
lev.setBounds(None)
for season in seasons:
parameter._set_name_yrs_attrs(test_ds, ref_ds, season)

lev = diff.getLevel()
if lev is not None:
lev.setBounds(None)
cdms2.setAutoBounds(orig_bounds)
ds_test = test_ds.get_climo_dataset(var_key, season)
ds_ref = ref_ds.get_ref_climo_dataset(var_key, season, ds_test)

metrics_dict = {}
metrics_dict["ref"] = {
"min": min_cdms(ref),
"max": max_cdms(ref),
"mean": mean(ref, axis="xz"),
}
metrics_dict["test"] = {
"min": min_cdms(test),
"max": max_cdms(test),
"mean": mean(test, axis="xz"),
}
dv_test = ds_test[var_key]
dv_ref = ds_ref[var_key]

metrics_dict["diff"] = {
"min": min_cdms(diff),
"max": max_cdms(diff),
"mean": mean(diff, axis="xz"),
}
metrics_dict["misc"] = {
"rmse": rmse(test_regrid, ref_regrid, axis="xz"),
"corr": corr(test_regrid, ref_regrid, axis="xz"),
}
is_vars_3d = has_z_axis(dv_test) and has_z_axis(dv_ref)
is_dims_diff = has_z_axis(dv_test) != has_z_axis(dv_ref)

return metrics_dict
if is_dims_diff:
raise RuntimeError(
"Dimensions of the test and ref variables are different."
)
elif not is_vars_3d:
raise RuntimeError(
"The test and/or ref variables are not 3-D (no Z axis)."
)
elif is_vars_3d:
# Since the default is now stored in MeridionalMean2dParameter,
# we must get it from there if the plevs param is blank.
if not parameter._is_plevs_set():
parameter.plevs = DEFAULT_PLEVS

_run_diags_3d(parameter, ds_test, ds_ref, season, var_key, ref_name)

def run_diag(parameter: MeridionalMean2dParameter) -> MeridionalMean2dParameter:
variables = parameter.variables
seasons = parameter.seasons
ref_name = getattr(parameter, "ref_name", "")
return parameter

test_data = utils.dataset.Dataset(parameter, test=True)
ref_data = utils.dataset.Dataset(parameter, ref=True)

for season in seasons:
# Get the name of the data, appended with the years averaged.
parameter.test_name_yrs = utils.general.get_name_and_yrs(
parameter, test_data, season
)
parameter.ref_name_yrs = utils.general.get_name_and_yrs(
parameter, ref_data, season
)
def _run_diags_3d(
parameter: MeridionalMean2dParameter,
ds_test: xr.Dataset,
ds_ref: xr.Dataset,
season: str,
var_key: str,
ref_name: str,
):
plevs = parameter.plevs

for var in variables:
logger.info("Variable: {}".format(var))
parameter.var_id = var
ds_t_plevs = regrid_z_axis_to_plevs(ds_test, var_key, plevs)
ds_r_plevs = regrid_z_axis_to_plevs(ds_ref, var_key, plevs)

mv1 = test_data.get_climo_variable(var, season)
mv2 = ref_data.get_climo_variable(var, season)
ds_t_plevs_avg = ds_t_plevs.spatial.average(var_key, axis=["Y"])
ds_r_plevs_avg = ds_r_plevs.spatial.average(var_key, axis=["Y"])

parameter.viewer_descr[var] = (
mv1.long_name
if hasattr(mv1, "long_name")
else "No long_name attr in test data."
)
# TODO: Make sure this the logic and output matches the old CDAT version
# of this code. Can we use xESMF instead here?
ds_t_plevs_rg_avg, ds_r_plevs_rg_avg = align_grids_to_lower_res(
ds_t_plevs_avg, ds_r_plevs_avg, var_key, tool="regrid2", method="conservative"
)

# For variables with a z-axis.
if mv1.getLevel() and mv2.getLevel():
# Since the default is now stored in MeridionalMean2dParameter,
# we must get it from there if the plevs param is blank.
plevs = parameter.plevs
if (isinstance(plevs, numpy.ndarray) and not plevs.all()) or (
not isinstance(plevs, numpy.ndarray) and not plevs
):
plevs = ZonalMean2dParameter().plevs

mv1_p = utils.general.convert_to_pressure_levels(
mv1, plevs, test_data, var, season
)
mv2_p = utils.general.convert_to_pressure_levels(
mv2, plevs, ref_data, var, season
)
# Get the difference between final regridded variables.
with xr.set_options(keep_attrs=True):
ds_diff_plevs_rg_avg = ds_t_plevs_rg_avg.copy()
ds_diff_plevs_rg_avg[var_key] = (
ds_t_plevs_rg_avg[var_key] - ds_r_plevs_rg_avg[var_key]
)

mv1_p = cdutil.averager(mv1_p, axis="y")
mv2_p = cdutil.averager(mv2_p, axis="y")
metrics_dict = _create_metrics_dict(
var_key,
ds_t_plevs_avg,
ds_t_plevs_rg_avg,
ds_r_plevs_avg,
ds_r_plevs_rg_avg,
ds_diff_plevs_rg_avg,
)

parameter.output_file = "-".join(
[ref_name, var, season, parameter.regions[0]]
)
parameter.main_title = str(" ".join([var, season]))

# Regrid towards the lower resolution of the two
# variables for calculating the difference.
if len(mv1_p.getLongitude()) < len(mv2_p.getLongitude()):
mv1_reg = mv1_p
lev_out = mv1_p.getLevel()
lon_out = mv1_p.getLongitude()
# in order to use regrid tool we need to have at least two latitude bands, so generate new grid first
lat = cdms2.createAxis([0])
lat.setBounds(numpy.array([-1, 1]))
lat.designateLatitude()
grid = cdms2.createRectGrid(lat, lon_out)

data_shape = list(mv2_p.shape)
data_shape.append(1)
mv2_reg = MV2.resize(mv2_p, data_shape)
mv2_reg.setAxis(-1, lat)
for i, ax in enumerate(mv2_p.getAxisList()):
mv2_reg.setAxis(i, ax)

mv2_reg = mv2_reg.regrid(grid, regridTool="regrid2")[..., 0]
# Apply the mask back, since crossSectionRegrid
# doesn't preserve the mask.
mv2_reg = MV2.masked_where(mv2_reg == mv2_reg.fill_value, mv2_reg)
elif len(mv1_p.getLongitude()) > len(mv2_p.getLongitude()):
mv2_reg = mv2_p
lev_out = mv2_p.getLevel()
lon_out = mv2_p.getLongitude()
mv1_reg = mv1_p.crossSectionRegrid(lev_out, lon_out)
# In order to use regrid tool we need to have at least two
# latitude bands, so generate new grid first.
lat = cdms2.createAxis([0])
lat.setBounds(numpy.array([-1, 1]))
lat.designateLatitude()
grid = cdms2.createRectGrid(lat, lon_out)

data_shape = list(mv1_p.shape)
data_shape.append(1)
mv1_reg = MV2.resize(mv1_p, data_shape)
mv1_reg.setAxis(-1, lat)
for i, ax in enumerate(mv1_p.getAxisList()):
mv1_reg.setAxis(i, ax)

mv1_reg = mv1_reg.regrid(grid, regridTool="regrid2")[..., 0]
# Apply the mask back, since crossSectionRegrid
# doesn't preserve the mask.
mv1_reg = MV2.masked_where(mv1_reg == mv1_reg.fill_value, mv1_reg)
else:
mv1_reg = mv1_p
mv2_reg = mv2_p

diff = mv1_reg - mv2_reg
metrics_dict = create_metrics(mv2_p, mv1_p, mv2_reg, mv1_reg, diff)

parameter.var_region = "global"

plot(
parameter.current_set,
mv2_p,
mv1_p,
diff,
metrics_dict,
parameter,
)
utils.general.save_ncfiles(
parameter.current_set, mv1_p, mv2_p, diff, parameter
)
# TODO: This section can be turned into a CoreParameter method since it is
# repeated in various drivers.
# Set parameter attributes for output files.
parameter.var_region = "global"
parameter.output_file = "-".join([ref_name, var_key, season, parameter.regions[0]])
parameter.main_title = str(" ".join([var_key, season]))

_save_data_metrics_and_plots(
parameter,
plot_func,
var_key,
ds_t_plevs_avg,
ds_r_plevs_avg,
ds_diff_plevs_rg_avg,
metrics_dict,
)

# For variables without a z-axis.
elif mv1.getLevel() is None and mv2.getLevel() is None:
raise RuntimeError(
"One of or both data doesn't have z dimention. Aborting."
)

return parameter
def _create_metrics_dict(
var_key: str,
ds_test: xr.Dataset,
ds_test_regrid: xr.Dataset,
ds_ref: xr.Dataset,
ds_ref_regrid: xr.Dataset,
ds_diff: xr.Dataset,
) -> MetricsDict:
"""Calculate metrics using the variable in the datasets.
Metrics include min value, max value, spatial average (mean), standard
deviation, correlation (pearson_r), and RMSE.
Parameters
----------
var_key : str
The variable key.
ds_test : xr.Dataset
The test dataset.
ds_test_regrid : xr.Dataset
The regridded test Dataset.
ds_ref : xr.Dataset
The reference dataset.
ds_ref_regrid : xr.Dataset
The regridded reference dataset.
ds_diff : xr. Dataset
The difference between ``ds_test_regrid`` and ``ds_ref_regrid``.
Returns
-------
Metrics
A dictionary with the key being a string and the value being either
a sub-dictionary (key is metric and value is float) or a string
("unit").
"""
metrics_dict = {}

metrics_dict["units"] = ds_test[var_key].attrs["units"]
metrics_dict["ref"] = {
"min": ds_ref[var_key].min().item(),
"max": ds_test[var_key].max().item(),
"mean": spatial_avg(ds_ref, var_key, axis=["X", "Z"]),
}
metrics_dict["test"] = {
"min": ds_test[var_key].min().item(),
"max": ds_test[var_key].max().item(),
"mean": spatial_avg(ds_test, var_key, axis=["X", "Z"]),
}

metrics_dict["diff"] = {
"min": ds_diff[var_key].min().item(),
"max": ds_diff[var_key].max().item(),
"mean": spatial_avg(ds_diff, var_key, axis=["X", "Z"]),
}

metrics_dict["misc"] = {
"rmse": rmse(ds_test_regrid, ds_ref_regrid, var_key, axis=["X", "Z"]),
"corr": correlation(ds_test_regrid, ds_ref_regrid, var_key, axis=["X", "Z"]),
}
return metrics_dict
Loading

0 comments on commit c907ca0

Please sign in to comment.