diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8303113600b..670ee021107 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Fix bug with rechunking to a frequency when some periods contain no data (:issue:`9360`). + By `Deepak Cherian `_. - Fix bug causing `DataTree.from_dict` to be sensitive to insertion order (:issue:`9276`, :pull:`9292`). By `Tom Nicholas `_. - Fix resampling error with monthly, quarterly, or yearly frequencies with diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 62183adcf71..628a2efa61e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2752,7 +2752,7 @@ def _resolve_frequency( ) assert variable.ndim == 1 - chunks: tuple[int, ...] = tuple( + chunks = ( DataArray( np.ones(variable.shape, dtype=int), dims=(name,), @@ -2760,9 +2760,13 @@ def _resolve_frequency( ) .resample({name: resampler}) .sum() - .data.tolist() ) - return chunks + # When bins (binning) or time periods are missing (resampling) + # we can end up with NaNs. Drop them. + if chunks.dtype.kind == "f": + chunks = chunks.dropna(name).astype(int) + chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist()) + return chunks_tuple chunks_mapping_ints: Mapping[Any, T_ChunkDim] = { name: ( diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f859ff38ccb..a43e3d15e32 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1209,24 +1209,34 @@ def get_dask_names(ds): ), ) @pytest.mark.parametrize("freq", ["D", "W", "5ME", "YE"]) - def test_chunk_by_frequency(self, freq, calendar) -> None: + @pytest.mark.parametrize("add_gap", [True, False]) + def test_chunk_by_frequency(self, freq: str, calendar: str, add_gap: bool) -> None: import dask.array N = 365 * 2 + ΔN = 28 + time = xr.date_range( + "2001-01-01", periods=N + ΔN, freq="D", calendar=calendar + ).to_numpy() + if add_gap: + # introduce an empty bin + time[31 : 31 + ΔN] = np.datetime64("NaT") + time = time[~np.isnat(time)] + else: + time = time[:N] + ds = Dataset( { "pr": ("time", dask.array.random.random((N), chunks=(20))), "pr2d": (("x", "time"), dask.array.random.random((10, N), chunks=(20))), "ones": ("time", np.ones((N,))), }, - coords={ - "time": xr.date_range( - "2001-01-01", periods=N, freq="D", calendar=calendar - ) - }, + coords={"time": time}, ) rechunked = ds.chunk(x=2, time=TimeResampler(freq)) - expected = tuple(ds.ones.resample(time=freq).sum().data.tolist()) + expected = tuple( + ds.ones.resample(time=freq).sum().dropna("time").astype(int).data.tolist() + ) assert rechunked.chunksizes["time"] == expected assert rechunked.chunksizes["x"] == (2,) * 5