Skip to content

Commit

Permalink
Add initial cosp_histogram refactor updates
Browse files Browse the repository at this point in the history
- Refactor most of `cosp_histogram_driver.py`
- Move `cosp_histogram_plot.py` up a directory level
- Add `template_cdat_regression_script.py`
  • Loading branch information
tomvothecoder committed Dec 5, 2023
1 parent 340d3b3 commit 295c027
Show file tree
Hide file tree
Showing 5 changed files with 372 additions and 100 deletions.
46 changes: 46 additions & 0 deletions auxiliary_tools/template_cdat_regression_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# %%
import os

from e3sm_diags.parameter.core_parameter import CoreParameter
from e3sm_diags.run import runner

param = CoreParameter()

# %%
param.sets = ["cosp_histogram"]
param.case_id = "MISR-COSP"
param.variables = ["COSP_HISTOGRAM_MISR"]
param.seasons = ["ANN"]
param.contour_levels = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5]
param.diff_levels = [
-3.0,
-2.5,
-2.0,
-1.5,
-1.0,
-0.5,
0,
0.5,
1.0,
1.5,
2.0,
2.5,
3.0,
]

param.test_name = "system tests"
param.short_test_name = "short_system tests"
param.ref_name = "MISRCOSP"
param.reference_name = "MISR COSP (2000-2009)"
param.reference_data_path = "tests/integration"
param.ref_file = "CLDMISR_ERA-Interim_ANN_198001_201401_climo.nc"
param.test_data_path = "tests/integration"
param.test_file = "CLD_MISR_20161118.beta0.FC5COSP.ne30_ne30.edison_ANN_climo.nc"

param.backend = "mpl"
prefix = "/global/cfs/cdirs/e3sm/www/vo13/examples"
param.results_dir = os.path.join(prefix, "cdat_regression_tests/", param.sets[0])
param.multiprocessing = False

# %%
runner.run_diags([param])
157 changes: 58 additions & 99 deletions e3sm_diags/driver/cosp_histogram_driver.py
Original file line number Diff line number Diff line change
@@ -1,130 +1,89 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING

import cdms2

import e3sm_diags
from e3sm_diags.driver import utils
from e3sm_diags.driver.utils.dataset_xr import Dataset
from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots
from e3sm_diags.driver.utils.regrid import _subset_on_region
from e3sm_diags.logger import custom_logger
from e3sm_diags.metrics import corr, max_cdms, mean, min_cdms, rmse
from e3sm_diags.plot import plot
from e3sm_diags.metrics.metrics import spatial_avg
from e3sm_diags.plot.cosp_histogram_plot import plot as plot_func

if TYPE_CHECKING:
from e3sm_diags.parameter.core_parameter import CoreParameter

logger = custom_logger(__name__)


def create_metrics(ref, test, ref_regrid, test_regrid, diff):
"""Creates the mean, max, min, rmse, corr in a dictionary"""
metrics_dict = {}
metrics_dict["ref"] = {
"min": min_cdms(ref),
"max": max_cdms(ref),
"mean": mean(ref),
}
metrics_dict["test"] = {
"min": min_cdms(test),
"max": max_cdms(test),
"mean": mean(test),
}

metrics_dict["diff"] = {
"min": min_cdms(diff),
"max": max_cdms(diff),
"mean": mean(diff),
}
metrics_dict["misc"] = {
"rmse": rmse(test_regrid, ref_regrid),
"corr": corr(test_regrid, ref_regrid),
}

return metrics_dict
def run_diag(parameter: CoreParameter) -> CoreParameter:
"""Get metrics for the cosp_histogram diagnostic set.
This funciton loops over each variable, season, pressure level, and region.
It subsets the test and reference variables on the selected region, then
calculates the spatial average for both variables. The difference between
the test and reference spatial averages is calculated. Afterwards, the
spatial averages for the test, ref, and differences are plotted.
def run_diag(parameter: CoreParameter) -> CoreParameter:
Parameters
----------
parameter : CoreParameter
The parameter for the diagnostic.
Returns
-------
CoreParameter
The parameter for the diagnostic with the result (completed or failed).
"""
variables = parameter.variables
seasons = parameter.seasons
ref_name = getattr(parameter, "ref_name", "")
regions = parameter.regions

test_data = utils.dataset.Dataset(parameter, test=True)
ref_data = utils.dataset.Dataset(parameter, ref=True)

for season in seasons:
# Get the name of the data, appended with the years averaged.
parameter.test_name_yrs = utils.general.get_name_and_yrs(
parameter, test_data, season
)
parameter.ref_name_yrs = utils.general.get_name_and_yrs(
parameter, ref_data, season
)

# Get land/ocean fraction for masking.
try:
land_frac = test_data.get_climo_variable("LANDFRAC", season)
ocean_frac = test_data.get_climo_variable("OCNFRAC", season)
except Exception:
mask_path = os.path.join(
e3sm_diags.INSTALL_PATH, "acme_ne30_ocean_land_mask.nc"
)
with cdms2.open(mask_path) as f:
land_frac = f("LANDFRAC")
ocean_frac = f("OCNFRAC")

for var in variables:
logger.info("Variable: {}".format(var))
parameter.var_id = var

mv1 = test_data.get_climo_variable(var, season)
mv2 = ref_data.get_climo_variable(var, season)

parameter.viewer_descr[var] = (
mv1.long_name
if hasattr(mv1, "long_name")
else "No long_name attr in test data."
)
test_ds = Dataset(parameter, data_type="test")
ref_ds = Dataset(parameter, data_type="ref")

for region in regions:
logger.info("Selected region: {}".format(region))
for var_key in variables:
logger.info("Variable: {}".format(var_key))
parameter.var_id = var_key

mv1_domain = utils.general.select_region(
region, mv1, land_frac, ocean_frac, parameter
)
mv2_domain = utils.general.select_region(
region, mv2, land_frac, ocean_frac, parameter
)
for season in seasons:
parameter._set_name_yrs_attrs(test_ds, ref_ds, season)

parameter.output_file = "-".join([ref_name, var, season, region])
parameter.main_title = str(" ".join([var, season, region]))
ds_test = test_ds.get_climo_dataset(var_key, season)
ds_ref = ref_ds.get_ref_climo_dataset(var_key, season, ds_test)

mv1_domain_mean = mean(mv1_domain)
mv2_domain_mean = mean(mv2_domain)
diff = mv1_domain_mean - mv2_domain_mean
for region in regions:
logger.info("Selected region: {}".format(region))

mv1_domain_mean.id = var
mv2_domain_mean.id = var
diff.id = var
ds_test_region = _subset_on_region(ds_test, var_key, region)
ds_ref_region = _subset_on_region(ds_ref, var_key, region)

parameter.backend = (
"mpl" # For now, there's no vcs support for this set.
parameter._set_param_output_attrs(
var_key, season, region, ref_name, ilev=None
)
plot(
parameter.current_set,
mv2_domain_mean,
mv1_domain_mean,
diff,
{},
parameter,

# Make a copy of the regional datasets to overwrite the existing
# variable with its spatial average.
ds_test_region_avg = ds_test_region.copy()
ds_ref_region_avg = ds_ref_region.copy()
ds_test_region_avg[var_key] = spatial_avg(
ds_test_region, var_key, as_list=False
)
ds_ref_region_avg[var_key] = spatial_avg(
ds_ref_region, var_key, as_list=False
)
utils.general.save_ncfiles(
parameter.current_set,
mv1_domain,
mv2_domain,
diff,
ds_diff_region_avg = ds_test_region_avg - ds_ref_region_avg

# TODO: Need to update this function to use cosp_histogram_plot.py
_save_data_metrics_and_plots(
parameter,
plot_func,
var_key,
ds_test_region_avg,
ds_ref_region_avg,
ds_diff_region_avg,
metrics_dict=None,
)

return parameter
1 change: 1 addition & 0 deletions e3sm_diags/driver/lat_lon_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

logger = custom_logger(__name__)


if TYPE_CHECKING:
from e3sm_diags.parameter.core_parameter import CoreParameter

Expand Down
1 change: 0 additions & 1 deletion e3sm_diags/parameter/core_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

logger = custom_logger(__name__)


if TYPE_CHECKING:
from e3sm_diags.driver.utils.dataset_xr import Dataset

Expand Down
Loading

0 comments on commit 295c027

Please sign in to comment.