Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
coxipi committed Jan 14, 2025
1 parent 8eb2b51 commit 2ef071c
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/xclim/indices/_agro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,7 @@ def standardized_precipitation_index(
i.e. a monthly resampling, the window is an integer number of months.
dist : {"gamma", "fisk"}
Name of the univariate distribution. (see :py:mod:`scipy.stats`).
method : {"APP", "ML"}
method : {"APP", "ML", "PWM"}
Name of the fitting method, such as `ML` (maximum likelihood), `APP` (approximate). The approximate method
uses a deterministic function that does not involve any optimization.
fitkwargs : dict, optional
Expand Down Expand Up @@ -1219,7 +1219,7 @@ def standardized_precipitation_index(
>>> spi_3_fitted = standardized_precipitation_index(pr, params=params)
"""
fitkwargs = fitkwargs or {}
dist_methods = {"gamma": ["ML", "APP"], "fisk": ["ML", "APP"]}
dist_methods = {"gamma": ["ML", "APP", "PWM"], "fisk": ["ML", "APP"]}
if dist in dist_methods:
if method not in dist_methods[dist]:
raise NotImplementedError(
Expand Down Expand Up @@ -1313,7 +1313,7 @@ def standardized_precipitation_evapotranspiration_index(
"""
fitkwargs = fitkwargs or {}

dist_methods = {"gamma": ["ML", "APP", "PWM"], "fisk": ["ML", "APP", "PWM"]}
dist_methods = {"gamma": ["ML", "APP", "PWM"], "fisk": ["ML", "APP"]}
if dist in dist_methods:
if method not in dist_methods[dist]:
raise NotImplementedError(
Expand Down
2 changes: 1 addition & 1 deletion src/xclim/indices/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def standardized_index_fit_params(
"Pass a value for `floc` in `fitkwargs`."
)

dist_and_methods = {"gamma": ["ML", "APP"], "fisk": ["ML", "APP"]}
dist_and_methods = {"gamma": ["ML", "APP", "PWM"], "fisk": ["ML", "APP"]}
dist = get_dist(dist)
if dist.name not in dist_and_methods:
raise NotImplementedError(f"The distribution `{dist.name}` is not supported.")
Expand Down
75 changes: 75 additions & 0 deletions tests/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,38 @@ class TestStandardizedIndices:
[-1.08627775, -0.46491398, -0.77806462, 0.31759127, 0.03794528],
2e-2,
),
(
"D",
1,
"gamma",
"PWM",
[-0.13002, 1.346689, 0.965731, 0.245408, -0.427896],
2e-2,
),
(
"D",
12,
"gamma",
"PWM",
[-0.209411, -0.086357, 0.636851, 1.022608, 0.634409],
2e-2,
),
(
"MS",
1,
"gamma",
"PWM",
[1.364243, 1.478565, 1.915559, -3.055828, 0.905304],
2e-2,
),
(
"MS",
12,
"gamma",
"PWM",
[0.577214, 1.522867, 1.634222, 0.967847, 0.689001],
2e-2,
),
],
)
def test_standardized_precipitation_index(
Expand All @@ -671,6 +703,13 @@ def test_standardized_precipitation_index(
and Version(__numpy_version__) < Version("2.0.0")
):
pytest.skip("Skipping SPI/ML/D on older numpy")

# change `dist` to a lmoments3 object if needed
if method == "PWM":
lmom = pytest.importorskip("lmoments3.distr")
scipy2lmom = {"gamma": "gam"}
dist = getattr(lmom, scipy2lmom[dist])

ds = open_dataset("sdba/CanESM2_1950-2100.nc").isel(location=1)
if freq == "D":
# to compare with ``climate_indices``
Expand Down Expand Up @@ -922,6 +961,42 @@ def test_zero_inflated(self, open_dataset):
np.all(np.not_equal(spid[False].values, spid[True].values)), True
)

def test_PWM_and_fitkwargs(self, open_dataset):
pr = (
open_dataset("sdba/CanESM2_1950-2100.nc")
.isel(location=1)
.sel(time=slice("1950", "1980"))
).pr

lmom = pytest.importorskip("lmoments3.distr")
# for now, only one function used
scipy2lmom = {"gamma": "gam"}
dist = getattr(lmom, scipy2lmom["gamma"])
fitkwargs = {"floc": 0}
input_params = dict(
freq=None,
window=1,
method="PWM",
dist=dist,
fitkwargs=fitkwargs,
# doy_bounds=(180, 180),
)
# this should not cause a problem
params_d0 = xci.stats.standardized_index_fit_params(pr, **input_params).isel(
dayofyear=0
)
np.testing.assert_allclose(
params_d0, np.array([5.63e-01, 0, 3.37e-05]), rtol=0, atol=2e-2
)
# this should cause a problem
fitkwargs["fscale"] = 1
input_params["fitkwargs"] = fitkwargs
with pytest.raises(
ValueError,
match="Lmoments3 does not use `fitkwargs` arguments, except for `floc` with the Gamma distribution.",
):
xci.stats.standardized_index_fit_params(pr, **input_params)


class TestDailyFreezeThawCycles:
@pytest.mark.parametrize(
Expand Down

0 comments on commit 2ef071c

Please sign in to comment.