Skip to content

Commit

Permalink
CDAT Migration Phase 2: Regression testing for lat_lon, `lat_lon_la…
Browse files Browse the repository at this point in the history
…nd`, and `lat_lon_river` (#744)

- Add Makefile that simplifies common development commands (building and installing, testing, etc.)
- Write unit tests to cover all new code for utility functions
  - `dataset_xr.py`, `metrics.py`, `climo_xr.py`, `io.py`, `regrid.py`
- Metrics comparison for  `cdat-migration-fy24` `lat_lon` and `main` branch of `lat_lon` -- `NET_FLUX_SRF` and `RESTOM` have the highest spatial average diffs
- Test run with 3D variables (`_run_3d_diags()`)
  - Fix Python 3.9 bug with using pipe command to represent Union -- doesn't work with `from __future__ import annotations` still
  - Fix subsetting syntax bug using ilev
  - Fix regridding bug where a single plev is passed and xCDAT does not allow generating bounds for coordinates of len <= 1 -- add conditional that just ignores adding new bounds for regridded output datasets, fix related tests
  - Fix accidentally calling save plots and metrics twice in `_get_metrics_by_region()`
- Fix failing integration tests pass in CI/CD
  - Refactor `test_diags.py` -- replace unittest with pytest
  - Refactor `test_all_sets.py` -- replace unittest with pytest
  - Test climatology datasets -- tested with 3d variables using `test_all_sets.py`
  • Loading branch information
tomvothecoder committed Nov 28, 2023
1 parent 046373e commit 50ce827
Show file tree
Hide file tree
Showing 10 changed files with 471 additions and 304 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ docs: ## generate Sphinx HTML documentation, including API docs
# Build
# ----------------------
install: clean ## install the package to the active Python's site-packages
python setup.py install
python -m pip install .
50 changes: 24 additions & 26 deletions e3sm_diags/driver/lat_lon_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import json
import os
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import TYPE_CHECKING, Dict, List, Tuple, Union

import xarray as xr

from e3sm_diags.driver.utils.dataset_xr import Dataset
from e3sm_diags.driver.utils.general import get_output_dir
from e3sm_diags.driver.utils.io import _write_vars_to_netcdf
from e3sm_diags.driver.utils.io import _get_output_dir, _write_vars_to_netcdf
from e3sm_diags.driver.utils.regrid import (
_apply_land_sea_mask,
_subset_on_region,
Expand All @@ -27,7 +26,9 @@
# type of metrics and the value is a sub-dictionary of metrics (key is metrics
# type and value is float). There is also a "unit" key representing the
# units for the variable.
MetricsDict = Dict[str, str | Dict[str, float | None | List[float]]]
UnitAttr = str
MetricsSubDict = Dict[str, Union[float, None, List[float]]]
MetricsDict = Dict[str, Union[UnitAttr, MetricsSubDict]]

if TYPE_CHECKING:
from e3sm_diags.parameter.core_parameter import CoreParameter
Expand Down Expand Up @@ -116,7 +117,6 @@ def run_diag(parameter: CoreParameter) -> CoreParameter:
ref_name,
)
elif is_vars_3d:
# TODO: Test this conditional with 3D variables.
_run_diags_3d(
parameter,
ds_test,
Expand Down Expand Up @@ -238,14 +238,13 @@ def _run_diags_3d(
plev = parameter.plevs
logger.info("Selected pressure level(s): {}".format(plev))

ds_test = regrid_z_axis_to_plevs(ds_test, var_key, parameter.plevs)
ds_ref = regrid_z_axis_to_plevs(ds_ref, var_key, parameter.plevs)
ds_test_rg = regrid_z_axis_to_plevs(ds_test, var_key, parameter.plevs)
ds_ref_rg = regrid_z_axis_to_plevs(ds_ref, var_key, parameter.plevs)

for ilev, _ in enumerate(plev):
# TODO: Test the subsetting here with 3D variables
z_axis = get_z_axis(ds_test[var_key])
ds_test_ilev = ds_test.isel({f"{z_axis}": ilev})
ds_ref_ilev = ds_ref.isel({f"{z_axis}": ilev})
for ilev in plev:
z_axis_key = get_z_axis(ds_test_rg[var_key]).name
ds_test_ilev = ds_test_rg.sel({z_axis_key: ilev})
ds_ref_ilev = ds_ref_rg.sel({z_axis_key: ilev})

for region in regions:
(
Expand Down Expand Up @@ -307,12 +306,15 @@ def _set_param_output_attrs(
The parameter with updated output attributes.
"""
if ilev is None:
parameter.output_file = f"{ref_name}-{var_key}-{season}-{region}"
parameter.main_title = f"{var_key} {season} {region}"
output_file = f"{ref_name}-{var_key}-{season}-{region}"
main_title = f"{var_key} {season} {region}"
else:
ilev_str = str(int(ilev))
parameter.output_file = f"{ref_name}-{var_key}-{ilev_str}-{season}-{region}"
parameter.main_title = f"{var_key} {ilev_str} 'mb' {season} {region}"
output_file = f"{ref_name}-{var_key}-{ilev_str}-{season}-{region}"
main_title = f"{var_key} {ilev_str} 'mb' {season} {region}"

parameter.output_file = output_file
parameter.main_title = main_title

return parameter

Expand Down Expand Up @@ -396,10 +398,6 @@ def _get_metrics_by_region(
var_key, ds_test, ds_test_regrid, ds_ref, ds_ref_regrid, ds_diff
)

_save_data_metrics_and_plots(
parameter, var_key, metrics_dict, ds_test, ds_ref, ds_diff
)

return metrics_dict, ds_test, ds_ref, ds_diff


Expand Down Expand Up @@ -556,14 +554,14 @@ def _save_data_metrics_and_plots(
ds_diff,
)

filename = os.path.join(
get_output_dir(parameter.current_set, parameter),
parameter.output_file + ".json",
)
with open(filename, "w") as outfile:
output_dir = _get_output_dir(parameter)
filename = f"{parameter.output_file}.json"
filepath = os.path.join(output_dir, filename)

with open(filepath, "w") as outfile:
json.dump(metrics_dict, outfile)

logger.info(f"Metrics saved in {filename}")
logger.info(f"Metrics saved in {filepath}")

# Set the viewer description to the "long_name" attr of the variable.
parameter.viewer_descr[var_key] = ds_test[var_key].attrs.get(
Expand Down
32 changes: 19 additions & 13 deletions e3sm_diags/driver/utils/dataset_xr.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,22 @@ def __init__(
"Valid options include 'ref' or 'test'."
)

# Set the `start_yr` and `end_yr` attrs based on the dataset type.
# Note, these attrs are different for the `area_mean_time_series`
# parameter.
if self.parameter.sets[0] in ["area_mean_time_series"]:
self.start_yr = self.parameter.start_yr # type: ignore
self.end_yr = self.parameter.end_yr # type: ignore
elif self.data_type == "ref":
self.start_yr = self.parameter.ref_start_yr # type: ignore
self.end_yr = self.parameter.ref_end_yr # type: ignore
elif self.data_type == "test":
self.start_yr = self.parameter.test_start_yr # type: ignore
self.end_yr = self.parameter.test_end_yr # type: ignore
# If the underlying data is a time series, set the `start_yr` and
# `end_yr` attrs based on the data type (ref or test). Note, these attrs
# are different for the `area_mean_time_series` parameter.
if self.is_time_series:
# FIXME: This conditional should not assume the first set is
# area_mean_time_series. If area_mean_time_series is at another
# index, this conditional is not False.
if self.parameter.sets[0] in ["area_mean_time_series"]:
self.start_yr = self.parameter.start_yr # type: ignore
self.end_yr = self.parameter.end_yr # type: ignore
elif self.data_type == "ref":
self.start_yr = self.parameter.ref_start_yr # type: ignore
self.end_yr = self.parameter.ref_end_yr # type: ignore
elif self.data_type == "test":
self.start_yr = self.parameter.test_start_yr # type: ignore
self.end_yr = self.parameter.test_end_yr # type: ignore

# The derived variables defined in E3SM Diags. If the `CoreParameter`
# object contains additional user derived variables, they are added
Expand Down Expand Up @@ -969,7 +973,9 @@ def _get_land_sea_mask(self, season: str) -> xr.Dataset:
ds_land_frac = self.get_climo_dataset(LAND_FRAC_KEY, season) # type: ignore
ds_ocean_frac = self.get_climo_dataset(OCEAN_FRAC_KEY, season) # type: ignore
except IOError as e:
logger.warning(e)
logger.info(
f"{e}. Using default land sea mask located at `{LAND_OCEAN_MASK_PATH}`."
)

ds_mask = xr.open_dataset(LAND_OCEAN_MASK_PATH)
ds_mask = self._squeeze_time_dim(ds_mask)
Expand Down
101 changes: 75 additions & 26 deletions e3sm_diags/driver/utils/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,16 +359,22 @@ def regrid_z_axis_to_plevs(
Replaces `e3sm_diags.driver.utils.general.convert_to_pressure_levels`.
"""
ds = dataset.copy()

# Make sure that the input dataset has Z axis bounds, which are required for
# getting grid positions during vertical regridding.
try:
ds.bounds.get_bounds("Z")
except KeyError:
ds = ds.bounds.add_bounds("Z")

z_axis = get_z_axis(ds[var_key])
z_long_name = z_axis.attrs.get("long_name")

if z_long_name is None:
raise KeyError(
f"The vertical level ({z_axis.name}) for '{var_key}' does "
"not have a 'long_name' attribute to determine whether it is hybrid "
"or pressure."
)

z_long_name = z_long_name.lower()

# Hybrid must be the first conditional statement because the long_name attr
Expand All @@ -385,8 +391,11 @@ def regrid_z_axis_to_plevs(
"'pressure', or 'isobaric'."
)

# Add bounds for the new, regridded Z axis.
ds_plevs = ds_plevs.bounds.add_bounds(axis="Z")
# Add bounds for the new, regridded Z axis if the length is greater than 1.
# xCDAT does not support adding bounds for singleton coordinates.
new_z_axis = get_z_axis(ds_plevs[var_key])
if len(new_z_axis) > 1:
ds_plevs = ds_plevs.bounds.add_bounds("Z")

return ds_plevs

Expand Down Expand Up @@ -423,20 +432,12 @@ def _hybrid_to_plevs(
-----
Replaces `e3sm_diags.driver.utils.general.hybrid_to_plevs`.
"""
# TODO: Do we need to convert the Z axis to mb units if it is in PA? Or
# do we always expect units to be in mb?
# TODO: mb units are always expected, but we should consider checking
# the units to confirm whether or not unit conversion is needed.
z_axis, _ = xc.create_axis("lev", plevs, generate_bounds=False)

pressure_grid = xc.create_grid(z=z_axis)
pressure_coords = _hybrid_to_pressure(ds, var_key)

# Make sure that the input dataset has Z axis bounds, which are required for
# getting grid positions during vertical regridding.
try:
ds.bounds.get_bounds("Z")
except KeyError:
ds = ds.bounds.add_bounds("Z")

# Keep the "axis" and "coordinate" attributes for CF mapping.
with xr.set_options(keep_attrs=True):
result = ds.regridder.vertical(
Expand Down Expand Up @@ -497,7 +498,7 @@ def _hybrid_to_pressure(ds: xr.Dataset, var_key: str) -> xr.DataArray:
"'hyam' and/or 'hybm' to use for reconstructing to pressure data."
)

ps = _convert_units_to_mb(ps)
ps = _convert_dataarray_units_to_mb(ps)

pressure_coords = hyam * p0 + hybm * ps
pressure_coords.attrs["units"] = "mb"
Expand Down Expand Up @@ -564,14 +565,13 @@ def _pressure_to_plevs(
-----
Replaces `e3sm_diags.driver.utils.general.pressure_to_plevs`.
"""
# Convert pressure coordinates and bounds to mb if it is not already in mb.
ds = _convert_dataset_units_to_mb(ds, var_key)

# Create the output pressure grid to regrid to using the `plevs` array.
z_axis, _ = xc.create_axis("lev", plevs, generate_bounds=False)
pressure_grid = xc.create_grid(z=z_axis)

# Convert pressure coordinates to mb if it is not already in mb.
lev_key = xc.get_dim_keys(ds[var_key], axis="Z")
ds[lev_key] = _convert_units_to_mb(ds[lev_key])

# Keep the "axis" and "coordinate" attributes for CF mapping.
with xr.set_options(keep_attrs=True):
result = ds.regridder.vertical(
Expand All @@ -584,10 +584,57 @@ def _pressure_to_plevs(
return result


def _convert_units_to_mb(da: xr.DataArray) -> xr.DataArray:
"""Convert DataArray to mb (millibars) if not in mb.
def _convert_dataset_units_to_mb(ds: xr.Dataset, var_key: str) -> xr.Dataset:
"""Convert a dataset's Z axis and bounds to mb if they are not in mb.
Parameters
----------
ds : xr.Dataset
The dataset.
var_key : str
The key of the variable.
Returns
-------
xr.Dataset
The dataset with a Z axis in mb units.
Raises
------
RuntimeError
If the Z axis units does not align with the Z bounds units.
"""
z_axis = xc.get_dim_coords(ds[var_key], axis="Z")
z_bnds = ds.bounds.get_bounds(axis="Z", var_key=var_key)

# Make sure that Z and Z bounds units are aligned. If units do not exist
# assume they are the same because bounds usually don't have a units attr.
z_axis_units = z_axis.attrs["units"]
z_bnds_units = z_bnds.attrs.get("units")
if z_bnds_units is not None and z_bnds_units != z_axis_units:
raise RuntimeError(
f"The units for '{z_bnds.name}' ({z_bnds_units}) "
f"does not align with '{z_axis.name}' ({z_axis_units}). "
)
else:
z_bnds.attrs["units"] = z_axis_units

# Convert Z and Z bounds and update them in the Dataset.
z_axis_new = _convert_dataarray_units_to_mb(z_axis)
ds = ds.assign_coords({z_axis.name: z_axis_new})

z_bnds_new = _convert_dataarray_units_to_mb(z_bnds)
z_bnds_new[z_axis.name] = z_axis_new
ds[z_bnds.name] = z_bnds_new

return ds


def _convert_dataarray_units_to_mb(da: xr.DataArray) -> xr.DataArray:
"""Convert a dataarray to mb (millibars) if they are not in mb.
Unit conversion formulas:
* hPa = mb
* mb = Pa / 100
* Pa = (mb * 100)
Expand All @@ -614,17 +661,19 @@ def _convert_units_to_mb(da: xr.DataArray) -> xr.DataArray:

if units is None:
raise ValueError(
"'{ps.name}' has no 'units' attribute to determine if data is in 'mb' or "
"'Pa' units."
f"'{da.name}' has no 'units' attribute to determine if data is in'mb', "
"'hPa', or 'Pa' units."
)

if units == "mb":
pass
elif units == "Pa":
if units == "Pa":
with xr.set_options(keep_attrs=True):
da = da / 100.0

da.attrs["units"] = "mb"
elif units == "hPa":
da.attrs["units"] = "mb"
elif units == "mb":
pass
else:
raise ValueError(
f"'{da.name}' should be in 'mb' or 'Pa' (which gets converted to 'mb'), "
Expand Down
1 change: 1 addition & 0 deletions e3sm_diags/plot/cartopy/arm_diags_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def plot_convection_onset_statistics(
var_time_absolute = cwv.getTime().asComponentTime()
time_interval = int(var_time_absolute[1].hour - var_time_absolute[0].hour)

# FIXME: UnboundLocalError: local variable 'cwv_max' referenced before assignment
number_of_bins = int(np.ceil((cwv_max - cwv_min) / bin_width))
bin_center = np.arange(
(cwv_min + (bin_width / 2)),
Expand Down
4 changes: 4 additions & 0 deletions e3sm_diags/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def get_final_parameters(self, parameters):
"""
Based on sets_to_run and the list of parameters,
get the final list of paremeters to run the diags on.
FIXME: This function was only designed to take in 1 parameter at a
time or a mix of different parameters. If there are two
CoreParameter objects, it will break.
"""
if not parameters or not isinstance(parameters, list):
msg = "You must pass in a list of parameter objects."
Expand Down
Loading

0 comments on commit 50ce827

Please sign in to comment.