From 7bc2657da3e77a2f0eb5996cae83c0df015d8fc3 Mon Sep 17 00:00:00 2001 From: ChengzhuZhang Date: Thu, 21 Mar 2024 16:16:57 -0700 Subject: [PATCH] update driver --- .../driver/annual_cycle_zonal_mean_driver.py | 127 +++++++----------- 1 file changed, 49 insertions(+), 78 deletions(-) diff --git a/e3sm_diags/driver/annual_cycle_zonal_mean_driver.py b/e3sm_diags/driver/annual_cycle_zonal_mean_driver.py index ef1955854..44eca4f9e 100755 --- a/e3sm_diags/driver/annual_cycle_zonal_mean_driver.py +++ b/e3sm_diags/driver/annual_cycle_zonal_mean_driver.py @@ -9,9 +9,9 @@ from e3sm_diags.driver.utils.dataset_xr import Dataset from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots from e3sm_diags.driver.utils.regrid import ( + align_grids_to_lower_res, get_z_axis, has_z_axis, - regrid_z_axis_to_plevs, ) from e3sm_diags.logger import custom_logger from e3sm_diags.metrics.metrics import spatial_avg @@ -69,9 +69,9 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: parameter._set_name_yrs_attrs(test_ds, ref_ds, "01") - ds_test = test_ds.get_climo_dataset(var_key, "01") + ds_test = test_ds.get_climo_dataset(var_key, "ANNUALCYCLE") # TODO consider to refactor the behavior of get_ref_climo_dataset - ds_ref = ref_ds.get_ref_climo_dataset(var_key, "01", ds_test) + ds_ref = ref_ds.get_ref_climo_dataset(var_key, "ANNUALCYCLE", ds_test) # Store the variable's DataArray objects for reuse. dv_test = ds_test[var_key] @@ -81,11 +81,10 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: is_dims_diff = has_z_axis(dv_test) != has_z_axis(dv_ref) if not is_vars_3d: - _run_diags_2d( + _run_diags_annual_cycle( parameter, ds_test, ds_ref, - "AC", regions, var_key, ref_name, @@ -102,16 +101,15 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: return parameter -def _run_diags_2d( +def _run_diags_annual_cycle( parameter: CoreParameter, ds_test: xr.Dataset, ds_ref: xr.Dataset, - season: str, regions: List[str], var_key: str, ref_name: str, ): - """Run diagnostics on a 2D variable. + """Run annual cycle zonal run diagnostics. This function gets the variable's metrics by region, then saves the metrics, metric plots, and data (optional, `CoreParameter.save_netcdf`). @@ -125,8 +123,6 @@ def _run_diags_2d( ds_ref : xr.Dataset The dataset containing the ref variable. If this is a model-only run then it will be the same dataset as ``ds_test``. - season : str - The season. regions : List[str] The list of regions. var_key : str @@ -137,78 +133,53 @@ def _run_diags_2d( for region in regions: logger.info(f"Selected region: {region}") - parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev=None) - # Calculate annual cycle - # Regridding - # Calculate zonal mean - da_test_1d, da_ref_1d = _calc_zonal_mean(ds_test, ds_ref, var_key) - da_diff_1d = _get_diff_of_zonal_means(da_test_1d, da_ref_1d) - - _save_data_metrics_and_plots( - parameter, - plot_func, - var_key, - da_test_1d.to_dataset(), - da_ref_1d.to_dataset(), - da_diff_1d.to_dataset(), - metrics_dict=None, + parameter._set_param_output_attrs( + var_key, "ANNUALCYCLE", region, ref_name, ilev=None ) + # align grids -def _get_annual_cycle( - ds_test: xr.Dataset, - var_key: str, -) -> xr.DataArray: - """Get annual cycle. - - # TODO: Write unit tests for this function. - - Parameters - ---------- - ds_test : xr.Dataset - The dataset containing the test variable. - var_key : str - The key of the variable. - - Returns - ------- - xr.DataArray - xarray DatAarray - """ - months = range(1, 13) - month_list = [f"{x:02}" for x in list(months)] - - for index, month in enumerate(month_list): - ds_test_mon = ds_test.get_climo_dataset(var_key, month) - print(ds_test_mon) - - -def _calc_zonal_mean( - ds_test: xr.Dataset, - ds_ref: xr.Dataset, - var_key: str, -) -> Tuple[xr.DataArray, xr.DataArray]: - """Calculate zonal mean metrics. + ds_test_reg, ds_ref_reg = align_grids_to_lower_res( + ds_test, + ds_ref, + var_key, + parameter.regrid_tool, + parameter.regrid_method, + ) - # TODO: Write unit tests for this function. + test_zonal_mean = spatial_avg(ds_test, var_key, axis=["X"], as_list=False) + test_reg_zonal_mean = spatial_avg( + ds_test_reg, var_key, axis=["X"], as_list=False + ) - Parameters - ---------- - ds_test : xr.Dataset - The dataset containing the test variable. - ds_ref : xr.Dataset - The dataset containing the ref variable. If this is a model-only run - then it will be the same dataset as ``ds_test``. - var_key : str - The key of the variable. + if ( + parameter.ref_name == "OMI-MLS" + ): # SCO from OMI-MLS only available as (time, lat) + test_zonal_mean = test_zonal_mean.sel(lat=(-60, 60)) + test_reg_zonal_mean = test_reg_zonal_mean.sel(lat=(-60, 60)) + if var == "SCO": + ref_zonal_mean = ref_ac + ref_reg_zonal_mean = ref_ac_reg + else: + ref_zonal_mean = spatial_avg(ds_ref, var_key, axis=["X"], as_list=False) + ref_reg_zonal_mean = spatial_avg( + ds_ref_reg, var_key, axis=["X"], as_list=False + ) + + else: + ref_zonal_mean = spatial_avg(ds_ref, var_key, axis=["X"], as_list=False) + ref_reg_zonal_mean = spatial_avg( + ds_ref_reg, var_key, axis=["X"], as_list=False + ) - Returns - ------- - Tuple[xr.DataArray, xr.DataArray] - A Tuple containing the zonal mean for the test variable and the ref - variable. - """ - da_test = spatial_avg(ds_test, var_key, axis=["X"], as_list=False) - da_ref = spatial_avg(ds_ref, var_key, axis=["X"], as_list=False) + diff = test_reg_zonal_mean - ref_reg_zonal_mean - return da_test, da_ref # type: ignore + _save_data_metrics_and_plots( + parameter, + plot_func, + var_key, + test_zonal_mean.to_dataset(), + ref_zonal_mean.to_dataset(), + diff.to_dataset(), + metrics_dict=None, + )