Skip to content

Commit

Permalink
CDAT Migration Phase 2: Refactor utilities and CoreParameter methods …
Browse files Browse the repository at this point in the history
…for reusability across diagnostic sets (#746)

- Move driver type annotations to `type_annotations.py`
- Move `lat_lon_driver._save_data_metrics_and_plots()` to `io.py`
- Update `_save_data_metrics_and_plots` args to accept `plot_func` callable
- Update `metrics.spatial_avg` to return an optionally `xr.DataArray` with `as_list=False`
- Move `parameter` arg to the top in `lat_lon_plot.plot`
- Move `_set_param_output_attrs` and `_set_name_yr_attrs` from `lat_lon_driver` to `CoreParameter` class
  • Loading branch information
tomvothecoder committed Nov 28, 2023
1 parent 50ce827 commit 6c5ed70
Show file tree
Hide file tree
Showing 10 changed files with 407 additions and 161 deletions.
156 changes: 13 additions & 143 deletions e3sm_diags/driver/lat_lon_driver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from __future__ import annotations

import json
import os
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, List, Tuple

import xarray as xr

from e3sm_diags.driver.utils.dataset_xr import Dataset
from e3sm_diags.driver.utils.io import _get_output_dir, _write_vars_to_netcdf
from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots
from e3sm_diags.driver.utils.regrid import (
_apply_land_sea_mask,
_subset_on_region,
Expand All @@ -16,20 +14,13 @@
has_z_axis,
regrid_z_axis_to_plevs,
)
from e3sm_diags.driver.utils.type_annotations import MetricsDict
from e3sm_diags.logger import custom_logger
from e3sm_diags.metrics.metrics import correlation, rmse, spatial_avg, std
from e3sm_diags.plot.lat_lon_plot import plot
from e3sm_diags.plot.lat_lon_plot import plot as plot_func

logger = custom_logger(__name__)

# The type annotation for the metrics dictionary. The key is the
# type of metrics and the value is a sub-dictionary of metrics (key is metrics
# type and value is float). There is also a "unit" key representing the
# units for the variable.
UnitAttr = str
MetricsSubDict = Dict[str, Union[float, None, List[float]]]
MetricsDict = Dict[str, Union[UnitAttr, MetricsSubDict]]

if TYPE_CHECKING:
from e3sm_diags.parameter.core_parameter import CoreParameter

Expand Down Expand Up @@ -72,31 +63,15 @@ def run_diag(parameter: CoreParameter) -> CoreParameter:
parameter.var_id = var_key

for season in seasons:
parameter.test_name_yrs = test_ds.get_name_yrs_attr(season)
parameter.ref_name_yrs = ref_ds.get_name_yrs_attr(season)
parameter._set_name_yrs_attrs(test_ds, ref_ds, season)

# The land sea mask dataset that is used for masking if the region
# is either land or sea. This variable is instantiated here to get
# it once per season in case it needs to be reused.
ds_land_sea_mask: xr.Dataset = test_ds._get_land_sea_mask(season)

ds_test = test_ds.get_climo_dataset(var_key, season)

# If the reference climatology dataset cannot be retrieved
# it will be set the to the test climatology dataset which means
# analysis is only performed on the test dataset.
# TODO: This logic was carried over from legacy implementation. It
# can probably be improved on by setting `ds_ref = None` and not
# performing unnecessary operations on `ds_ref` for model-only runs,
# since it is the same as `ds_test``.
try:
ds_ref = ref_ds.get_climo_dataset(var_key, season)
parameter.model_only = False
except (RuntimeError, IOError):
ds_ref = ds_test
parameter.model_only = True

logger.info("Cannot process reference data, analyzing test data only.")
ds_ref = ref_ds.get_ref_climo_dataset(var_key, season, ds_test)

# Store the variable's DataArray objects for reuse.
dv_test = ds_test[var_key]
Expand Down Expand Up @@ -173,9 +148,8 @@ def _run_diags_2d(
The reference name.
"""
for region in regions:
parameter = _set_param_output_attrs(
parameter, var_key, season, region, ref_name, ilev=None
)
parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev=None)

(
metrics_dict,
ds_test_region,
Expand All @@ -191,11 +165,12 @@ def _run_diags_2d(
)
_save_data_metrics_and_plots(
parameter,
plot_func,
var_key,
metrics_dict,
ds_test_region,
ds_ref_region,
ds_diff_region,
metrics_dict,
)


Expand Down Expand Up @@ -261,64 +236,18 @@ def _run_diags_3d(
region,
)

parameter = _set_param_output_attrs(
parameter, var_key, season, region, ref_name, ilev
)
parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev)
_save_data_metrics_and_plots(
parameter,
plot_func,
var_key,
metrics_dict,
ds_test_region,
ds_ref_region,
ds_diff_region,
metrics_dict,
)


def _set_param_output_attrs(
parameter: CoreParameter,
var_key: str,
season: str,
region: str,
ref_name: str,
ilev: float | None,
) -> CoreParameter:
"""Set the parameter output attributes based on argument values.
Parameters
----------
parameter : CoreParameter
The parameter.
var_key : str
The variable key.
season : str
The season.
region : str
The region.
ref_name : str
The reference name,
ilev : float | None
The pressure level, by default None. This option is only set if the
variable is 3D.
Returns
-------
CoreParameter
The parameter with updated output attributes.
"""
if ilev is None:
output_file = f"{ref_name}-{var_key}-{season}-{region}"
main_title = f"{var_key} {season} {region}"
else:
ilev_str = str(int(ilev))
output_file = f"{ref_name}-{var_key}-{ilev_str}-{season}-{region}"
main_title = f"{var_key} {ilev_str} 'mb' {season} {region}"

parameter.output_file = output_file
parameter.main_title = main_title

return parameter


def _get_metrics_by_region(
parameter: CoreParameter,
ds_test: xr.Dataset,
Expand Down Expand Up @@ -516,62 +445,3 @@ def _create_metrics_dict(
}

return metrics_dict


def _save_data_metrics_and_plots(
parameter: CoreParameter,
var_key: str,
metrics_dict: MetricsDict,
ds_test: xr.Dataset,
ds_ref: xr.Dataset | None,
ds_diff: xr.Dataset | None,
):
"""Save data (optional), metrics, and plots.
Parameters
----------
parameter : CoreParameter
The parameter for the diagnostic.
var_key : str
The variable key.
metrics_dict : Metrics
The dictionary containing metrics for the variable.
ds_test : xr.Dataset
The test dataset.
ds_ref : xr.Dataset | None
The optional reference dataset. If the diagnostic is a model-only run,
then it will be None.
ds_diff : xr.Dataset | None
The optional difference dataset. If the diagnostic is a model-only run,
then it will be None.
"""
if parameter.save_netcdf:
_write_vars_to_netcdf(
parameter,
var_key,
ds_test,
ds_ref,
ds_diff,
)

output_dir = _get_output_dir(parameter)
filename = f"{parameter.output_file}.json"
filepath = os.path.join(output_dir, filename)

with open(filepath, "w") as outfile:
json.dump(metrics_dict, outfile)

logger.info(f"Metrics saved in {filepath}")

# Set the viewer description to the "long_name" attr of the variable.
parameter.viewer_descr[var_key] = ds_test[var_key].attrs.get(
"long_name", "No long_name attr in test data"
)

plot(
ds_test[var_key],
ds_ref[var_key] if ds_ref is not None else None,
ds_diff[var_key] if ds_diff is not None else None,
metrics_dict,
parameter,
)
59 changes: 57 additions & 2 deletions e3sm_diags/driver/utils/dataset_xr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import glob
import os
import re
from typing import Callable, Dict, Literal, Tuple
from typing import TYPE_CHECKING, Callable, Dict, Literal, Tuple

import xarray as xr
import xcdat as xc
Expand All @@ -29,7 +29,10 @@
from e3sm_diags.driver import LAND_FRAC_KEY, LAND_OCEAN_MASK_PATH, OCEAN_FRAC_KEY
from e3sm_diags.driver.utils.climo_xr import CLIMO_FREQ, CLIMO_FREQS, climo
from e3sm_diags.logger import custom_logger
from e3sm_diags.parameter.core_parameter import CoreParameter

if TYPE_CHECKING:
from e3sm_diags.parameter.core_parameter import CoreParameter


logger = custom_logger(__name__)

Expand Down Expand Up @@ -256,6 +259,58 @@ def _get_global_attr_from_climo_dataset(
# --------------------------------------------------------------------------
# Climatology related methods
# --------------------------------------------------------------------------
def get_ref_climo_dataset(
self, var_key: str, season: CLIMO_FREQ, ds_test: xr.Dataset
):
"""Get the reference climatology dataset for the variable and season.
If the reference climatatology does not exist or could not be found, it
will be considered a model-only run. For this case the test dataset
is returned as a default value and subsequent metrics calculations will
only be performed on the original test dataset.
Parameters
----------
var_key : str
The key of the variable.
season : CLIMO_FREQ
The climatology frequency.
ds_test : xr.Dataset
The test dataset, which is returned if the reference climatology
does not exist or could not be found.
Returns
-------
xr.Dataset
The reference climatology if it exists or a copy of the test dataset
if it does not exist.
Raises
------
RuntimeError
If `self.data_type` is not "ref".
"""
# TODO: This logic was carried over from legacy implementation. It
# can probably be improved on by setting `ds_ref = None` and not
# performing unnecessary operations on `ds_ref` for model-only runs,
# since it is the same as `ds_test``.
if self.data_type == "ref":
try:
ds_ref = self.get_climo_dataset(var_key, season)
self.model_only = False
except (RuntimeError, IOError):
ds_ref = ds_test.copy()
self.model_only = True

logger.info("Cannot process reference data, analyzing test data only.")
else:
raise RuntimeError(
"`Dataset._get_ref_dataset` only works with "
f"`self.data_type == 'ref'`, not {self.data_type}."
)

return ds_ref

def get_climo_dataset(self, var: str, season: CLIMO_FREQ) -> xr.Dataset:
"""Get the dataset containing the climatology variable.
Expand Down
66 changes: 66 additions & 0 deletions e3sm_diags/driver/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,82 @@
from __future__ import annotations

import errno
import json
import os
from typing import Callable

import xarray as xr

from e3sm_diags.driver.utils.type_annotations import MetricsDict
from e3sm_diags.logger import custom_logger
from e3sm_diags.parameter.core_parameter import CoreParameter

logger = custom_logger(__name__)


def _save_data_metrics_and_plots(
parameter: CoreParameter,
plot_func: Callable,
var_key: str,
ds_test: xr.Dataset,
ds_ref: xr.Dataset | None,
ds_diff: xr.Dataset | None,
metrics_dict: MetricsDict | None,
):
"""Save data (optional), metrics, and plots.
Parameters
----------
parameter : CoreParameter
The parameter for the diagnostic.
plot_func: Callable
The plot function for the diagnostic set.
var_key : str
The variable key.
ds_test : xr.Dataset
The test dataset.
ds_ref : xr.Dataset | None
The optional reference dataset. If the diagnostic is a model-only run,
then it will be None.
ds_diff : xr.Dataset | None
The optional difference dataset. If the diagnostic is a model-only run,
then it will be None.
metrics_dict : Metrics
The dictionary containing metrics for the variable.
"""
if parameter.save_netcdf:
_write_vars_to_netcdf(
parameter,
var_key,
ds_test,
ds_ref,
ds_diff,
)

output_dir = _get_output_dir(parameter)
filename = f"{parameter.output_file}.json"
filepath = os.path.join(output_dir, filename)

if metrics_dict is not None:
with open(filepath, "w") as outfile:
json.dump(metrics_dict, outfile)

logger.info(f"Metrics saved in {filepath}")

# Set the viewer description to the "long_name" attr of the variable.
parameter.viewer_descr[var_key] = ds_test[var_key].attrs.get(
"long_name", "No long_name attr in test data"
)

plot_func(
parameter,
ds_test[var_key],
ds_ref[var_key] if ds_ref is not None else None,
ds_diff[var_key] if ds_diff is not None else None,
metrics_dict,
)


def _write_vars_to_netcdf(
parameter: CoreParameter,
var_key,
Expand Down
Loading

0 comments on commit 6c5ed70

Please sign in to comment.