Skip to content

Commit

Permalink
Fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Jan 25, 2024
1 parent 0b7a18d commit 0ae2a46
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 55 deletions.
3 changes: 3 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[report]
exclude_also =
if TYPE_CHECKING:
Original file line number Diff line number Diff line change
@@ -1,26 +1,3 @@
"""
The template run script used for generating results on your development branch.
Steps:
1. Activate your conda dev env for your branch
2. `make install` to install the latest version of your branch code into the env
3. Copy this script into `auxiliary_tools/cdat_regression_testing/<ISSUE>-<SET_NAME>`
4. Update `SET_DIR` string variable
5. Update `SET_NAME` string variable.
- Options include: "lat_lon", "zonal_mean_xy", "zonal_mean_2d",
"zonal_mean_2d_stratosphere", "polar", "cosp_histogram",
"meridional_mean_2d", "annual_cycle_zonal_mean", "enso_diags", "qbo",
"area_mean_time_series", "diurnal_cycle", "streamflow", "arm_diags",
"tc_analysis", "aerosol_aeronet", "aerosol_budget", "mp_partition",
6.
6. Run this script
- Make sure to run this command on NERSC perlmutter cpu:
`salloc --nodes 1 --qos interactive --time 01:00:00 --constraint cpu --account=e3sm
conda activate <NAME-OF-DEV-ENV>`
- python auxiliary_tools/cdat_regression_testing/<ISSUE-<SET_NAME>
7. Make a copy of the CDAT regression testing notebook in the same directory
as this script and follow the instructions there to start testing.
"""
from auxiliary_tools.cdat_regression_testing.base_run_script import run_set

SET_NAME = "cosp_histogram"
Expand Down
11 changes: 6 additions & 5 deletions e3sm_diags/driver/utils/dataset_xr.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ def _open_climo_dataset(self, filepath: str) -> xr.Dataset:

try:
ds = xc.open_dataset(**args)
except ValueError as e:
except ValueError as e: # pragma: no cover
# FIXME: Need to fix the test that covers this code block.
msg = str(e)

if "dimension 'time' already exists as a scalar variable" in msg:
Expand Down Expand Up @@ -837,11 +838,11 @@ def _get_dataset_with_derivation_func(
# extensive refactoring of the structure for derived variables (e.g.,
# the massive derived variables dictionary).
if func in FUNC_REQUIRES_DATASET_AND_TARGET_VAR:
func_args = [ds, target_var_key] + func_args # type: ignore
ds_final = func(*func_args)
func_args = [ds, target_var_key] + func_args # type: ignore # pragma: nocover
ds_final = func(*func_args) # pragma: nocover
elif func in FUNC_REQUIRES_TARGET_VAR:
func_args = [target_var_key] + func_args # type: ignore
ds_final = func(*func_args)
func_args = [target_var_key] + func_args # type: ignore # pragma: nocover
ds_final = func(*func_args) # pragma: nocover
else:
derived_var = func(*func_args)
ds_final = ds.copy()
Expand Down
167 changes: 141 additions & 26 deletions tests/e3sm_diags/driver/utils/test_dataset_xr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,44 @@
)
from e3sm_diags.parameter.core_parameter import CoreParameter

# Reusable spatial coords dictionary for composing an xr.Dataest.
spatial_coords = {
"lat": xr.DataArray(
dims="lat",
data=np.array([-90.0, 90]),
attrs={
"axis": "Y",
"long_name": "latitude",
"standard_name": "latitude",
"bounds": "lat_bnds",
},
),
"lon": xr.DataArray(
dims="lon",
data=np.array([0.0, 180]),
attrs={
"axis": "X",
"long_name": "longitude",
"standard_name": "longitude",
"bounds": "lon_bnds",
},
),
}

# Reusable spatial bounds dictionary for composing an xr.Dataest.
spatial_bounds = {
"lat_bnds": xr.DataArray(
name="lat_bnds",
data=[[-90.0, 0.0], [0.0, 90.0]],
dims=["lat", "bnds"],
),
"lon_bnds": xr.DataArray(
name="lat_bnds",
data=[[-90.0, 90.0], [90, 270]],
dims=["lon", "bnds"],
),
}


def _create_parameter_object(
dataset_type: Literal["ref", "test"],
Expand Down Expand Up @@ -216,11 +254,9 @@ def setup(self, tmp_path):
self.data_path.mkdir()

# Set up climatology dataset and save to a temp file.
# TODO: Update this to an actual climatology dataset structure
self.ds_climo = xr.Dataset(
coords={
"lat": [-90, 90],
"lon": [0, 180],
**spatial_coords,
"time": xr.DataArray(
dims="time",
data=np.array(
Expand All @@ -240,6 +276,7 @@ def setup(self, tmp_path):
),
},
data_vars={
**spatial_bounds,
"ts": xr.DataArray(
name="ts",
data=np.array(
Expand All @@ -248,7 +285,7 @@ def setup(self, tmp_path):
]
),
dims=["time", "lat", "lon"],
)
),
},
)
self.ds_climo.time.encoding = {"units": "days since 2000-01-01"}
Expand Down Expand Up @@ -390,12 +427,46 @@ def setup(self, tmp_path):
self.data_path = tmp_path / "input_data"
self.data_path.mkdir()

self.spatial_coords = {
"lat": xr.DataArray(
dims="lat",
data=np.array([-90.0, 90]),
attrs={
"axis": "Y",
"long_name": "latitude",
"standard_name": "latitude",
"bounds": "lat_bnds",
},
),
"lon": xr.DataArray(
dims="lon",
data=np.array([0.0, 180]),
attrs={
"axis": "X",
"long_name": "longitude",
"standard_name": "longitude",
"bounds": "lon_bnds",
},
),
}
self.spatial_bounds = {
"lat_bnds": xr.DataArray(
name="lat_bnds",
data=[[-90.0, 0.0], [0.0, 90.0]],
dims=["lat", "bnds"],
),
"lon_bnds": xr.DataArray(
name="lat_bnds",
data=[[-90.0, 90.0], [90, 270]],
dims=["lon", "bnds"],
),
}

# Set up climatology dataset and save to a temp file.
# TODO: Update this to an actual climatology dataset structure
self.ds_climo = xr.Dataset(
coords={
"lat": [-90, 90],
"lon": [0, 180],
**spatial_coords,
"time": xr.DataArray(
dims="time",
data=np.array(
Expand All @@ -415,6 +486,7 @@ def setup(self, tmp_path):
),
},
data_vars={
**spatial_bounds,
"ts": xr.DataArray(
name="ts",
data=np.array(
Expand All @@ -423,16 +495,15 @@ def setup(self, tmp_path):
]
),
dims=["time", "lat", "lon"],
)
),
},
)
self.ds_climo.time.encoding = {"units": "days since 2000-01-01"}

# Set up time series dataset and save to a temp file.
self.ds_ts = xr.Dataset(
coords={
"lat": [-90, 90],
"lon": [0, 180],
**spatial_coords,
"time": xr.DataArray(
dims="time",
data=np.array(
Expand Down Expand Up @@ -573,9 +644,7 @@ def test_returns_climo_dataset_using_test_file_variable(self):

xr.testing.assert_identical(result, expected)

def test_returns_climo_dataset_using_ref_file_variable_test_name_and_season(
self,
):
def test_returns_climo_dataset_using_ref_file_variable_test_name_and_season(self):
# Example: {test_data_path}/{test_name}_{season}.nc
parameter = _create_parameter_object(
"ref", "climo", self.data_path, "2000", "2001"
Expand All @@ -589,9 +658,7 @@ def test_returns_climo_dataset_using_ref_file_variable_test_name_and_season(

xr.testing.assert_identical(result, expected)

def test_returns_climo_dataset_using_test_file_variable_test_name_and_season(
self,
):
def test_returns_climo_dataset_using_test_file_variable_test_name_and_season(self):
# Example: {test_data_path}/{test_name}_{season}.nc
parameter = _create_parameter_object(
"test", "climo", self.data_path, "2000", "2001"
Expand Down Expand Up @@ -651,8 +718,7 @@ def test_returns_climo_dataset_with_derived_variable(self):
# We will derive the "PRECT" variable using the "pr" variable.
ds_pr = xr.Dataset(
coords={
"lat": [-90, 90],
"lon": [0, 180],
**spatial_coords,
"time": xr.DataArray(
dims="time",
data=np.array(
Expand All @@ -672,6 +738,7 @@ def test_returns_climo_dataset_with_derived_variable(self):
),
},
data_vars={
**spatial_bounds,
"pr": xr.DataArray(
xr.DataArray(
data=np.array(
Expand Down Expand Up @@ -702,11 +769,56 @@ def test_returns_climo_dataset_with_derived_variable(self):

xr.testing.assert_identical(result, expected)

@pytest.mark.xfail
def test_returns_climo_dataset_using_derived_var_directly_from_dataset_and_replaces_scalar_time_var(
self,
):
# FIXME: This test needs to cover `except` block in `_open_dataset()`.
# The issue is that we can't create a dummy dataset with an incorrect
# time scalar variable using Xarray because it just throws the error
# below. We might need to use another library like netCDF4 to create
# a dummy dataset.
ds_precst = xr.Dataset(
coords={
**spatial_coords,
},
data_vars={
**spatial_bounds,
"time": xr.DataArray(
dims="time",
data=0,
),
"PRECST": xr.DataArray(
xr.DataArray(
data=np.array(
[
[[1.0, 1.0], [1.0, 1.0]],
]
),
dims=["time", "lat", "lon"],
attrs={"units": "mm/s"},
)
),
},
)

parameter = _create_parameter_object(
"ref", "climo", self.data_path, "2000", "2001"
)
parameter.ref_file = "pr_200001_200112.nc"
ds_precst.to_netcdf(f"{self.data_path}/{parameter.ref_file}")

ds = Dataset(parameter, data_type="ref")

result = ds.get_climo_dataset("PRECST", season="ANN")
expected = ds_precst.squeeze(dim="time").drop_vars("time")

xr.testing.assert_identical(result, expected)

def test_returns_climo_dataset_using_derived_var_directly_from_dataset(self):
ds_precst = xr.Dataset(
coords={
"lat": [-90, 90],
"lon": [0, 180],
**spatial_coords,
"time": xr.DataArray(
dims="time",
data=np.array(
Expand All @@ -726,6 +838,7 @@ def test_returns_climo_dataset_using_derived_var_directly_from_dataset(self):
),
},
data_vars={
**spatial_bounds,
"PRECST": xr.DataArray(
xr.DataArray(
data=np.array(
Expand Down Expand Up @@ -756,8 +869,7 @@ def test_returns_climo_dataset_using_derived_var_directly_from_dataset(self):
def test_returns_climo_dataset_using_source_variable_with_wildcard(self):
ds_precst = xr.Dataset(
coords={
"lat": [-90, 90],
"lon": [0, 180],
**spatial_coords,
"time": xr.DataArray(
dims="time",
data=np.array(
Expand All @@ -777,6 +889,7 @@ def test_returns_climo_dataset_using_source_variable_with_wildcard(self):
),
},
data_vars={
**spatial_bounds,
"bc_a?DDF": xr.DataArray(
xr.DataArray(
data=np.array(
Expand Down Expand Up @@ -879,8 +992,7 @@ def test_raises_error_if_dataset_has_no_matching_source_variables_to_derive_vari
def test_raises_error_if_no_datasets_found_to_derive_variable(self):
ds_precst = xr.Dataset(
coords={
"lat": [-90, 90],
"lon": [0, 180],
**spatial_coords,
"time": xr.DataArray(
dims="time",
data=np.array(
Expand All @@ -900,6 +1012,7 @@ def test_raises_error_if_no_datasets_found_to_derive_variable(self):
),
},
data_vars={
**spatial_bounds,
"invalid": xr.DataArray(
xr.DataArray(
data=np.array(
Expand Down Expand Up @@ -1327,10 +1440,10 @@ def setup(self, tmp_path):
self.data_path = tmp_path / "input_data"
self.data_path.mkdir()
# Set up climatology dataset and save to a temp file.

self.ds_climo = xr.Dataset(
coords={
"lat": [-90, 90],
"lon": [0, 180],
**spatial_coords,
"time": xr.DataArray(
dims="time",
data=np.array(
Expand All @@ -1350,6 +1463,7 @@ def setup(self, tmp_path):
),
},
data_vars={
**spatial_bounds,
"ts": xr.DataArray(
name="ts",
data=np.array(
Expand All @@ -1358,7 +1472,7 @@ def setup(self, tmp_path):
]
),
dims=["time", "lat", "lon"],
)
),
},
)
self.ds_climo.time.encoding = {"units": "days since 2000-01-01"}
Expand Down Expand Up @@ -1580,3 +1694,4 @@ def test_returns_test_name_and_years_averaged_as_single_string_with_timeseries_d
expected = "short_test_name (1800-1850)"

assert result == expected
assert result == expected
Loading

0 comments on commit 0ae2a46

Please sign in to comment.