diff --git a/ci/doc.yml b/ci/doc.yml index 414cd07..e4e46d6 100644 --- a/ci/doc.yml +++ b/ci/doc.yml @@ -15,6 +15,7 @@ dependencies: - furo>=2024.8.6 - myst-nb - xarray + - zarr - pip: # relative to this file. Needs to be editable to be accepted. - --editable .. diff --git a/cupy_xarray/__init__.py b/cupy_xarray/__init__.py index 5c3a06c..0bb96aa 100644 --- a/cupy_xarray/__init__.py +++ b/cupy_xarray/__init__.py @@ -1,4 +1,5 @@ from . import _version -from .accessors import CupyDataArrayAccessor, CupyDatasetAccessor # noqa +from .accessors import CupyDataArrayAccessor, CupyDatasetAccessor # noqa: F401 +from .kvikio import KvikioBackendEntrypoint # noqa: F401 __version__ = _version.get_versions()["version"] diff --git a/cupy_xarray/kvikio.py b/cupy_xarray/kvikio.py new file mode 100644 index 0000000..3004ef1 --- /dev/null +++ b/cupy_xarray/kvikio.py @@ -0,0 +1,230 @@ +""" +:doc:`kvikIO ` backend for xarray to read Zarr stores directly into CuPy +arrays in GPU memory. +""" + +import os +import warnings + +import cupy as cp +import numpy as np +from xarray import Variable +from xarray.backends import zarr as zarr_backend +from xarray.backends.common import _normalize_path # TODO: can this be public +from xarray.backends.store import StoreBackendEntrypoint +from xarray.backends.zarr import ZarrArrayWrapper, ZarrBackendEntrypoint, ZarrStore +from xarray.core import indexing +from xarray.core.utils import close_on_error # TODO: can this be public. + +try: + import kvikio.zarr + import zarr + + has_kvikio = True +except ImportError: + has_kvikio = False + + +# TODO: minimum kvikio version for supporting consolidated +# TODO: minimum xarray version for ZarrArrayWrapper._array 2023.10.0? + + +class DummyZarrArrayWrapper(ZarrArrayWrapper): + def __init__(self, array: np.ndarray): + assert isinstance(array, np.ndarray) + self._array = array + self.filters = None + self.dtype = array.dtype + self.shape = array.shape + + def __array__(self): + return self._array + + def get_array(self): + return self._array + + def __getitem__(self, key): + return self._array[key] + + +class CupyZarrArrayWrapper(ZarrArrayWrapper): + def __array__(self): + return self.get_array() + + +class EagerCupyZarrArrayWrapper(ZarrArrayWrapper): + """Used to wrap dimension coordinates.""" + + def __array__(self): + return self._array[:].get() + + def get_duck_array(self): + # total hack: make a numpy array look like a Zarr array + # this gets us through Xarray's backend layers + return DummyZarrArrayWrapper(self._array[:].get()) + + +class GDSZarrStore(ZarrStore): + @classmethod + def open_group( + cls, + store, + mode="r", + synchronizer=None, + group=None, + consolidated=False, + consolidate_on_close=False, + chunk_store=None, + storage_options=None, + append_dim=None, + write_region=None, + safe_chunks=True, + stacklevel=2, + ): + # zarr doesn't support pathlib.Path objects yet. zarr-python#601 + if isinstance(store, os.PathLike): + store = os.fspath(store) + + open_kwargs = { + "mode": mode, + "synchronizer": synchronizer, + "path": group, + ########## NEW STUFF + "meta_array": cp.empty(()), + } + open_kwargs["storage_options"] = storage_options + + if chunk_store: + open_kwargs["chunk_store"] = chunk_store + if consolidated is None: + consolidated = False + + store = kvikio.zarr.GDSStore(store) + + if consolidated is None: + try: + zarr_group = zarr.open_consolidated(store, **open_kwargs) + except KeyError: + warnings.warn( + "Failed to open Zarr store with consolidated metadata, " + "falling back to try reading non-consolidated metadata. " + "This is typically much slower for opening a dataset. " + "To silence this warning, consider:\n" + "1. Consolidating metadata in this existing store with " + "zarr.consolidate_metadata().\n" + "2. Explicitly setting consolidated=False, to avoid trying " + "to read consolidate metadata, or\n" + "3. Explicitly setting consolidated=True, to raise an " + "error in this case instead of falling back to try " + "reading non-consolidated metadata.", + RuntimeWarning, + stacklevel=stacklevel, + ) + zarr_group = zarr.open_group(store, **open_kwargs) + elif consolidated: + # TODO: an option to pass the metadata_key keyword + zarr_group = zarr.open_consolidated(store, **open_kwargs) + else: + zarr_group = zarr.open_group(store, **open_kwargs) + + return cls( + zarr_group, + mode, + consolidate_on_close, + append_dim, + write_region, + safe_chunks, + ) + + def open_store_variable(self, name, zarr_array): + try_nczarr = self._mode == "r" + dimensions, attributes = zarr_backend._get_zarr_dims_and_attrs( + zarr_array, zarr_backend.DIMENSION_KEY, try_nczarr + ) + + #### Changed from zarr array wrapper + # we want indexed dimensions to be loaded eagerly + # Right now we load in to device and then transfer to host + # But these should be small-ish arrays + # TODO: can we tell GDSStore to load as numpy array directly + # not cupy array? + array_wrapper = EagerCupyZarrArrayWrapper if name in dimensions else CupyZarrArrayWrapper + data = indexing.LazilyIndexedArray(array_wrapper(zarr_array)) + + attributes = dict(attributes) + encoding = { + "chunks": zarr_array.chunks, + "preferred_chunks": dict(zip(dimensions, zarr_array.chunks)), + "compressor": zarr_array.compressor, + "filters": zarr_array.filters, + } + # _FillValue needs to be in attributes, not encoding, so it will get + # picked up by decode_cf + if zarr_array.fill_value is not None: + attributes["_FillValue"] = zarr_array.fill_value + + return Variable(dimensions, data, attributes, encoding) + + +class KvikioBackendEntrypoint(ZarrBackendEntrypoint): + """ + Xarray backend to read Zarr stores using 'kvikio' engine. + + For more information about the underlying library, visit + :doc:`kvikIO's Zarr page`. + """ + + available = has_kvikio + description = "Open zarr files (.zarr) using Kvikio" + url = "https://docs.rapids.ai/api/kvikio/stable/api/#zarr" + + # disabled by default + # We need to provide this because of the subclassing from + # ZarrBackendEntrypoint + def guess_can_open(self, filename_or_obj): + return False + + def open_dataset( + self, + filename_or_obj, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + group=None, + mode="r", + synchronizer=None, + consolidated=None, + chunk_store=None, + storage_options=None, + stacklevel=3, + ): + filename_or_obj = _normalize_path(filename_or_obj) + store = GDSZarrStore.open_group( + filename_or_obj, + group=group, + mode=mode, + synchronizer=synchronizer, + consolidated=consolidated, + consolidate_on_close=False, + chunk_store=chunk_store, + storage_options=storage_options, + stacklevel=stacklevel + 1, + ) + + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds diff --git a/cupy_xarray/tests/test_kvikio.py b/cupy_xarray/tests/test_kvikio.py new file mode 100644 index 0000000..ba64fbb --- /dev/null +++ b/cupy_xarray/tests/test_kvikio.py @@ -0,0 +1,54 @@ +import cupy as cp +import numpy as np +import pytest +import xarray as xr +from xarray.core.indexing import ExplicitlyIndexedNDArrayMixin + +kvikio = pytest.importorskip("kvikio") +zarr = pytest.importorskip("zarr") + +import kvikio.zarr # noqa +import xarray.core.indexing # noqa + + +@pytest.fixture +def store(tmp_path): + ds = xr.Dataset( + { + "a": ("x", np.arange(10), {"foo": "bar"}), + "scalar": np.array(1), + }, + coords={"x": ("x", np.arange(-5, 5))}, + ) + + for var in ds.variables: + ds[var].encoding["compressor"] = None + + store_path = tmp_path / "kvikio.zarr" + ds.to_zarr(store_path, consolidated=True) + return store_path + + +def test_entrypoint(): + assert "kvikio" in xr.backends.list_engines() + + +@pytest.mark.parametrize("consolidated", [True, False]) +def test_lazy_load(consolidated, store): + with xr.open_dataset(store, engine="kvikio", consolidated=consolidated) as ds: + for _, da in ds.data_vars.items(): + assert isinstance(da.variable._data, ExplicitlyIndexedNDArrayMixin) + + +@pytest.mark.parametrize("indexer", [slice(None), slice(2, 4), 2, [2, 3, 5]]) +def test_lazy_indexing(indexer, store): + with xr.open_dataset(store, engine="kvikio") as ds: + ds = ds.isel(x=indexer) + for _, da in ds.data_vars.items(): + assert isinstance(da.variable._data, ExplicitlyIndexedNDArrayMixin) + + loaded = ds.compute() + for _, da in loaded.data_vars.items(): + if da.ndim == 0: + continue + assert isinstance(da.data, cp.ndarray) diff --git a/docs/api.rst b/docs/api.rst index 70d22b0..17bdb12 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -51,3 +51,16 @@ Methods Dataset.cupy.as_cupy Dataset.cupy.as_numpy + + +KvikIO engine +------------- + +.. currentmodule:: cupy_xarray + +.. automodule:: cupy_xarray.kvikio + +.. autosummary:: + :toctree: generated/ + + KvikioBackendEntrypoint diff --git a/docs/conf.py b/docs/conf.py index 2dffa80..ebba5ed 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -57,6 +57,7 @@ "python": ("https://docs.python.org/3/", None), "dask": ("https://docs.dask.org/en/latest", None), "cupy": ("https://docs.cupy.dev/en/latest", None), + "kvikio": ("https://docs.rapids.ai/api/kvikio/stable", None), "xarray": ("http://docs.xarray.dev/en/latest/", None), } diff --git a/docs/index.md b/docs/index.md index 56f1995..b649c66 100644 --- a/docs/index.md +++ b/docs/index.md @@ -59,6 +59,7 @@ Large parts of this documentations comes from [SciPy 2023 Xarray on GPUs tutoria source/high-level-api source/apply-ufunc source/real-example-1 + source/kvikio **Tutorials & Presentations**: diff --git a/docs/source/kvikio.ipynb b/docs/source/kvikio.ipynb new file mode 100644 index 0000000..6117eb2 --- /dev/null +++ b/docs/source/kvikio.ipynb @@ -0,0 +1,5032 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5920bb97-1d76-4363-9aee-d1c5cd395409", + "metadata": {}, + "source": [ + "# Kvikio demo\n", + "\n", + "Requires\n", + "- [ ] https://github.com/pydata/xarray/pull/8100\n", + "- [ ] Some updates to `dask.array.core.getter`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c9ee3a73-6f7b-4875-b5a6-2e6d48fade44", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exception reporting mode: Minimal\n", + "kvikio : 23.2.0\n", + "xarray : 2022.6.1.dev458+g83c2919b2\n", + "numpy_groupies: 0.9.22+2.gd148074\n", + "json : 2.0.9\n", + "numpy : 1.24.4\n", + "flox : 0.7.3.dev12+g796dcd2\n", + "zarr : 2.16.1\n", + "dask : 2023.8.1\n", + "cupy_xarray : 0.1.1+21.gd2da1e4.dirty\n", + "sys : 3.9.17 | packaged by conda-forge | (main, Aug 10 2023, 07:02:31) \n", + "[GCC 12.3.0]\n", + "\n" + ] + } + ], + "source": [ + "%load_ext watermark\n", + "%xmode minimal\n", + "\n", + "# These imports are currently unnecessary. I import them to show versions\n", + "# cupy_xarray registers the kvikio entrypoint on install.\n", + "# import cupy as cp\n", + "# import cudf\n", + "import cupy_xarray # registers cupy accessor\n", + "import kvikio.zarr\n", + "\n", + "import flox\n", + "import numpy_groupies\n", + "import numpy as np\n", + "import xarray as xr\n", + "import zarr\n", + "\n", + "import dask\n", + "\n", + "dask.config.set(scheduler=\"sync\")\n", + "\n", + "store = \"./air-temperature.zarr\"\n", + "\n", + "%watermark -iv" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "83b1b514-eeb8-4a81-a3e8-3a7dc82ffce4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'kvikio': \n", + " Open zarr files (.zarr) using Kvikio\n", + " Learn more at https://docs.rapids.ai/api/kvikio/nightly/api.html#zarr,\n", + " 'store': \n", + " Open AbstractDataStore instances in Xarray\n", + " Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.StoreBackendEntrypoint.html,\n", + " 'zarr': \n", + " Open zarr files (.zarr) using zarr in Xarray\n", + " Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.ZarrBackendEntrypoint.html}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xr.backends.list_engines()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "81b2e5cb-4b2d-4a31-b7a0-961aadbc321d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> \u001b[0;32m/glade/u/home/dcherian/python/xarray/xarray/core/indexing.py\u001b[0m(485)\u001b[0;36m__array__\u001b[0;34m()\u001b[0m\n", + "\u001b[0;32m 484 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\u001b[0;32m--> 485 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_duck_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\u001b[0;32m 486 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + "ipdb> c\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> \u001b[0;32m/glade/u/home/dcherian/python/xarray/xarray/core/indexing.py\u001b[0m(485)\u001b[0;36m__array__\u001b[0;34m()\u001b[0m\n", + "\u001b[0;32m 484 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\u001b[0;32m--> 485 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_duck_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\u001b[0;32m 486 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0m\n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + "ipdb> c\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MemoryCachedArray(array=CopyOnWriteArray(array=LazilyIndexedArray(array=_ElementwiseFunctionArray(LazilyIndexedArray(array=, key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None)))), func=functools.partial(, scale_factor=0.01, add_offset=None, dtype=), dtype=dtype('float32')), key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None))))))\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset>\n",
+       "Dimensions:  (time: 2920, lat: 25, lon: 53)\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n",
+       "  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n",
+       "  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "Data variables:\n",
+       "    air      (time, lat, lon) float32 ...\n",
+       "    scalar   float64 ...\n",
+       "Attributes:\n",
+       "    Conventions:  COARDS\n",
+       "    description:  Data is from NMC initialized reanalysis\\n(4x/day).  These a...\n",
+       "    platform:     Model\n",
+       "    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...\n",
+       "    title:        4x daily NMC reanalysis (1948)
" + ], + "text/plain": [ + "\n", + "Dimensions: (time: 2920, lat: 25, lon: 53)\n", + "Coordinates:\n", + " * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n", + " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n", + " * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n", + "Data variables:\n", + " air (time, lat, lon) float32 ...\n", + " scalar float64 ...\n", + "Attributes:\n", + " Conventions: COARDS\n", + " description: Data is from NMC initialized reanalysis\\n(4x/day). These a...\n", + " platform: Model\n", + " references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...\n", + " title: 4x daily NMC reanalysis (1948)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%autoreload\n", + "\n", + "# Consolidated must be False\n", + "ds = xr.open_dataset(store, engine=\"kvikio\", consolidated=False)\n", + "print(ds.air._variable._data)\n", + "ds" + ] + }, + { + "cell_type": "markdown", + "id": "6d301bec-e64b-4a8f-9c20-5dab56721561", + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, + "source": [ + "## Create example dataset\n", + "\n", + "- cannot be compressed" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d481cc3b-420e-4b7c-8c5e-77d874128b12", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "airt = xr.tutorial.open_dataset(\"air_temperature\", engine=\"netcdf4\")\n", + "for var in airt.variables:\n", + " airt[var].encoding[\"compressor\"] = None\n", + "airt[\"scalar\"] = 12.0\n", + "airt.to_zarr(store, mode=\"w\", consolidated=True)" + ] + }, + { + "cell_type": "markdown", + "id": "883d5507-988f-453a-b576-87bb563b540f", + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, + "source": [ + "## Test opening\n", + "\n", + "### Standard usage" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "58063142-b69b-46a5-9e4d-a83944e57857", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>\n",
+       "[3869000 values with dtype=float32]\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n",
+       "  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n",
+       "  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "Attributes:\n",
+       "    GRIB_id:       11\n",
+       "    GRIB_name:     TMP\n",
+       "    actual_range:  [185.16000366210938, 322.1000061035156]\n",
+       "    dataset:       NMC Reanalysis\n",
+       "    level_desc:    Surface\n",
+       "    long_name:     4xDaily Air temperature at sigma level 995\n",
+       "    parent_stat:   Other\n",
+       "    precision:     2\n",
+       "    statistic:     Individual Obs\n",
+       "    units:         degK\n",
+       "    var_desc:      Air temperature
" + ], + "text/plain": [ + "\n", + "[3869000 values with dtype=float32]\n", + "Coordinates:\n", + " * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n", + " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n", + " * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n", + "Attributes:\n", + " GRIB_id: 11\n", + " GRIB_name: TMP\n", + " actual_range: [185.16000366210938, 322.1000061035156]\n", + " dataset: NMC Reanalysis\n", + " level_desc: Surface\n", + " long_name: 4xDaily Air temperature at sigma level 995\n", + " parent_stat: Other\n", + " precision: 2\n", + " statistic: Individual Obs\n", + " units: degK\n", + " var_desc: Air temperature" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xr.open_dataset(store, engine=\"zarr\").air" + ] + }, + { + "cell_type": "markdown", + "id": "95161182-6b58-4dbd-9752-9961c251be1a", + "metadata": {}, + "source": [ + "### Now with kvikio!\n", + "\n", + " - must read with `consolidated=False` (https://github.com/rapidsai/kvikio/issues/119)\n", + " - dask.from_zarr to GDSStore / open_mfdataset" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8fd27bdf-e317-4de3-891e-41d38d06dcaf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MemoryCachedArray(array=CopyOnWriteArray(array=LazilyIndexedArray(array=_ElementwiseFunctionArray(LazilyIndexedArray(array=, key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None)))), func=functools.partial(, scale_factor=0.01, add_offset=None, dtype=), dtype=dtype('float32')), key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None))))))\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset>\n",
+       "Dimensions:  (time: 2920, lat: 25, lon: 53)\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n",
+       "  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n",
+       "  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "Data variables:\n",
+       "    air      (time, lat, lon) float32 ...\n",
+       "    scalar   float64 ...\n",
+       "Attributes:\n",
+       "    Conventions:  COARDS\n",
+       "    description:  Data is from NMC initialized reanalysis\\n(4x/day).  These a...\n",
+       "    platform:     Model\n",
+       "    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...\n",
+       "    title:        4x daily NMC reanalysis (1948)
" + ], + "text/plain": [ + "\n", + "Dimensions: (time: 2920, lat: 25, lon: 53)\n", + "Coordinates:\n", + " * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n", + " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n", + " * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n", + "Data variables:\n", + " air (time, lat, lon) float32 ...\n", + " scalar float64 ...\n", + "Attributes:\n", + " Conventions: COARDS\n", + " description: Data is from NMC initialized reanalysis\\n(4x/day). These a...\n", + " platform: Model\n", + " references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...\n", + " title: 4x daily NMC reanalysis (1948)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Consolidated must be False\n", + "ds = xr.open_dataset(store, engine=\"kvikio\", consolidated=False)\n", + "print(ds.air._variable._data)\n", + "ds" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6c939a04-1588-4693-9483-c6ad7152951a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'scalar' ()>\n",
+       "[1 values with dtype=float64]
" + ], + "text/plain": [ + "\n", + "[1 values with dtype=float64]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.scalar" + ] + }, + { + "cell_type": "markdown", + "id": "bb84a7ad-84dc-4bb3-8636-3f9416953089", + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, + "source": [ + "## Lazy reading" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1ecc39b1-b788-4831-9160-5b35afb83598", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>\n",
+       "[3869000 values with dtype=float32]\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n",
+       "  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n",
+       "  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "Attributes:\n",
+       "    GRIB_id:       11\n",
+       "    GRIB_name:     TMP\n",
+       "    actual_range:  [185.16000366210938, 322.1000061035156]\n",
+       "    dataset:       NMC Reanalysis\n",
+       "    level_desc:    Surface\n",
+       "    long_name:     4xDaily Air temperature at sigma level 995\n",
+       "    parent_stat:   Other\n",
+       "    precision:     2\n",
+       "    statistic:     Individual Obs\n",
+       "    units:         degK\n",
+       "    var_desc:      Air temperature
" + ], + "text/plain": [ + "\n", + "[3869000 values with dtype=float32]\n", + "Coordinates:\n", + " * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n", + " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n", + " * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n", + "Attributes:\n", + " GRIB_id: 11\n", + " GRIB_name: TMP\n", + " actual_range: [185.16000366210938, 322.1000061035156]\n", + " dataset: NMC Reanalysis\n", + " level_desc: Surface\n", + " long_name: 4xDaily Air temperature at sigma level 995\n", + " parent_stat: Other\n", + " precision: 2\n", + " statistic: Individual Obs\n", + " units: degK\n", + " var_desc: Air temperature" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.air" + ] + }, + { + "cell_type": "markdown", + "id": "7d366864-a2b3-4573-9bf7-41d1f6ee457c", + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, + "source": [ + "## Data load for repr" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "00205e73-9b43-4254-9cba-f75435251391", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'air' (lon: 53)>\n",
+       "array([277.29   , 277.4    , 277.79   , 278.6    , 279.5    , 280.1    ,\n",
+       "       280.6    , 280.9    , 280.79   , 280.69998, 280.79   , 281.     ,\n",
+       "       280.29   , 277.69998, 273.5    , 269.     , 265.5    , 264.     ,\n",
+       "       265.19998, 268.1    , 269.79   , 267.9    , 263.     , 258.1    ,\n",
+       "       254.59999, 251.79999, 249.59999, 249.89   , 252.29999, 254.     ,\n",
+       "       254.29999, 255.89   , 260.     , 263.     , 261.5    , 257.29   ,\n",
+       "       255.5    , 258.29   , 264.     , 268.69998, 270.5    , 270.6    ,\n",
+       "       271.19998, 272.9    , 274.79   , 276.4    , 278.19998, 280.5    ,\n",
+       "       282.9    , 284.69998, 286.1    , 286.9    , 286.6    ],\n",
+       "      dtype=float32)\n",
+       "Coordinates:\n",
+       "    lat      float32 50.0\n",
+       "  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n",
+       "    time     datetime64[ns] 2013-01-01\n",
+       "Attributes:\n",
+       "    GRIB_id:       11\n",
+       "    GRIB_name:     TMP\n",
+       "    actual_range:  [185.16000366210938, 322.1000061035156]\n",
+       "    dataset:       NMC Reanalysis\n",
+       "    level_desc:    Surface\n",
+       "    long_name:     4xDaily Air temperature at sigma level 995\n",
+       "    parent_stat:   Other\n",
+       "    precision:     2\n",
+       "    statistic:     Individual Obs\n",
+       "    units:         degK\n",
+       "    var_desc:      Air temperature
" + ], + "text/plain": [ + "\n", + "array([277.29 , 277.4 , 277.79 , 278.6 , 279.5 , 280.1 ,\n", + " 280.6 , 280.9 , 280.79 , 280.69998, 280.79 , 281. ,\n", + " 280.29 , 277.69998, 273.5 , 269. , 265.5 , 264. ,\n", + " 265.19998, 268.1 , 269.79 , 267.9 , 263. , 258.1 ,\n", + " 254.59999, 251.79999, 249.59999, 249.89 , 252.29999, 254. ,\n", + " 254.29999, 255.89 , 260. , 263. , 261.5 , 257.29 ,\n", + " 255.5 , 258.29 , 264. , 268.69998, 270.5 , 270.6 ,\n", + " 271.19998, 272.9 , 274.79 , 276.4 , 278.19998, 280.5 ,\n", + " 282.9 , 284.69998, 286.1 , 286.9 , 286.6 ],\n", + " dtype=float32)\n", + "Coordinates:\n", + " lat float32 50.0\n", + " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n", + " time datetime64[ns] 2013-01-01\n", + "Attributes:\n", + " GRIB_id: 11\n", + " GRIB_name: TMP\n", + " actual_range: [185.16000366210938, 322.1000061035156]\n", + " dataset: NMC Reanalysis\n", + " level_desc: Surface\n", + " long_name: 4xDaily Air temperature at sigma level 995\n", + " parent_stat: Other\n", + " precision: 2\n", + " statistic: Individual Obs\n", + " units: degK\n", + " var_desc: Air temperature" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds[\"air\"].isel(time=0, lat=10).load()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "80aa6892-8c7f-44b3-bd52-9795ec4ea6f3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'scalar' ()>\n",
+       "[1 values with dtype=float64]
" + ], + "text/plain": [ + "\n", + "[1 values with dtype=float64]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.scalar" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ba48a2c0-96e0-41d7-9e07-381e05e8dc33", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>\n",
+       "[3869000 values with dtype=float32]\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n",
+       "  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n",
+       "  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "Attributes:\n",
+       "    GRIB_id:       11\n",
+       "    GRIB_name:     TMP\n",
+       "    actual_range:  [185.16000366210938, 322.1000061035156]\n",
+       "    dataset:       NMC Reanalysis\n",
+       "    level_desc:    Surface\n",
+       "    long_name:     4xDaily Air temperature at sigma level 995\n",
+       "    parent_stat:   Other\n",
+       "    precision:     2\n",
+       "    statistic:     Individual Obs\n",
+       "    units:         degK\n",
+       "    var_desc:      Air temperature
" + ], + "text/plain": [ + "\n", + "[3869000 values with dtype=float32]\n", + "Coordinates:\n", + " * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n", + " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n", + " * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n", + "Attributes:\n", + " GRIB_id: 11\n", + " GRIB_name: TMP\n", + " actual_range: [185.16000366210938, 322.1000061035156]\n", + " dataset: NMC Reanalysis\n", + " level_desc: Surface\n", + " long_name: 4xDaily Air temperature at sigma level 995\n", + " parent_stat: Other\n", + " precision: 2\n", + " statistic: Individual Obs\n", + " units: degK\n", + " var_desc: Air temperature" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.air" + ] + }, + { + "cell_type": "markdown", + "id": "d0ea31d2-6c52-4346-b489-fc1e43200213", + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, + "source": [ + "## CuPy array on load" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "1b34a68a-a6b3-4273-bf7c-28814ebfce11", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "MemoryCachedArray(array=CopyOnWriteArray(array=LazilyIndexedArray(array=_ElementwiseFunctionArray(LazilyIndexedArray(array=, key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None)))), func=functools.partial(, scale_factor=0.01, add_offset=None, dtype=), dtype=dtype('float32')), key=BasicIndexer((0, 10, slice(None, None, None))))))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds[\"air\"].isel(time=0, lat=10).variable._data" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "db69559c-1fde-4b3b-914d-87d8437ec256", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "cupy.ndarray" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(ds[\"air\"].isel(time=0, lat=10).load().data)" + ] + }, + { + "cell_type": "markdown", + "id": "d34a5cce-7bbc-408f-b643-05da1e121c78", + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, + "source": [ + "## Load to host" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "09b40d7d-ed38-4a50-af11-c2e5f0242a97", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>\n",
+       "[3869000 values with dtype=float32]\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n",
+       "  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n",
+       "  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "Attributes:\n",
+       "    GRIB_id:       11\n",
+       "    GRIB_name:     TMP\n",
+       "    actual_range:  [185.16000366210938, 322.1000061035156]\n",
+       "    dataset:       NMC Reanalysis\n",
+       "    level_desc:    Surface\n",
+       "    long_name:     4xDaily Air temperature at sigma level 995\n",
+       "    parent_stat:   Other\n",
+       "    precision:     2\n",
+       "    statistic:     Individual Obs\n",
+       "    units:         degK\n",
+       "    var_desc:      Air temperature
" + ], + "text/plain": [ + "\n", + "[3869000 values with dtype=float32]\n", + "Coordinates:\n", + " * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n", + " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n", + " * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n", + "Attributes:\n", + " GRIB_id: 11\n", + " GRIB_name: TMP\n", + " actual_range: [185.16000366210938, 322.1000061035156]\n", + " dataset: NMC Reanalysis\n", + " level_desc: Surface\n", + " long_name: 4xDaily Air temperature at sigma level 995\n", + " parent_stat: Other\n", + " precision: 2\n", + " statistic: Individual Obs\n", + " units: degK\n", + " var_desc: Air temperature" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.air" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "eeb9ad78-1353-464f-8419-4c44ea499f17", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "numpy.ndarray" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(ds.air.as_numpy().data)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "615efd76-2194-4604-9ab8-61499e7d725d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "cupy.ndarray" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(ds.air.data)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "140fe3e2-ea9b-445d-8401-5c624384c182", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "cupy.ndarray" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(ds.air.mean(\"time\").load().data)" + ] + }, + { + "cell_type": "markdown", + "id": "cab539a7-d952-4b38-b515-712c52c62501", + "metadata": { + "tags": [] + }, + "source": [ + "## Doesn't work: Chunk with dask" + ] + }, + { + "cell_type": "markdown", + "id": "62c084eb-8df4-4b7f-a187-a736d68d430d", + "metadata": {}, + "source": [ + "`meta` is wrong" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "68f93bfe-fe56-488a-a10b-dc4f48029367", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>\n",
+       "dask.array<xarray-air, shape=(2920, 25, 53), dtype=float32, chunksize=(10, 25, 53), chunktype=numpy.ndarray>\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n",
+       "  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n",
+       "  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "Attributes:\n",
+       "    GRIB_id:       11\n",
+       "    GRIB_name:     TMP\n",
+       "    actual_range:  [185.16000366210938, 322.1000061035156]\n",
+       "    dataset:       NMC Reanalysis\n",
+       "    level_desc:    Surface\n",
+       "    long_name:     4xDaily Air temperature at sigma level 995\n",
+       "    parent_stat:   Other\n",
+       "    precision:     2\n",
+       "    statistic:     Individual Obs\n",
+       "    units:         degK\n",
+       "    var_desc:      Air temperature
" + ], + "text/plain": [ + "\n", + "dask.array\n", + "Coordinates:\n", + " * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0\n", + " * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0\n", + " * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00\n", + "Attributes:\n", + " GRIB_id: 11\n", + " GRIB_name: TMP\n", + " actual_range: [185.16000366210938, 322.1000061035156]\n", + " dataset: NMC Reanalysis\n", + " level_desc: Surface\n", + " long_name: 4xDaily Air temperature at sigma level 995\n", + " parent_stat: Other\n", + " precision: 2\n", + " statistic: Individual Obs\n", + " units: degK\n", + " var_desc: Air temperature" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.chunk(time=10).air" + ] + }, + { + "cell_type": "markdown", + "id": "3f4c72f6-22e7-4e99-9f4e-2524d6ab4226", + "metadata": {}, + "source": [ + "`dask.array.core.getter` calls `np.asarray` on each chunk.\n", + "\n", + "This calls `ImplicitToExplicitIndexingAdapter.__array__` which calls `np.asarray(cupy.array)` which raises.\n", + "\n", + "Xarray uses `.get_duck_array` internally to remove these adapters. We might need to add\n", + "```python\n", + "# handle xarray internal classes that might wrap cupy\n", + "if hasattr(c, \"get_duck_array\"):\n", + " c = c.get_duck_array()\n", + "else:\n", + " c = np.asarray(c)\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "e1256d03-9701-433a-8291-80dc8dccffce", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from dask.utils import is_arraylike\n", + "\n", + "data = ds.air.variable._data\n", + "is_arraylike(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "308affa5-9fb9-4638-989b-97aac2604c16", + "metadata": {}, + "outputs": [], + "source": [ + "from xarray.core.indexing import ImplicitToExplicitIndexingAdapter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "985cd2f8-406e-4e9e-8017-42efb16aa40e", + "metadata": {}, + "outputs": [], + "source": [ + "ImplicitToExplicitIndexingAdapter(data).get_duck_array()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa8ef4f7-5014-476f-b4c0-ec2f9abdb6e2", + "metadata": {}, + "outputs": [], + "source": [ + "ds.chunk(time=10).air.compute()" + ] + }, + { + "cell_type": "markdown", + "id": "17dc1bf6-7548-4eee-a5f3-ebcc20d41567", + "metadata": {}, + "source": [ + "### explicit meta" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cdd4b4e6-d69a-4898-964a-0e6096ca1942", + "metadata": {}, + "outputs": [], + "source": [ + "import cupy as cp\n", + "\n", + "chunked = ds.chunk(time=10, from_array_kwargs={\"meta\": cp.array([])})\n", + "chunked.air" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74f80d94-ebb6-43c3-9411-79e0442d894e", + "metadata": {}, + "outputs": [], + "source": [ + "%autoreload\n", + "\n", + "chunked.compute()" + ] + }, + { + "cell_type": "markdown", + "id": "1c07c449-bc43-490a-ac38-11e93200133d", + "metadata": {}, + "source": [ + "## GroupBy with flox\n", + "\n", + "Requires\n", + "\n", + "1. flox main branch?\n", + "2. https://github.com/ml31415/numpy-groupies/pull/63" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c292cf77-c99e-40fa-8cad-d8914c346b29", + "metadata": {}, + "outputs": [], + "source": [ + "ds.air.groupby(\"time.month\").mean()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "miniconda3-kvikio_env", + "language": "python", + "name": "conda-env-miniconda3-kvikio_env-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.17" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index d98b3fe..2d5094e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ test = [ "pytest", ] +[project.entry-points."xarray.backends"] +kvikio = "cupy_xarray.kvikio:KvikioBackendEntrypoint" + [tool.ruff] line-length = 100 # E501 (line-too-long) exclude = [