From 5b447d3c9c5175132d503a8b8963400875ca24e5 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Mon, 8 Apr 2024 11:06:13 -0400 Subject: [PATCH] Do not chunk when the input isn't chunked (#348) * Refine dask testing - warn when number of chunks increases too much * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change default when input is not chunked * update documentation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update other test * upd changes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGES.rst | 4 + doc/notebooks/Dask.ipynb | 755 ++++++++--------------------------- xesmf/frontend.py | 34 +- xesmf/tests/test_frontend.py | 40 +- 4 files changed, 231 insertions(+), 602 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 4e08402e..55ddf13d 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,10 @@ What's new ========== +0.8.5 (unreleased) +------------------ +* Reverted to the chunking behaviour of xESMF 0.7 for cases where the spatial dimensions are not chunked on the source data. (:pull:`348`) By `Pascal Bourgault `_. + 0.8.4 (2024-02-26) ------------------ * Fix regression from :pull:`332` that made ``Regridder`` fail with rectilinear datasets and ``parallel=True``. (:issue:`343`, :pull:`344`). diff --git a/doc/notebooks/Dask.ipynb b/doc/notebooks/Dask.ipynb index 4b2cde3c..9cab732a 100644 --- a/doc/notebooks/Dask.ipynb +++ b/doc/notebooks/Dask.ipynb @@ -12,25 +12,19 @@ "metadata": {}, "source": [ "If you are unfamiliar with Dask, read\n", - "[Parallel computing with Dask](http://xarray.pydata.org/en/stable/dask.html) in\n", - "Xarray documentation first. The current version only supports dask arrays on a\n", - "single machine. Support of [Dask.distributed](https://distributed.dask.org) is\n", - "in roadmap.\n", + "[Parallel computing with Dask](https://docs.xarray.dev/en/stable/user-guide/dask.html)\n", + "in Xarray documentation first.\n", "\n", - "xESMF's Dask support is mostly for\n", - "[lazy evaluation](https://en.wikipedia.org/wiki/Lazy_evaluation) and\n", + "Recall that the regridding process is divided in two steps : computing the\n", + "weights and applying the weights. Dask support is much more advanced for the\n", + "latter, and this what the first part of this notebook is about.\n", + "\n", + "Dask allows [lazy evaluation](https://en.wikipedia.org/wiki/Lazy_evaluation) and\n", "[out-of-core computing](https://en.wikipedia.org/wiki/External_memory_algorithm),\n", - "to allow processing large volumes of data with limited memory. You might also\n", - "get moderate speed-up on a multi-core machine by\n", - "[choosing proper chunk sizes](http://xarray.pydata.org/en/stable/dask.html#chunking-and-performance),\n", - "but that generally won't help your entire pipeline too much, because the\n", - "read-regrid-write pipeline is severely I/O limited (see\n", - "[this issue](https://github.com/pangeo-data/pangeo/issues/334) for more\n", - "discussions). On a single machine, the disk bandwidth is typically limited to\n", - "~500 MB/s, and you cannot process data faster than such rate. If you need much\n", - "faster data processing rate, you should resort to parallel file systems on HPC\n", - "clusters or distributed storage on public cloud platforms. Please refer to the\n", - "[Pangeo project](http://pangeo.io/) for more information.\n" + "to allow processing large volumes of data with limited memory. You may also get\n", + "a speed-up by parallelizing the process in some cases, but a general rule of\n", + "thumb is that if the data fits in memory, regridding will be faster without\n", + "dask.\n" ] }, { @@ -445,17 +439,17 @@ " title: 4x daily NMC reanalysis (1948)\n", " description: Data is from NMC initialized reanalysis\\n(4x/day). These a...\n", " platform: Model\n", - " references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly..." ], "text/plain": [ @@ -712,7 +706,7 @@ "" ], "text/plain": [ - "dask.array" + "dask.array" ] }, "execution_count": 4, @@ -774,15 +768,15 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 2.04 s, sys: 43.8 ms, total: 2.08 s\n", - "Wall time: 2.22 s\n" + "CPU times: user 8.17 ms, sys: 2.32 ms, total: 10.5 ms\n", + "Wall time: 10.2 ms\n" ] }, { @@ -1158,16 +1152,16 @@ " * lat (lat) float64 16.0 17.0 18.0 19.0 20.0 ... 70.0 71.0 72.0 73.0 74.0\n", " * lon (lon) float64 200.0 201.5 203.0 204.5 ... 324.5 326.0 327.5 329.0\n", "Data variables:\n", - " air (time, lat, lon) float32 dask.array<chunksize=(500, 25, 53), meta=np.ndarray>\n", + " air (time, lat, lon) float32 dask.array<chunksize=(500, 59, 87), meta=np.ndarray>\n", "Attributes:\n", - " regrid_method: bilinear
  • regrid_method :
    bilinear
  • " ], "text/plain": [ "\n", @@ -1306,12 +1292,12 @@ " * lat (lat) float64 16.0 17.0 18.0 19.0 20.0 ... 70.0 71.0 72.0 73.0 74.0\n", " * lon (lon) float64 200.0 201.5 203.0 204.5 ... 324.5 326.0 327.5 329.0\n", "Data variables:\n", - " air (time, lat, lon) float32 dask.array\n", + " air (time, lat, lon) float32 dask.array\n", "Attributes:\n", " regrid_method: bilinear" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -1324,7 +1310,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -1346,17 +1332,17 @@ " \n", " Bytes \n", " 57.18 MiB \n", - " 2.53 MiB \n", + " 9.79 MiB \n", " \n", " \n", " \n", " Shape \n", " (2920, 59, 87) \n", - " (500, 25, 53) \n", + " (500, 59, 87) \n", " \n", " \n", " Dask graph \n", - " 36 chunks in 8 graph layers \n", + " 6 chunks in 8 graph layers \n", " \n", " \n", " Data type \n", @@ -1370,8 +1356,6 @@ "\n", " \n", " \n", - " \n", - " \n", " \n", "\n", " \n", @@ -1397,7 +1381,6 @@ "\n", " \n", " \n", - " \n", " \n", "\n", " \n", @@ -1405,13 +1388,10 @@ "\n", " \n", " \n", - " \n", - " \n", " \n", "\n", " \n", " \n", - " \n", " \n", "\n", " \n", @@ -1427,29 +1407,29 @@ "" ], "text/plain": [ - "dask.array" + "dask.array" ] }, - "execution_count": 7, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "ds_out[\"air\"].data # chunks are preserved" + "ds_out[\"air\"].data" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 1.62 s, sys: 386 ms, total: 2.01 s\n", - "Wall time: 908 ms\n" + "CPU times: user 756 ms, sys: 155 ms, total: 911 ms\n", + "Wall time: 600 ms\n" ] } ], @@ -1459,7 +1439,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -1468,7 +1448,7 @@ "(numpy.ndarray, (2920, 59, 87))" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -1481,20 +1461,28 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Spatial chunks\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Dask support also includes chunking over horizontal/core dimensions (`lat`,\n", - "`lon`, or `x`, `y`).\n" + "## Chunking behaviour\n", + "\n", + "xESMF will adjust its default behaviour according to the input data. On spatial\n", + "dimensions where the data has only one chunk, the output of a `Regridder` call\n", + "will also have only one chunk, no matter the new dimension size. This like the\n", + "previous example.\n", + "\n", + "However, if the input has more than one chunk along a spatial dimension, then\n", + "the regridder will try to preserve the chunk size. When upscaling data, this\n", + "means the number of chunks increases and with it the number of dask tasks added\n", + "to the graph. This can actually decrease performance if the graph becomes too\n", + "large, filled up with many small tasks.\n", + "\n", + "One can always override xESMF's default behaviour by passing `output_chunks` to\n", + "the `Regridder` call.\n", + "\n", + "In the example below, the input has three chunks along `lon`:\n" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 13, "metadata": { "scrolled": true }, @@ -1502,393 +1490,7 @@ { "data": { "text/html": [ - "
    \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
    <xarray.Dataset>\n",
    -       "Dimensions:  (lat: 25, time: 2920, lon: 53)\n",
    -       "Coordinates:\n",
    -       "  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n",
    -       "  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n",
    -       "  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n",
    -       "Data variables:\n",
    -       "    air      (time, lat, lon) float32 dask.array<chunksize=(2920, 25, 25), meta=np.ndarray>\n",
    -       "Attributes:\n",
    -       "    Conventions:  COARDS\n",
    -       "    title:        4x daily NMC reanalysis (1948)\n",
    -       "    description:  Data is from NMC initialized reanalysis\\n(4x/day).  These a...\n",
    -       "    platform:     Model\n",
    -       "    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
    " ], "text/plain": [ "\n", @@ -3345,7 +2926,7 @@ " *empty*" ] }, - "execution_count": 16, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -3362,7 +2943,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -3371,7 +2952,7 @@ "Frozen({})" ] }, - "execution_count": 17, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -3392,7 +2973,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -3401,7 +2982,7 @@ "Frozen({'lat': (25, 25, 9), 'lon': (25, 25, 25, 12)})" ] }, - "execution_count": 18, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -3431,7 +3012,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/xesmf/frontend.py b/xesmf/frontend.py index dd5bc1ac..e685f502 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -502,11 +502,14 @@ def __call__(self, indata, keep_attrs=False, skipna=False, na_thres=1.0, output_ mask the output value. output_chunks: dict or tuple, optional - If indata is a dask_array_type, the desired chunks to have on the - output data along the spatial axes. Other non-spatial axes inherit - the same chunks as indata as those are not affected by the application - of the weights. Default behavior is to have the outdata chunks be like - the indata chunks. Chunks have to be specified for all spatial dimensions + The desired chunks to have on the output along the spatial axes, if indata is a dask array. + Other non-spatial axes inherit the same chunks as indata. + Default behavior depends on the chunking of indata. If it is not chunked along + the spatial dimension, the output will also not be chunked, + equivalent to passing ``output_chunks=(-1, -1)``. + If it is chunked, the output will preserve the chunk sizes, + equivalent to passing ``output_chunks=ìndata.chunks``. + Chunks have to be specified for all spatial dimensions of the output data otherwise regridding will fail. output_chunks can either be a tuple the same size as the spatial axes of outdata or it can be a dict with defined dims. If output_chunks is a dict, the @@ -598,9 +601,28 @@ def regrid_array(self, indata, weights, skipna=False, na_thres=1.0, output_chunk weights = self.weights.data.reshape(self.shape_out + self.shape_in) if isinstance(indata, dask_array_type): # dask if output_chunks is None: + # Default : same chunk size as the input to preserve chunksize + # Unless the input is not chunked along the dimension (shape_in == in_chunk_size), in which case we do not chunk along the dimension + # This preserves the pre-0.8 behaviour. output_chunks = tuple( - [min(shp, inchnk) for shp, inchnk in zip(self.shape_out, indata.chunksize[-2:])] + min(chnkin, shpout) if shpin != chnkin else shpout + for shpout, shpin, chnkin in zip( + self.shape_out, self.shape_in, indata.chunksize[-2:] + ) + ) + fac = np.prod( + [np.ceil(shp / chnk) for shp, chnk in zip(self.shape_out, output_chunks)] ) + if fac > 4: # Dask's built-in threshold is 10 + warnings.warn( + ( + f'Regridding is increasing the number of chunks by a factor of {fac}, ' + 'you might want to specify sizes in `output_chunks` in the regridder call. ' + f'Default behaviour is to preserve the chunk sizes from the input {indata.chunksize[-2:]}.' + ), + da.core.PerformanceWarning, + stacklevel=3, + ) if len(output_chunks) != len(self.shape_out): if len(output_chunks) == 1 and self.sequence_out: output_chunks = (1, output_chunks[0]) diff --git a/xesmf/tests/test_frontend.py b/xesmf/tests/test_frontend.py index e4684c26..3c584df0 100644 --- a/xesmf/tests/test_frontend.py +++ b/xesmf/tests/test_frontend.py @@ -507,23 +507,46 @@ def test_regrid_dask(request, scheduler): scheduler = request.getfixturevalue(scheduler) regridder = xe.Regridder(ds_in, ds_out, 'conservative') - indata = ds_in_chunked['data4D'].data - # Use ridiculous small chunk size value to be sure it _isn't_ impacting computation. - with dask.config.set({'array.chunk-size': '1MiB'}): - outdata = regridder(indata) + indata = ds_in_chunked['data'].data + outdata = regridder(indata) assert dask.is_dask_collection(outdata) # lazy dask arrays have incorrect shape attribute due to last chunk assert outdata.shape == indata.shape[:-2] + horiz_shape_out - assert outdata.chunksize == indata.chunksize - # Check that the number of tasks hasn't exploded. + # Check that the number of tasks is as predicted + # ds_in has 1 chunk + # thus output also has 1 chunk (output is not chunked if input isn't) + # regridding adds 3 tasks, wrapping the weights adds 2 + n_task_out = len(outdata.__dask_graph__().keys()) n_task_in = len(indata.__dask_graph__().keys()) + assert n_task_out == n_task_in + 5 + + # Use very small chunks + indata_chunked = indata.rechunk((5, 6)) # Now has 9 chunks (5, 6) + outdata = regridder(indata_chunked) + # This is the case where we preserve chunk size + assert outdata.chunksize == indata_chunked.chunksize n_task_out = len(outdata.__dask_graph__().keys()) - assert (n_task_out / n_task_in) < 15 + n_task_in = len(indata_chunked.__dask_graph__().keys()) + # input has 9 chunks + # output has 16 + # Regridding adds 2 * 9 * 16 + 16 + 64 (I'm not sure I fully understand how dasks sums at the end) + # Wrapping the weights adds 9 * 16 + 1 + assert n_task_out == n_task_in + 513 + + # Prescribe chunks + outdata = regridder(indata, output_chunks=(-1, 12)) + n_task_out = len(outdata.__dask_graph__().keys()) + n_task_in = len(indata.__dask_graph__().keys()) + # input has 1 chunks + # output has 2 + # Regridding adds 2 * 1 * 2 + 2 + # Wrapping the weights adds 1 * 2 + 1 + assert n_task_out == n_task_in + 9 - outdata_ref = ds_out['data4D_ref'].values + outdata_ref = ds_out['data_ref'].values rel_err = (outdata.compute() - outdata_ref) / outdata_ref assert np.max(np.abs(rel_err)) < 0.05 @@ -562,7 +585,6 @@ def test_regrid_dataarray_dask(request, scheduler): assert dask.is_dask_collection(dr_out) assert dr_out.data.shape == dr_in.data.shape[:-2] + horiz_shape_out - assert dr_out.data.chunksize == dr_in.data.chunksize # data over broadcasting dimensions should agree assert_almost_equal(dr_in.values.mean(axis=(2, 3)), dr_out.values.mean(axis=(2, 3)), decimal=10)