diff --git a/e3sm_diags/driver/lat_lon_driver.py b/e3sm_diags/driver/lat_lon_driver.py index f821c7a721..bea50fb944 100755 --- a/e3sm_diags/driver/lat_lon_driver.py +++ b/e3sm_diags/driver/lat_lon_driver.py @@ -1,13 +1,11 @@ from __future__ import annotations -import json -import os -from typing import TYPE_CHECKING, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, List, Tuple import xarray as xr from e3sm_diags.driver.utils.dataset_xr import Dataset -from e3sm_diags.driver.utils.io import _get_output_dir, _write_vars_to_netcdf +from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots from e3sm_diags.driver.utils.regrid import ( _apply_land_sea_mask, _subset_on_region, @@ -16,20 +14,13 @@ has_z_axis, regrid_z_axis_to_plevs, ) +from e3sm_diags.driver.utils.type_annotations import MetricsDict from e3sm_diags.logger import custom_logger from e3sm_diags.metrics.metrics import correlation, rmse, spatial_avg, std -from e3sm_diags.plot.lat_lon_plot import plot +from e3sm_diags.plot.lat_lon_plot import plot as plot_func logger = custom_logger(__name__) -# The type annotation for the metrics dictionary. The key is the -# type of metrics and the value is a sub-dictionary of metrics (key is metrics -# type and value is float). There is also a "unit" key representing the -# units for the variable. -UnitAttr = str -MetricsSubDict = Dict[str, Union[float, None, List[float]]] -MetricsDict = Dict[str, Union[UnitAttr, MetricsSubDict]] - if TYPE_CHECKING: from e3sm_diags.parameter.core_parameter import CoreParameter @@ -72,8 +63,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: parameter.var_id = var_key for season in seasons: - parameter.test_name_yrs = test_ds.get_name_yrs_attr(season) - parameter.ref_name_yrs = ref_ds.get_name_yrs_attr(season) + parameter._set_name_yrs_attrs(test_ds, ref_ds, season) # The land sea mask dataset that is used for masking if the region # is either land or sea. This variable is instantiated here to get @@ -81,22 +71,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: ds_land_sea_mask: xr.Dataset = test_ds._get_land_sea_mask(season) ds_test = test_ds.get_climo_dataset(var_key, season) - - # If the reference climatology dataset cannot be retrieved - # it will be set the to the test climatology dataset which means - # analysis is only performed on the test dataset. - # TODO: This logic was carried over from legacy implementation. It - # can probably be improved on by setting `ds_ref = None` and not - # performing unnecessary operations on `ds_ref` for model-only runs, - # since it is the same as `ds_test``. - try: - ds_ref = ref_ds.get_climo_dataset(var_key, season) - parameter.model_only = False - except (RuntimeError, IOError): - ds_ref = ds_test - parameter.model_only = True - - logger.info("Cannot process reference data, analyzing test data only.") + ds_ref = ref_ds.get_ref_climo_dataset(var_key, season, ds_test) # Store the variable's DataArray objects for reuse. dv_test = ds_test[var_key] @@ -173,9 +148,8 @@ def _run_diags_2d( The reference name. """ for region in regions: - parameter = _set_param_output_attrs( - parameter, var_key, season, region, ref_name, ilev=None - ) + parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev=None) + ( metrics_dict, ds_test_region, @@ -191,11 +165,12 @@ def _run_diags_2d( ) _save_data_metrics_and_plots( parameter, + plot_func, var_key, - metrics_dict, ds_test_region, ds_ref_region, ds_diff_region, + metrics_dict, ) @@ -261,64 +236,18 @@ def _run_diags_3d( region, ) - parameter = _set_param_output_attrs( - parameter, var_key, season, region, ref_name, ilev - ) + parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev) _save_data_metrics_and_plots( parameter, + plot_func, var_key, - metrics_dict, ds_test_region, ds_ref_region, ds_diff_region, + metrics_dict, ) -def _set_param_output_attrs( - parameter: CoreParameter, - var_key: str, - season: str, - region: str, - ref_name: str, - ilev: float | None, -) -> CoreParameter: - """Set the parameter output attributes based on argument values. - - Parameters - ---------- - parameter : CoreParameter - The parameter. - var_key : str - The variable key. - season : str - The season. - region : str - The region. - ref_name : str - The reference name, - ilev : float | None - The pressure level, by default None. This option is only set if the - variable is 3D. - - Returns - ------- - CoreParameter - The parameter with updated output attributes. - """ - if ilev is None: - output_file = f"{ref_name}-{var_key}-{season}-{region}" - main_title = f"{var_key} {season} {region}" - else: - ilev_str = str(int(ilev)) - output_file = f"{ref_name}-{var_key}-{ilev_str}-{season}-{region}" - main_title = f"{var_key} {ilev_str} 'mb' {season} {region}" - - parameter.output_file = output_file - parameter.main_title = main_title - - return parameter - - def _get_metrics_by_region( parameter: CoreParameter, ds_test: xr.Dataset, @@ -516,62 +445,3 @@ def _create_metrics_dict( } return metrics_dict - - -def _save_data_metrics_and_plots( - parameter: CoreParameter, - var_key: str, - metrics_dict: MetricsDict, - ds_test: xr.Dataset, - ds_ref: xr.Dataset | None, - ds_diff: xr.Dataset | None, -): - """Save data (optional), metrics, and plots. - - Parameters - ---------- - parameter : CoreParameter - The parameter for the diagnostic. - var_key : str - The variable key. - metrics_dict : Metrics - The dictionary containing metrics for the variable. - ds_test : xr.Dataset - The test dataset. - ds_ref : xr.Dataset | None - The optional reference dataset. If the diagnostic is a model-only run, - then it will be None. - ds_diff : xr.Dataset | None - The optional difference dataset. If the diagnostic is a model-only run, - then it will be None. - """ - if parameter.save_netcdf: - _write_vars_to_netcdf( - parameter, - var_key, - ds_test, - ds_ref, - ds_diff, - ) - - output_dir = _get_output_dir(parameter) - filename = f"{parameter.output_file}.json" - filepath = os.path.join(output_dir, filename) - - with open(filepath, "w") as outfile: - json.dump(metrics_dict, outfile) - - logger.info(f"Metrics saved in {filepath}") - - # Set the viewer description to the "long_name" attr of the variable. - parameter.viewer_descr[var_key] = ds_test[var_key].attrs.get( - "long_name", "No long_name attr in test data" - ) - - plot( - ds_test[var_key], - ds_ref[var_key] if ds_ref is not None else None, - ds_diff[var_key] if ds_diff is not None else None, - metrics_dict, - parameter, - ) diff --git a/e3sm_diags/driver/utils/dataset_xr.py b/e3sm_diags/driver/utils/dataset_xr.py index 5e27b2c139..251e56c493 100644 --- a/e3sm_diags/driver/utils/dataset_xr.py +++ b/e3sm_diags/driver/utils/dataset_xr.py @@ -16,7 +16,7 @@ import glob import os import re -from typing import Callable, Dict, Literal, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Literal, Tuple import xarray as xr import xcdat as xc @@ -29,7 +29,10 @@ from e3sm_diags.driver import LAND_FRAC_KEY, LAND_OCEAN_MASK_PATH, OCEAN_FRAC_KEY from e3sm_diags.driver.utils.climo_xr import CLIMO_FREQ, CLIMO_FREQS, climo from e3sm_diags.logger import custom_logger -from e3sm_diags.parameter.core_parameter import CoreParameter + +if TYPE_CHECKING: + from e3sm_diags.parameter.core_parameter import CoreParameter + logger = custom_logger(__name__) @@ -256,6 +259,58 @@ def _get_global_attr_from_climo_dataset( # -------------------------------------------------------------------------- # Climatology related methods # -------------------------------------------------------------------------- + def get_ref_climo_dataset( + self, var_key: str, season: CLIMO_FREQ, ds_test: xr.Dataset + ): + """Get the reference climatology dataset for the variable and season. + + If the reference climatatology does not exist or could not be found, it + will be considered a model-only run. For this case the test dataset + is returned as a default value and subsequent metrics calculations will + only be performed on the original test dataset. + + Parameters + ---------- + var_key : str + The key of the variable. + season : CLIMO_FREQ + The climatology frequency. + ds_test : xr.Dataset + The test dataset, which is returned if the reference climatology + does not exist or could not be found. + + Returns + ------- + xr.Dataset + The reference climatology if it exists or a copy of the test dataset + if it does not exist. + + Raises + ------ + RuntimeError + If `self.data_type` is not "ref". + """ + # TODO: This logic was carried over from legacy implementation. It + # can probably be improved on by setting `ds_ref = None` and not + # performing unnecessary operations on `ds_ref` for model-only runs, + # since it is the same as `ds_test``. + if self.data_type == "ref": + try: + ds_ref = self.get_climo_dataset(var_key, season) + self.model_only = False + except (RuntimeError, IOError): + ds_ref = ds_test.copy() + self.model_only = True + + logger.info("Cannot process reference data, analyzing test data only.") + else: + raise RuntimeError( + "`Dataset._get_ref_dataset` only works with " + f"`self.data_type == 'ref'`, not {self.data_type}." + ) + + return ds_ref + def get_climo_dataset(self, var: str, season: CLIMO_FREQ) -> xr.Dataset: """Get the dataset containing the climatology variable. diff --git a/e3sm_diags/driver/utils/io.py b/e3sm_diags/driver/utils/io.py index 72b3309773..09e4794da4 100644 --- a/e3sm_diags/driver/utils/io.py +++ b/e3sm_diags/driver/utils/io.py @@ -1,16 +1,82 @@ from __future__ import annotations import errno +import json import os +from typing import Callable import xarray as xr +from e3sm_diags.driver.utils.type_annotations import MetricsDict from e3sm_diags.logger import custom_logger from e3sm_diags.parameter.core_parameter import CoreParameter logger = custom_logger(__name__) +def _save_data_metrics_and_plots( + parameter: CoreParameter, + plot_func: Callable, + var_key: str, + ds_test: xr.Dataset, + ds_ref: xr.Dataset | None, + ds_diff: xr.Dataset | None, + metrics_dict: MetricsDict | None, +): + """Save data (optional), metrics, and plots. + + Parameters + ---------- + parameter : CoreParameter + The parameter for the diagnostic. + plot_func: Callable + The plot function for the diagnostic set. + var_key : str + The variable key. + ds_test : xr.Dataset + The test dataset. + ds_ref : xr.Dataset | None + The optional reference dataset. If the diagnostic is a model-only run, + then it will be None. + ds_diff : xr.Dataset | None + The optional difference dataset. If the diagnostic is a model-only run, + then it will be None. + metrics_dict : Metrics + The dictionary containing metrics for the variable. + """ + if parameter.save_netcdf: + _write_vars_to_netcdf( + parameter, + var_key, + ds_test, + ds_ref, + ds_diff, + ) + + output_dir = _get_output_dir(parameter) + filename = f"{parameter.output_file}.json" + filepath = os.path.join(output_dir, filename) + + if metrics_dict is not None: + with open(filepath, "w") as outfile: + json.dump(metrics_dict, outfile) + + logger.info(f"Metrics saved in {filepath}") + + # Set the viewer description to the "long_name" attr of the variable. + parameter.viewer_descr[var_key] = ds_test[var_key].attrs.get( + "long_name", "No long_name attr in test data" + ) + + plot_func( + parameter, + ds_test[var_key], + ds_ref[var_key] if ds_ref is not None else None, + ds_diff[var_key] if ds_diff is not None else None, + metrics_dict, + ) + + def _write_vars_to_netcdf( parameter: CoreParameter, var_key, diff --git a/e3sm_diags/driver/utils/type_annotations.py b/e3sm_diags/driver/utils/type_annotations.py new file mode 100644 index 0000000000..4132e65147 --- /dev/null +++ b/e3sm_diags/driver/utils/type_annotations.py @@ -0,0 +1,9 @@ +from typing import Dict, List, Union + +# The type annotation for the metrics dictionary. The key is the +# type of metrics and the value is a sub-dictionary of metrics (key is metrics +# type and value is float). There is also a "unit" key representing the +# units for the variable. +UnitAttr = str +MetricsSubDict = Dict[str, Union[float, None, List[float]]] +MetricsDict = Dict[str, Union[UnitAttr, MetricsSubDict]] diff --git a/e3sm_diags/metrics/metrics.py b/e3sm_diags/metrics/metrics.py index 27f140a467..68e5a4fc13 100644 --- a/e3sm_diags/metrics/metrics.py +++ b/e3sm_diags/metrics/metrics.py @@ -1,4 +1,6 @@ """This module stores functions to calculate metrics using Xarray objects.""" +from __future__ import annotations + from typing import List import xarray as xr @@ -28,7 +30,9 @@ def get_weights(ds: xr.Dataset): return ds.spatial.get_weights(axis=["X", "Y"]) -def spatial_avg(ds: xr.Dataset, var_key: str) -> List[float]: +def spatial_avg( + ds: xr.Dataset, var_key: str, as_list: bool = True +) -> List[float] | xr.DataArray: """Compute a variable's weighted spatial average. Parameters @@ -37,10 +41,13 @@ def spatial_avg(ds: xr.Dataset, var_key: str) -> List[float]: The dataset containing the variable. var_key : str The key of the varible. + as_list : bool + Return the spatial average as a list of floats, by default True. + If False, return an xr.DataArray. Returns ------- - List[float] + List[float] | xr.DataArray The spatial average of the variable based on the specified axis. Raises @@ -55,9 +62,10 @@ def spatial_avg(ds: xr.Dataset, var_key: str) -> List[float]: ds_avg = ds.spatial.average(var_key, axis=AXES, weights="generate") results = ds_avg[var_key] - results_list = results.data.tolist() + if as_list: + return results.data.tolist() - return results_list + return results def std(ds: xr.Dataset, var_key: str) -> List[float]: diff --git a/e3sm_diags/parameter/core_parameter.py b/e3sm_diags/parameter/core_parameter.py index b76ed97c8f..4e97e0de45 100644 --- a/e3sm_diags/parameter/core_parameter.py +++ b/e3sm_diags/parameter/core_parameter.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import copy import importlib import sys -from typing import Any, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List from e3sm_diags.derivations.derivations import DerivedVariablesMap from e3sm_diags.driver.utils.climo_xr import CLIMO_FREQ @@ -11,6 +13,10 @@ logger = custom_logger(__name__) +if TYPE_CHECKING: + from e3sm_diags.driver.utils.dataset_xr import Dataset + + class CoreParameter: def __init__(self): # File I/O @@ -120,7 +126,7 @@ def __init__(self): self.dpi: int = 150 self.arrows: bool = True self.logo: bool = False - self.contour_levels: List[str] = [] + self.contour_levels: List[float] = [] # Test plot settings self.test_name: str = "" @@ -155,7 +161,7 @@ def __init__(self): self.diff_name: str = "" self.diff_title: str = "Model - Observation" self.diff_colormap: str = "diverging_bwr.rgb" - self.diff_levels: List[str] = [] + self.diff_levels: List[float] = [] self.diff_units: str = "" self.diff_type: str = "absolute" @@ -227,6 +233,58 @@ def check_values(self): msg = "You need to define both the 'test_start_yr' and 'test_end_yr' parameter." raise RuntimeError(msg) + def _set_param_output_attrs( + self, + var_key: str, + season: str, + region: str, + ref_name: str, + ilev: float | None, + ): + """Set the parameter output attributes based on argument values. + + Parameters + ---------- + var_key : str + The variable key. + season : str + The season. + region : str + The region. + ref_name : str + The reference name. + ilev : float | None + The pressure level, by default None. This option is only set if the + variable is 3D. + """ + if ilev is None: + output_file = f"{ref_name}-{var_key}-{season}-{region}" + main_title = f"{var_key} {season} {region}" + else: + ilev_str = str(int(ilev)) + output_file = f"{ref_name}-{var_key}-{ilev_str}-{season}-{region}" + main_title = f"{var_key} {ilev_str} 'mb' {season} {region}" + + self.output_file = output_file + self.main_title = main_title + + def _set_name_yrs_attrs( + self, ds_test: Dataset, ds_ref: Dataset, season: CLIMO_FREQ + ): + """Set the test_name_yrs and ref_name_yrs attributes. + + Parameters + ---------- + ds_test : Dataset + The test dataset object used for setting ``self.test_name_yrs``. + ds_ref : Dataset + The ref dataset object used for setting ``self.ref_name_yrs``. + season : CLIMO_FREQ + The climatology frequency. + """ + self.test_name_yrs = ds_test.get_name_yrs_attr(season) + self.ref_name_yrs = ds_ref.get_name_yrs_attr(season) + def _run_diag(self) -> List[Any]: """Run the diagnostics for each set in the parameter. diff --git a/e3sm_diags/plot/lat_lon_plot.py b/e3sm_diags/plot/lat_lon_plot.py index 7c40232e17..a6b8a0da02 100644 --- a/e3sm_diags/plot/lat_lon_plot.py +++ b/e3sm_diags/plot/lat_lon_plot.py @@ -20,16 +20,18 @@ def plot( + parameter: CoreParameter, da_test: xr.DataArray, da_ref: xr.DataArray | None, da_diff: xr.DataArray | None, metrics_dict: MetricsDict, - parameter: CoreParameter, ): """Plot the variable's metrics generated for the lat_lon set. Parameters ---------- + parameter : CoreParameter + The CoreParameter object containing plot configurations. da_test : xr.DataArray The test data. da_ref : xr.DataArray | None @@ -38,8 +40,6 @@ def plot( The difference between ``ds_test_regrid`` and ``ds_ref_regrid``. metrics_dict : Metrics The metrics. - parameter : CoreParameter - The CoreParameter object containing plot configurations. """ fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi) fig.suptitle(parameter.main_title, x=0.5, y=0.96, fontsize=18) diff --git a/e3sm_diags/plot/utils.py b/e3sm_diags/plot/utils.py index b2ab1be2a3..c4057cbbe6 100644 --- a/e3sm_diags/plot/utils.py +++ b/e3sm_diags/plot/utils.py @@ -101,7 +101,7 @@ def _add_colormap( fig: plt.figure, parameter: CoreParameter, color_map: str, - contour_levels: List[str], + contour_levels: List[float], title: Tuple[str | None, str, str], metrics: Tuple[float, ...], ): @@ -123,7 +123,7 @@ def _add_colormap( The CoreParameter object containing plot configurations. color_map : str The colormap styling to use (e.g., "cet_rainbow.rgb"). - contour_levels : List[str] + contour_levels : List[float] The map contour levels. title : Tuple[str | None, str, str] A tuple of strings to form the title of the colormap, in the format @@ -427,14 +427,12 @@ def _determine_tick_step(degrees_covered: float) -> int: return 1 -def _get_contour_label_format_and_pad( - c_levels: List[str] | List[str | float], -) -> Tuple[str, int]: +def _get_contour_label_format_and_pad(c_levels: List[float]) -> Tuple[str, int]: """Get the label format and padding for each contour level. Parameters ---------- - c_levels : List[str] | List[str | float] + c_levels : List[float] The contour levels. Returns diff --git a/tests/e3sm_diags/driver/utils/test_dataset_xr.py b/tests/e3sm_diags/driver/utils/test_dataset_xr.py index 3ddc6de364..1fdf6de3e3 100644 --- a/tests/e3sm_diags/driver/utils/test_dataset_xr.py +++ b/tests/e3sm_diags/driver/utils/test_dataset_xr.py @@ -205,6 +205,181 @@ def test_property_is_timeseries_returns_false_and_is_climo_returns_true_for_ref( assert ds.is_climo +class TestGetReferenceClimoDataset: + @pytest.fixture(autouse=True) + def setup(self, tmp_path): + # Create temporary directory to save files. + self.data_path = tmp_path / "input_data" + self.data_path.mkdir() + + # Set up climatology dataset and save to a temp file. + # TODO: Update this to an actual climatology dataset structure + self.ds_climo = xr.Dataset( + coords={ + "lat": [-90, 90], + "lon": [0, 180], + "time": xr.DataArray( + dims="time", + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 1, 1, 12, 0, 0, 0, has_year_zero=False + ) + ], + dtype="object", + ), + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + data_vars={ + "ts": xr.DataArray( + name="ts", + data=np.array( + [ + [[1.0, 1.0], [1.0, 1.0]], + ] + ), + dims=["time", "lat", "lon"], + ) + }, + ) + self.ds_climo.time.encoding = {"units": "days since 2000-01-01"} + + # Set up time series dataset and save to a temp file. + self.ds_ts = xr.Dataset( + coords={ + "lat": [-90, 90], + "lon": [0, 180], + "time": xr.DataArray( + dims="time", + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 1, 1, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 2, 1, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 1, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 1, 1, 12, 0, 0, 0, has_year_zero=False + ), + ], + dtype="object", + ), + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + data_vars={ + "time_bnds": xr.DataArray( + name="time_bnds", + data=np.array( + [ + [ + cftime.DatetimeGregorian( + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2001, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + ], + dtype=object, + ), + dims=["time", "bnds"], + ), + "ts": xr.DataArray( + xr.DataArray( + data=np.array( + [ + [[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]], + ] + ), + dims=["time", "lat", "lon"], + ) + ), + }, + ) + self.ds_ts.time.encoding = {"units": "days since 2000-01-01"} + + def test_raises_error_if_dataset_data_type_is_not_ref(self): + parameter = _create_parameter_object( + "test", "climo", self.data_path, "2000", "2001" + ) + parameter.ref_file = "test.nc" + ds = Dataset(parameter, data_type="test") + + with pytest.raises(RuntimeError): + ds.get_ref_climo_dataset("ts", "ANN", self.ds_climo.copy()) + + def test_returns_reference_climo_dataset_from_file(self): + parameter = _create_parameter_object( + "ref", "climo", self.data_path, "2000", "2001" + ) + parameter.ref_file = "ref_file.nc" + + self.ds_climo.to_netcdf(f"{self.data_path}/{parameter.ref_file}") + + ds = Dataset(parameter, data_type="ref") + result = ds.get_ref_climo_dataset("ts", "ANN", self.ds_climo.copy()) + expected = self.ds_climo.squeeze(dim="time").drop_vars("time") + + assert result.identical(expected) + assert not ds.model_only + + def test_returns_test_dataset_as_default_value_if_climo_dataset_not_found(self): + parameter = _create_parameter_object( + "ref", "climo", self.data_path, "2000", "2001" + ) + parameter.ref_file = "ref_file.nc" + ds = Dataset(parameter, data_type="ref") + + ds_test = self.ds_climo.copy() + result = ds.get_ref_climo_dataset("ts", "ANN", ds_test) + + assert result.identical(ds_test) + assert ds.model_only + + class TestGetClimoDataset: @pytest.fixture(autouse=True) def setup(self, tmp_path): diff --git a/tests/e3sm_diags/metrics/test_metrics.py b/tests/e3sm_diags/metrics/test_metrics.py index bda3d85030..141b5e0d67 100644 --- a/tests/e3sm_diags/metrics/test_metrics.py +++ b/tests/e3sm_diags/metrics/test_metrics.py @@ -75,6 +75,13 @@ def test_returns_spatial_avg_for_x_y(self): np.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5) + def test_returns_spatial_avg_for_x_y_as_xr_dataarray(self): + expected = [1.5, 1.333299, 1.5] + result = spatial_avg(self.ds, "ts", as_list=False) + + assert isinstance(result, xr.DataArray) + np.testing.assert_allclose(expected, result, atol=1e-5, rtol=1e-5) + class TestStd: @pytest.fixture(autouse=True)