From 2ef071c03d07ffc4670d427326a816c6340e4638 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89ric=20Dupuis?= Date: Tue, 14 Jan 2025 15:58:21 -0500 Subject: [PATCH] add tests --- src/xclim/indices/_agro.py | 6 +-- src/xclim/indices/stats.py | 2 +- tests/test_indices.py | 75 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 4 deletions(-) diff --git a/src/xclim/indices/_agro.py b/src/xclim/indices/_agro.py index baa08662a..fd9355959 100644 --- a/src/xclim/indices/_agro.py +++ b/src/xclim/indices/_agro.py @@ -1145,7 +1145,7 @@ def standardized_precipitation_index( i.e. a monthly resampling, the window is an integer number of months. dist : {"gamma", "fisk"} Name of the univariate distribution. (see :py:mod:`scipy.stats`). - method : {"APP", "ML"} + method : {"APP", "ML", "PWM"} Name of the fitting method, such as `ML` (maximum likelihood), `APP` (approximate). The approximate method uses a deterministic function that does not involve any optimization. fitkwargs : dict, optional @@ -1219,7 +1219,7 @@ def standardized_precipitation_index( >>> spi_3_fitted = standardized_precipitation_index(pr, params=params) """ fitkwargs = fitkwargs or {} - dist_methods = {"gamma": ["ML", "APP"], "fisk": ["ML", "APP"]} + dist_methods = {"gamma": ["ML", "APP", "PWM"], "fisk": ["ML", "APP"]} if dist in dist_methods: if method not in dist_methods[dist]: raise NotImplementedError( @@ -1313,7 +1313,7 @@ def standardized_precipitation_evapotranspiration_index( """ fitkwargs = fitkwargs or {} - dist_methods = {"gamma": ["ML", "APP", "PWM"], "fisk": ["ML", "APP", "PWM"]} + dist_methods = {"gamma": ["ML", "APP", "PWM"], "fisk": ["ML", "APP"]} if dist in dist_methods: if method not in dist_methods[dist]: raise NotImplementedError( diff --git a/src/xclim/indices/stats.py b/src/xclim/indices/stats.py index b021857e7..477f8d7cf 100644 --- a/src/xclim/indices/stats.py +++ b/src/xclim/indices/stats.py @@ -808,7 +808,7 @@ def standardized_index_fit_params( "Pass a value for `floc` in `fitkwargs`." ) - dist_and_methods = {"gamma": ["ML", "APP"], "fisk": ["ML", "APP"]} + dist_and_methods = {"gamma": ["ML", "APP", "PWM"], "fisk": ["ML", "APP"]} dist = get_dist(dist) if dist.name not in dist_and_methods: raise NotImplementedError(f"The distribution `{dist.name}` is not supported.") diff --git a/tests/test_indices.py b/tests/test_indices.py index 22978b49d..c8444fec4 100644 --- a/tests/test_indices.py +++ b/tests/test_indices.py @@ -660,6 +660,38 @@ class TestStandardizedIndices: [-1.08627775, -0.46491398, -0.77806462, 0.31759127, 0.03794528], 2e-2, ), + ( + "D", + 1, + "gamma", + "PWM", + [-0.13002, 1.346689, 0.965731, 0.245408, -0.427896], + 2e-2, + ), + ( + "D", + 12, + "gamma", + "PWM", + [-0.209411, -0.086357, 0.636851, 1.022608, 0.634409], + 2e-2, + ), + ( + "MS", + 1, + "gamma", + "PWM", + [1.364243, 1.478565, 1.915559, -3.055828, 0.905304], + 2e-2, + ), + ( + "MS", + 12, + "gamma", + "PWM", + [0.577214, 1.522867, 1.634222, 0.967847, 0.689001], + 2e-2, + ), ], ) def test_standardized_precipitation_index( @@ -671,6 +703,13 @@ def test_standardized_precipitation_index( and Version(__numpy_version__) < Version("2.0.0") ): pytest.skip("Skipping SPI/ML/D on older numpy") + + # change `dist` to a lmoments3 object if needed + if method == "PWM": + lmom = pytest.importorskip("lmoments3.distr") + scipy2lmom = {"gamma": "gam"} + dist = getattr(lmom, scipy2lmom[dist]) + ds = open_dataset("sdba/CanESM2_1950-2100.nc").isel(location=1) if freq == "D": # to compare with ``climate_indices`` @@ -922,6 +961,42 @@ def test_zero_inflated(self, open_dataset): np.all(np.not_equal(spid[False].values, spid[True].values)), True ) + def test_PWM_and_fitkwargs(self, open_dataset): + pr = ( + open_dataset("sdba/CanESM2_1950-2100.nc") + .isel(location=1) + .sel(time=slice("1950", "1980")) + ).pr + + lmom = pytest.importorskip("lmoments3.distr") + # for now, only one function used + scipy2lmom = {"gamma": "gam"} + dist = getattr(lmom, scipy2lmom["gamma"]) + fitkwargs = {"floc": 0} + input_params = dict( + freq=None, + window=1, + method="PWM", + dist=dist, + fitkwargs=fitkwargs, + # doy_bounds=(180, 180), + ) + # this should not cause a problem + params_d0 = xci.stats.standardized_index_fit_params(pr, **input_params).isel( + dayofyear=0 + ) + np.testing.assert_allclose( + params_d0, np.array([5.63e-01, 0, 3.37e-05]), rtol=0, atol=2e-2 + ) + # this should cause a problem + fitkwargs["fscale"] = 1 + input_params["fitkwargs"] = fitkwargs + with pytest.raises( + ValueError, + match="Lmoments3 does not use `fitkwargs` arguments, except for `floc` with the Gamma distribution.", + ): + xci.stats.standardized_index_fit_params(pr, **input_params) + class TestDailyFreezeThawCycles: @pytest.mark.parametrize(