Skip to content

Commit

Permalink
update driver
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzhuzhang committed Mar 21, 2024
1 parent b327d93 commit 7bc2657
Showing 1 changed file with 49 additions and 78 deletions.
127 changes: 49 additions & 78 deletions e3sm_diags/driver/annual_cycle_zonal_mean_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from e3sm_diags.driver.utils.dataset_xr import Dataset
from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots
from e3sm_diags.driver.utils.regrid import (
align_grids_to_lower_res,
get_z_axis,
has_z_axis,
regrid_z_axis_to_plevs,
)
from e3sm_diags.logger import custom_logger
from e3sm_diags.metrics.metrics import spatial_avg
Expand Down Expand Up @@ -69,9 +69,9 @@ def run_diag(parameter: CoreParameter) -> CoreParameter:

parameter._set_name_yrs_attrs(test_ds, ref_ds, "01")

ds_test = test_ds.get_climo_dataset(var_key, "01")
ds_test = test_ds.get_climo_dataset(var_key, "ANNUALCYCLE")
# TODO consider to refactor the behavior of get_ref_climo_dataset
ds_ref = ref_ds.get_ref_climo_dataset(var_key, "01", ds_test)
ds_ref = ref_ds.get_ref_climo_dataset(var_key, "ANNUALCYCLE", ds_test)

# Store the variable's DataArray objects for reuse.
dv_test = ds_test[var_key]
Expand All @@ -81,11 +81,10 @@ def run_diag(parameter: CoreParameter) -> CoreParameter:
is_dims_diff = has_z_axis(dv_test) != has_z_axis(dv_ref)

if not is_vars_3d:
_run_diags_2d(
_run_diags_annual_cycle(
parameter,
ds_test,
ds_ref,
"AC",
regions,
var_key,
ref_name,
Expand All @@ -102,16 +101,15 @@ def run_diag(parameter: CoreParameter) -> CoreParameter:
return parameter


def _run_diags_2d(
def _run_diags_annual_cycle(
parameter: CoreParameter,
ds_test: xr.Dataset,
ds_ref: xr.Dataset,
season: str,
regions: List[str],
var_key: str,
ref_name: str,
):
"""Run diagnostics on a 2D variable.
"""Run annual cycle zonal run diagnostics.
This function gets the variable's metrics by region, then saves the
metrics, metric plots, and data (optional, `CoreParameter.save_netcdf`).
Expand All @@ -125,8 +123,6 @@ def _run_diags_2d(
ds_ref : xr.Dataset
The dataset containing the ref variable. If this is a model-only run
then it will be the same dataset as ``ds_test``.
season : str
The season.
regions : List[str]
The list of regions.
var_key : str
Expand All @@ -137,78 +133,53 @@ def _run_diags_2d(
for region in regions:
logger.info(f"Selected region: {region}")

parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev=None)
# Calculate annual cycle
# Regridding
# Calculate zonal mean
da_test_1d, da_ref_1d = _calc_zonal_mean(ds_test, ds_ref, var_key)
da_diff_1d = _get_diff_of_zonal_means(da_test_1d, da_ref_1d)

_save_data_metrics_and_plots(
parameter,
plot_func,
var_key,
da_test_1d.to_dataset(),
da_ref_1d.to_dataset(),
da_diff_1d.to_dataset(),
metrics_dict=None,
parameter._set_param_output_attrs(
var_key, "ANNUALCYCLE", region, ref_name, ilev=None
)

# align grids

def _get_annual_cycle(
ds_test: xr.Dataset,
var_key: str,
) -> xr.DataArray:
"""Get annual cycle.
# TODO: Write unit tests for this function.
Parameters
----------
ds_test : xr.Dataset
The dataset containing the test variable.
var_key : str
The key of the variable.
Returns
-------
xr.DataArray
xarray DatAarray
"""
months = range(1, 13)
month_list = [f"{x:02}" for x in list(months)]

for index, month in enumerate(month_list):
ds_test_mon = ds_test.get_climo_dataset(var_key, month)
print(ds_test_mon)


def _calc_zonal_mean(
ds_test: xr.Dataset,
ds_ref: xr.Dataset,
var_key: str,
) -> Tuple[xr.DataArray, xr.DataArray]:
"""Calculate zonal mean metrics.
ds_test_reg, ds_ref_reg = align_grids_to_lower_res(
ds_test,
ds_ref,
var_key,
parameter.regrid_tool,
parameter.regrid_method,
)

# TODO: Write unit tests for this function.
test_zonal_mean = spatial_avg(ds_test, var_key, axis=["X"], as_list=False)
test_reg_zonal_mean = spatial_avg(
ds_test_reg, var_key, axis=["X"], as_list=False
)

Parameters
----------
ds_test : xr.Dataset
The dataset containing the test variable.
ds_ref : xr.Dataset
The dataset containing the ref variable. If this is a model-only run
then it will be the same dataset as ``ds_test``.
var_key : str
The key of the variable.
if (
parameter.ref_name == "OMI-MLS"
): # SCO from OMI-MLS only available as (time, lat)
test_zonal_mean = test_zonal_mean.sel(lat=(-60, 60))
test_reg_zonal_mean = test_reg_zonal_mean.sel(lat=(-60, 60))
if var == "SCO":
ref_zonal_mean = ref_ac
ref_reg_zonal_mean = ref_ac_reg
else:
ref_zonal_mean = spatial_avg(ds_ref, var_key, axis=["X"], as_list=False)
ref_reg_zonal_mean = spatial_avg(
ds_ref_reg, var_key, axis=["X"], as_list=False
)

else:
ref_zonal_mean = spatial_avg(ds_ref, var_key, axis=["X"], as_list=False)
ref_reg_zonal_mean = spatial_avg(
ds_ref_reg, var_key, axis=["X"], as_list=False
)

Returns
-------
Tuple[xr.DataArray, xr.DataArray]
A Tuple containing the zonal mean for the test variable and the ref
variable.
"""
da_test = spatial_avg(ds_test, var_key, axis=["X"], as_list=False)
da_ref = spatial_avg(ds_ref, var_key, axis=["X"], as_list=False)
diff = test_reg_zonal_mean - ref_reg_zonal_mean

return da_test, da_ref # type: ignore
_save_data_metrics_and_plots(
parameter,
plot_func,
var_key,
test_zonal_mean.to_dataset(),
ref_zonal_mean.to_dataset(),
diff.to_dataset(),
metrics_dict=None,
)

0 comments on commit 7bc2657

Please sign in to comment.