diff --git a/.github/release.yml b/.github/release.yml new file mode 100644 index 00000000000..9d1e0987bfc --- /dev/null +++ b/.github/release.yml @@ -0,0 +1,5 @@ +changelog: + exclude: + authors: + - dependabot + - pre-commit-ci diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 01a2b4afc55..c0a3e6bbf4e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.4.7' + rev: 'v0.5.0' hooks: - id: ruff args: ["--fix", "--show-fixes"] @@ -30,7 +30,7 @@ repos: additional_dependencies: ["black==24.4.2"] - id: blackdoc-autoupdate-black - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.0 + rev: v1.10.1 hooks: - id: mypy # Copied from setup.cfg diff --git a/HOW_TO_RELEASE.md b/HOW_TO_RELEASE.md index 9d1164547b9..28f7b8e8775 100644 --- a/HOW_TO_RELEASE.md +++ b/HOW_TO_RELEASE.md @@ -23,14 +23,13 @@ upstream https://github.com/pydata/xarray (push) ```sh git fetch upstream --tags ``` - This will return a list of all the contributors since the last release: + Then run ```sh - git log "$(git tag --sort=v:refname | tail -1).." --format=%aN | sort -u | perl -pe 's/\n/$1, /' - ``` - This will return the total number of contributors: - ```sh - git log "$(git tag --sort=v:refname | tail -1).." --format=%aN | sort -u | wc -l + python ci/release_contributors.py ``` + (needs `gitpython` and `toolz` / `cytoolz`) + + and copy the output. 3. Write a release summary: ~50 words describing the high level features. This will be used in the release emails, tweets, GitHub release notes, etc. 4. Look over whats-new.rst and the docs. Make sure "What's New" is complete diff --git a/asv_bench/benchmarks/coding.py b/asv_bench/benchmarks/coding.py new file mode 100644 index 00000000000..c39555243c0 --- /dev/null +++ b/asv_bench/benchmarks/coding.py @@ -0,0 +1,18 @@ +import numpy as np + +import xarray as xr + +from . import parameterized + + +@parameterized(["calendar"], [("standard", "noleap")]) +class EncodeCFDatetime: + def setup(self, calendar): + self.units = "days since 2000-01-01" + self.dtype = np.dtype("int64") + self.times = xr.date_range( + "2000", freq="D", periods=10000, calendar=calendar + ).values + + def time_encode_cf_datetime(self, calendar): + xr.coding.times.encode_cf_datetime(self.times, self.units, calendar, self.dtype) diff --git a/ci/install-upstream-wheels.sh b/ci/install-upstream-wheels.sh index d728768439a..855336ad8bd 100755 --- a/ci/install-upstream-wheels.sh +++ b/ci/install-upstream-wheels.sh @@ -13,7 +13,7 @@ $conda remove -y numba numbagg sparse # temporarily remove numexpr $conda remove -y numexpr # temporarily remove backends -$conda remove -y cf_units hdf5 h5py netcdf4 pydap +$conda remove -y pydap # forcibly remove packages to avoid artifacts $conda remove -y --force \ numpy \ @@ -37,8 +37,7 @@ python -m pip install \ numpy \ scipy \ matplotlib \ - pandas \ - h5py + pandas # for some reason pandas depends on pyarrow already. # Remove once a `pyarrow` version compiled with `numpy>=2.0` is on `conda-forge` python -m pip install \ diff --git a/ci/release_contributors.py b/ci/release_contributors.py new file mode 100644 index 00000000000..dab95c651c5 --- /dev/null +++ b/ci/release_contributors.py @@ -0,0 +1,49 @@ +import re +import textwrap + +import git +from tlz.itertoolz import last, unique + +co_author_re = re.compile(r"Co-authored-by: (?P[^<]+?) <(?P.+)>") + + +def main(): + repo = git.Repo(".") + + most_recent_release = last(repo.tags) + + # extract information from commits + contributors = {} + for commit in repo.iter_commits(f"{most_recent_release.name}.."): + matches = co_author_re.findall(commit.message) + if matches: + contributors.update({email: name for name, email in matches}) + contributors[commit.author.email] = commit.author.name + + # deduplicate and ignore + # TODO: extract ignores from .github/release.yml + ignored = ["dependabot", "pre-commit-ci"] + unique_contributors = unique( + contributor + for contributor in contributors.values() + if contributor.removesuffix("[bot]") not in ignored + ) + + sorted_ = sorted(unique_contributors) + if len(sorted_) > 1: + names = f"{', '.join(sorted_[:-1])} and {sorted_[-1]}" + else: + names = "".join(sorted_) + + statement = textwrap.dedent( + f"""\ + Thanks to the {len(sorted_)} contributors to this release: + {names} + """.rstrip() + ) + + print(statement) + + +if __name__ == "__main__": + main() diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index abf6a88690a..119db282ad9 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -22,7 +22,7 @@ dependencies: - netcdf4 - numba - numbagg - - numpy<2 + - numpy - packaging - pandas - pint>=0.22 diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index 116eee7f702..d0484c5e999 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -8,6 +8,7 @@ dependencies: - bottleneck - cartopy - cfgrib + - kerchunk - dask-core>=2022.1 - dask-expr - hypothesis>=6.75.8 @@ -21,7 +22,7 @@ dependencies: - nbsphinx - netcdf4>=1.5 - numba - - numpy>=1.21,<2 + - numpy>=2 - packaging>=21.3 - pandas>=1.4,!=2.1.0 - pooch diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 2eedc9b0621..896e390ea3e 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -23,7 +23,7 @@ dependencies: - netcdf4 - numba - numbagg - - numpy<2 + - numpy - packaging - pandas # - pint>=0.22 diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 317e1fe5f41..ef02a3e7f23 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -26,7 +26,7 @@ dependencies: - numba - numbagg - numexpr - - numpy<2 + - numpy - opt_einsum - packaging - pandas diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index d9c89649358..c38b4f2d11c 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -693,3 +693,7 @@ coding.times.CFTimedeltaCoder coding.times.CFDatetimeCoder + + groupers.Grouper + groupers.Resampler + groupers.EncodedGroups diff --git a/doc/api.rst b/doc/api.rst index a8f8ea7dd1c..6ed8d513934 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -111,6 +111,7 @@ Dataset contents Dataset.drop_duplicates Dataset.drop_dims Dataset.drop_encoding + Dataset.drop_attrs Dataset.set_coords Dataset.reset_coords Dataset.convert_calendar @@ -306,6 +307,7 @@ DataArray contents DataArray.drop_indexes DataArray.drop_duplicates DataArray.drop_encoding + DataArray.drop_attrs DataArray.reset_coords DataArray.copy DataArray.convert_calendar @@ -801,6 +803,18 @@ DataArray DataArrayGroupBy.dims DataArrayGroupBy.groups +Grouper Objects +--------------- + +.. currentmodule:: xarray + +.. autosummary:: + :toctree: generated/ + + groupers.BinGrouper + groupers.UniqueGrouper + groupers.TimeResampler + Rolling objects =============== @@ -1026,17 +1040,20 @@ DataArray Accessors ========= -.. currentmodule:: xarray +.. currentmodule:: xarray.core .. autosummary:: :toctree: generated/ - core.accessor_dt.DatetimeAccessor - core.accessor_dt.TimedeltaAccessor - core.accessor_str.StringAccessor + accessor_dt.DatetimeAccessor + accessor_dt.TimedeltaAccessor + accessor_str.StringAccessor + Custom Indexes ============== +.. currentmodule:: xarray + .. autosummary:: :toctree: generated/ diff --git a/doc/combined.json b/doc/combined.json new file mode 100644 index 00000000000..345462e055f --- /dev/null +++ b/doc/combined.json @@ -0,0 +1,30 @@ +{ + "version": 1, + "refs": { + ".zgroup": "{\"zarr_format\":2}", + "foo/.zarray": "{\"chunks\":[4,5],\"compressor\":null,\"dtype\":\">> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " +copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.{3,}: | {5,8}: " copybutton_prompt_is_regexp = True # nbsphinx configurations @@ -158,6 +159,8 @@ "Variable": "~xarray.Variable", "DatasetGroupBy": "~xarray.core.groupby.DatasetGroupBy", "DataArrayGroupBy": "~xarray.core.groupby.DataArrayGroupBy", + "Grouper": "~xarray.groupers.Grouper", + "Resampler": "~xarray.groupers.Resampler", # objects without namespace: numpy "ndarray": "~numpy.ndarray", "MaskedArray": "~numpy.ma.MaskedArray", @@ -169,6 +172,7 @@ "CategoricalIndex": "~pandas.CategoricalIndex", "TimedeltaIndex": "~pandas.TimedeltaIndex", "DatetimeIndex": "~pandas.DatetimeIndex", + "IntervalIndex": "~pandas.IntervalIndex", "Series": "~pandas.Series", "DataFrame": "~pandas.DataFrame", "Categorical": "~pandas.Categorical", diff --git a/doc/ecosystem.rst b/doc/ecosystem.rst index 63f60cd0090..62927103c7e 100644 --- a/doc/ecosystem.rst +++ b/doc/ecosystem.rst @@ -17,6 +17,7 @@ Geosciences - `climpred `_: Analysis of ensemble forecast models for climate prediction. - `geocube `_: Tool to convert geopandas vector data into rasterized xarray data. - `GeoWombat `_: Utilities for analysis of remotely sensed and gridded raster data at scale (easily tame Landsat, Sentinel, Quickbird, and PlanetScope). +- `grib2io `_: Utility to work with GRIB2 files including an xarray backend, DASK support for parallel reading in open_mfdataset, lazy loading of data, editing of GRIB2 attributes and GRIB2IO DataArray attrs, and spatial interpolation and reprojection of GRIB2 messages and GRIB2IO Datasets/DataArrays for both grid to grid and grid to stations. - `gsw-xarray `_: a wrapper around `gsw `_ that adds CF compliant attributes when possible, units, name. - `infinite-diff `_: xarray-based finite-differencing, focused on gridded climate/meteorology data - `marc_analysis `_: Analysis package for CESM/MARC experiments and output. diff --git a/doc/getting-started-guide/faq.rst b/doc/getting-started-guide/faq.rst index 7f99fa77e3a..58d5448cdf5 100644 --- a/doc/getting-started-guide/faq.rst +++ b/doc/getting-started-guide/faq.rst @@ -352,9 +352,9 @@ Some packages may have additional functionality beyond what is shown here. You c How does xarray handle missing values? -------------------------------------- -**xarray can handle missing values using ``np.NaN``** +**xarray can handle missing values using ``np.nan``** -- ``np.NaN`` is used to represent missing values in labeled arrays and datasets. It is a commonly used standard for representing missing or undefined numerical data in scientific computing. ``np.NaN`` is a constant value in NumPy that represents "Not a Number" or missing values. +- ``np.nan`` is used to represent missing values in labeled arrays and datasets. It is a commonly used standard for representing missing or undefined numerical data in scientific computing. ``np.nan`` is a constant value in NumPy that represents "Not a Number" or missing values. - Most of xarray's computation methods are designed to automatically handle missing values appropriately. diff --git a/doc/getting-started-guide/installing.rst b/doc/getting-started-guide/installing.rst index ca12ae62440..823c50f333b 100644 --- a/doc/getting-started-guide/installing.rst +++ b/doc/getting-started-guide/installing.rst @@ -8,8 +8,8 @@ Required dependencies - Python (3.9 or later) - `numpy `__ (1.23 or later) -- `packaging `__ (22 or later) -- `pandas `__ (1.5 or later) +- `packaging `__ (23.1 or later) +- `pandas `__ (2.0 or later) .. _optional-dependencies: diff --git a/doc/internals/how-to-add-new-backend.rst b/doc/internals/how-to-add-new-backend.rst index 4352dd3df5b..a979abe34e2 100644 --- a/doc/internals/how-to-add-new-backend.rst +++ b/doc/internals/how-to-add-new-backend.rst @@ -4,7 +4,7 @@ How to add a new backend ------------------------ Adding a new backend for read support to Xarray does not require -to integrate any code in Xarray; all you need to do is: +one to integrate any code in Xarray; all you need to do is: - Create a class that inherits from Xarray :py:class:`~xarray.backends.BackendEntrypoint` and implements the method ``open_dataset`` see :ref:`RST backend_entrypoint` diff --git a/doc/user-guide/computation.rst b/doc/user-guide/computation.rst index f99d41be538..23ecbedf61e 100644 --- a/doc/user-guide/computation.rst +++ b/doc/user-guide/computation.rst @@ -426,7 +426,7 @@ However, the functions also take missing values in the data into account: .. ipython:: python - data = xr.DataArray([np.NaN, 2, 4]) + data = xr.DataArray([np.nan, 2, 4]) weights = xr.DataArray([8, 1, 1]) data.weighted(weights).mean() @@ -444,7 +444,7 @@ If the weights add up to to 0, ``sum`` returns 0: data.weighted(weights).sum() -and ``mean``, ``std`` and ``var`` return ``NaN``: +and ``mean``, ``std`` and ``var`` return ``nan``: .. ipython:: python diff --git a/doc/user-guide/dask.rst b/doc/user-guide/dask.rst index 27e7449b7c3..f71969066f9 100644 --- a/doc/user-guide/dask.rst +++ b/doc/user-guide/dask.rst @@ -296,6 +296,12 @@ loaded into Dask or not: Automatic parallelization with ``apply_ufunc`` and ``map_blocks`` ----------------------------------------------------------------- +.. tip:: + + Some problems can become embarassingly parallel and thus easy to parallelize + automatically by rechunking to a frequency, e.g. ``ds.chunk(time=TimeResampler("YE"))``. + See :py:meth:`Dataset.chunk` for more. + Almost all of xarray's built-in operations work on Dask arrays. If you want to use a function that isn't wrapped by xarray, and have it applied in parallel on each block of your xarray object, you have three options: @@ -551,6 +557,16 @@ larger chunksizes. Check out the `dask documentation on chunks `_. +.. tip:: + + Many time domain problems become amenable to an embarassingly parallel or blockwise solution + (e.g. using :py:func:`xarray.map_blocks`, :py:func:`dask.array.map_blocks`, or + :py:func:`dask.array.blockwise`) by rechunking to a frequency along the time dimension. + Provide :py:class:`xarray.groupers.TimeResampler` objects to :py:meth:`Dataset.chunk` to do so. + For example ``ds.chunk(time=TimeResampler("MS"))`` will set the chunks so that a month of + data is contained in one chunk. The resulting chunk sizes need not be uniform, depending on + the frequency of the data, and the calendar. + Optimization Tips ----------------- diff --git a/doc/user-guide/groupby.rst b/doc/user-guide/groupby.rst index 1ad2d52fc00..783b4f13c2e 100644 --- a/doc/user-guide/groupby.rst +++ b/doc/user-guide/groupby.rst @@ -1,3 +1,5 @@ +.. currentmodule:: xarray + .. _groupby: GroupBy: Group and Bin Data @@ -15,8 +17,8 @@ __ https://www.jstatsoft.org/v40/i01/paper - Apply some function to each group. - Combine your groups back into a single data object. -Group by operations work on both :py:class:`~xarray.Dataset` and -:py:class:`~xarray.DataArray` objects. Most of the examples focus on grouping by +Group by operations work on both :py:class:`Dataset` and +:py:class:`DataArray` objects. Most of the examples focus on grouping by a single one-dimensional variable, although support for grouping over a multi-dimensional variable has recently been implemented. Note that for one-dimensional data, it is usually faster to rely on pandas' implementation of @@ -24,10 +26,11 @@ the same pipeline. .. tip:: - To substantially improve the performance of GroupBy operations, particularly - with dask `install the flox package `_. flox + `Install the flox package `_ to substantially improve the performance + of GroupBy operations, particularly with dask. flox `extends Xarray's in-built GroupBy capabilities `_ - by allowing grouping by multiple variables, and lazy grouping by dask arrays. If installed, Xarray will automatically use flox by default. + by allowing grouping by multiple variables, and lazy grouping by dask arrays. + If installed, Xarray will automatically use flox by default. Split ~~~~~ @@ -87,7 +90,7 @@ Binning Sometimes you don't want to use all the unique values to determine the groups but instead want to "bin" the data into coarser groups. You could always create a customized coordinate, but xarray facilitates this via the -:py:meth:`~xarray.Dataset.groupby_bins` method. +:py:meth:`Dataset.groupby_bins` method. .. ipython:: python @@ -110,7 +113,7 @@ Apply ~~~~~ To apply a function to each group, you can use the flexible -:py:meth:`~xarray.core.groupby.DatasetGroupBy.map` method. The resulting objects are automatically +:py:meth:`core.groupby.DatasetGroupBy.map` method. The resulting objects are automatically concatenated back together along the group axis: .. ipython:: python @@ -121,8 +124,8 @@ concatenated back together along the group axis: arr.groupby("letters").map(standardize) -GroupBy objects also have a :py:meth:`~xarray.core.groupby.DatasetGroupBy.reduce` method and -methods like :py:meth:`~xarray.core.groupby.DatasetGroupBy.mean` as shortcuts for applying an +GroupBy objects also have a :py:meth:`core.groupby.DatasetGroupBy.reduce` method and +methods like :py:meth:`core.groupby.DatasetGroupBy.mean` as shortcuts for applying an aggregation function: .. ipython:: python @@ -183,7 +186,7 @@ Iterating and Squeezing Previously, Xarray defaulted to squeezing out dimensions of size one when iterating over a GroupBy object. This behaviour is being removed. You can always squeeze explicitly later with the Dataset or DataArray -:py:meth:`~xarray.DataArray.squeeze` methods. +:py:meth:`DataArray.squeeze` methods. .. ipython:: python @@ -217,7 +220,7 @@ __ https://cfconventions.org/cf-conventions/v1.6.0/cf-conventions.html#_two_dime da.groupby("lon").map(lambda x: x - x.mean(), shortcut=False) Because multidimensional groups have the ability to generate a very large -number of bins, coarse-binning via :py:meth:`~xarray.Dataset.groupby_bins` +number of bins, coarse-binning via :py:meth:`Dataset.groupby_bins` may be desirable: .. ipython:: python @@ -232,3 +235,71 @@ applying your function, and then unstacking the result: stacked = da.stack(gridcell=["ny", "nx"]) stacked.groupby("gridcell").sum(...).unstack("gridcell") + +.. _groupby.groupers: + +Grouper Objects +~~~~~~~~~~~~~~~ + +Both ``groupby_bins`` and ``resample`` are specializations of the core ``groupby`` operation for binning, +and time resampling. Many problems demand more complex GroupBy application: for example, grouping by multiple +variables with a combination of categorical grouping, binning, and resampling; or more specializations like +spatial resampling; or more complex time grouping like special handling of seasons, or the ability to specify +custom seasons. To handle these use-cases and more, Xarray is evolving to providing an +extension point using ``Grouper`` objects. + +.. tip:: + + See the `grouper design`_ doc for more detail on the motivation and design ideas behind + Grouper objects. + +.. _grouper design: https://github.com/pydata/xarray/blob/main/design_notes/grouper_objects.md + +For now Xarray provides three specialized Grouper objects: + +1. :py:class:`groupers.UniqueGrouper` for categorical grouping +2. :py:class:`groupers.BinGrouper` for binned grouping +3. :py:class:`groupers.TimeResampler` for resampling along a datetime coordinate + +These provide functionality identical to the existing ``groupby``, ``groupby_bins``, and ``resample`` methods. +That is, + +.. code-block:: python + + ds.groupby("x") + +is identical to + +.. code-block:: python + + from xarray.groupers import UniqueGrouper + + ds.groupby(x=UniqueGrouper()) + +; and + +.. code-block:: python + + ds.groupby_bins("x", bins=bins) + +is identical to + +.. code-block:: python + + from xarray.groupers import BinGrouper + + ds.groupby(x=BinGrouper(bins)) + +and + +.. code-block:: python + + ds.resample(time="ME") + +is identical to + +.. code-block:: python + + from xarray.groupers import TimeResampler + + ds.resample(time=TimeResampler("ME")) diff --git a/doc/user-guide/interpolation.rst b/doc/user-guide/interpolation.rst index 311e1bf0129..53d8a52b342 100644 --- a/doc/user-guide/interpolation.rst +++ b/doc/user-guide/interpolation.rst @@ -292,8 +292,8 @@ Let's see how :py:meth:`~xarray.DataArray.interp` works on real data. axes[0].set_title("Raw data") # Interpolated data - new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.sizes["lon"] * 4) - new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.sizes["lat"] * 4) + new_lon = np.linspace(ds.lon[0].item(), ds.lon[-1].item(), ds.sizes["lon"] * 4) + new_lat = np.linspace(ds.lat[0].item(), ds.lat[-1].item(), ds.sizes["lat"] * 4) dsi = ds.interp(lat=new_lat, lon=new_lon) dsi.air.plot(ax=axes[1]) @savefig interpolation_sample3.png width=8in diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst index da414bc383e..fabff1000d7 100644 --- a/doc/user-guide/io.rst +++ b/doc/user-guide/io.rst @@ -19,6 +19,81 @@ format (recommended). np.random.seed(123456) +You can `read different types of files `_ +in `xr.open_dataset` by specifying the engine to be used: + +.. ipython:: python + :okexcept: + :suppress: + + import xarray as xr + + xr.open_dataset("my_file.grib", engine="cfgrib") + +The "engine" provides a set of instructions that tells xarray how +to read the data and pack them into a `dataset` (or `dataarray`). +These instructions are stored in an underlying "backend". + +Xarray comes with several backends that cover many common data formats. +Many more backends are available via external libraries, or you can `write your own `_. +This diagram aims to help you determine - based on the format of the file you'd like to read - +which type of backend you're using and how to use it. + +Text and boxes are clickable for more information. +Following the diagram is detailed information on many popular backends. +You can learn more about using and developing backends in the +`Xarray tutorial JupyterBook `_. + +.. mermaid:: + :alt: Flowchart illustrating how to choose the right backend engine to read your data + + flowchart LR + built-in-eng["""Is your data stored in one of these formats? + - netCDF4 (netcdf4) + - netCDF3 (scipy) + - Zarr (zarr) + - DODS/OPeNDAP (pydap) + - HDF5 (h5netcdf) + """] + + built-in("""You're in luck! Xarray bundles a backend for this format. + Open data using xr.open_dataset(). We recommend + always setting the engine you want to use.""") + + installed-eng["""One of these formats? + - GRIB (cfgrib) + - TileDB (tiledb) + - GeoTIFF, JPEG-2000, ESRI-hdf (rioxarray, via GDAL) + - Sentinel-1 SAFE (xarray-sentinel) + """] + + installed("""Install the package indicated in parentheses to your + Python environment. Restart the kernel and use + xr.open_dataset(files, engine='rioxarray').""") + + other("""Ask around to see if someone in your data community + has created an Xarray backend for your data type. + If not, you may need to create your own or consider + exporting your data to a more common format.""") + + built-in-eng -->|Yes| built-in + built-in-eng -->|No| installed-eng + + installed-eng -->|Yes| installed + installed-eng -->|No| other + + click built-in-eng "https://docs.xarray.dev/en/stable/getting-started-guide/faq.html#how-do-i-open-format-x-file-as-an-xarray-dataset" + click other "https://docs.xarray.dev/en/stable/internals/how-to-add-new-backend.html" + + classDef quesNodefmt fill:#9DEEF4,stroke:#206C89,text-align:left + class built-in-eng,installed-eng quesNodefmt + + classDef ansNodefmt fill:#FFAA05,stroke:#E37F17,text-align:left,white-space:nowrap + class built-in,installed,other ansNodefmt + + linkStyle default font-size:20pt,color:#206C89 + + .. _io.netcdf: netCDF @@ -985,6 +1060,59 @@ reads. Because this fall-back option is so much slower, xarray issues a instead of falling back to try reading non-consolidated metadata. +.. _io.kerchunk: + +Kerchunk +-------- + +`Kerchunk `_ is a Python library +that allows you to access chunked and compressed data formats (such as NetCDF3, NetCDF4, HDF5, GRIB2, TIFF & FITS), +many of which are primary data formats for many data archives, by viewing the +whole archive as an ephemeral `Zarr`_ dataset which allows for parallel, chunk-specific access. + +Instead of creating a new copy of the dataset in the Zarr spec/format or +downloading the files locally, Kerchunk reads through the data archive and extracts the +byte range and compression information of each chunk and saves as a ``reference``. +These references are then saved as ``json`` files or ``parquet`` (more efficient) +for later use. You can view some of these stored in the `references` +directory `here `_. + + +.. note:: + These references follow this `specification `_. + Packages like `kerchunk`_ and `virtualizarr `_ + help in creating and reading these references. + + +Reading these data archives becomes really easy with ``kerchunk`` in combination +with ``xarray``, especially when these archives are large in size. A single combined +reference can refer to thousands of the original data files present in these archives. +You can view the whole dataset with from this `combined reference` using the above packages. + +The following example shows opening a combined references generated from a ``.hdf`` file stored locally. + +.. ipython:: python + + storage_options = { + "target_protocol": "file", + } + + # add the `remote_protocol` key in `storage_options` if you're accessing a file remotely + + ds1 = xr.open_dataset( + "./combined.json", + engine="kerchunk", + storage_options=storage_options, + ) + + ds1 + +.. note:: + + You can refer to the `project pythia kerchunk cookbook `_ + and the `pangeo guide on kerchunk `_ for more information. + + .. _io.iris: Iris diff --git a/doc/user-guide/terminology.rst b/doc/user-guide/terminology.rst index 55937310827..140f3dd1faa 100644 --- a/doc/user-guide/terminology.rst +++ b/doc/user-guide/terminology.rst @@ -221,7 +221,7 @@ complete examples, please consult the relevant documentation.* combined_ds lazy - Lazily-evaluated operations do not load data into memory until necessary.Instead of doing calculations + Lazily-evaluated operations do not load data into memory until necessary. Instead of doing calculations right away, xarray lets you plan what calculations you want to do, like finding the average temperature in a dataset.This planning is called "lazy evaluation." Later, when you're ready to see the final result, you tell xarray, "Okay, go ahead and do those calculations now!" diff --git a/doc/user-guide/testing.rst b/doc/user-guide/testing.rst index 13279eccb0b..d82d9d7d7d9 100644 --- a/doc/user-guide/testing.rst +++ b/doc/user-guide/testing.rst @@ -239,9 +239,10 @@ If the array type you want to generate has an array API-compliant top-level name you can use this neat trick: .. ipython:: python - :okwarning: - from numpy import array_api as xp # available in numpy 1.26.0 + import numpy as xp # compatible in numpy 2.0 + + # use `import numpy.array_api as xp` in numpy>=1.23,<2.0 from hypothesis.extra.array_api import make_strategies_namespace diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0c401c2348e..4cd34c4cf54 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -15,17 +15,14 @@ What's New np.random.seed(123456) -.. _whats-new.2024.06.1: +.. _whats-new.2024.07.1: -v2024.06.1 (unreleased) +v2024.07.1 (unreleased) ----------------------- New Features ~~~~~~~~~~~~ -- Allow chunking for arrays with duplicated dimension names (:issue:`8759`, :pull:`9099`). - By `Martin Raspaud `_. -- Extract the source url from fsspec objects (:issue:`9142`, :pull:`8923`). - By `Justus Magin `_. + Breaking changes ~~~~~~~~~~~~~~~~ @@ -35,6 +32,77 @@ Deprecations ~~~~~~~~~~~~ +Bug fixes +~~~~~~~~~ + +- Fix bug causing `DataTree.from_dict` to be sensitive to insertion order (:issue:`9276`, :pull:`9292`). + By `Tom Nicholas `_. +- Fix resampling error with monthly, quarterly, or yearly frequencies with + cftime when the time bins straddle the date "0001-01-01". For example, this + can happen in certain circumstances when the time coordinate contains the + date "0001-01-01". (:issue:`9108`, :pull:`9116`) By `Spencer Clark + `_ and `Deepak Cherian + `_. + +Documentation +~~~~~~~~~~~~~ + + +Internal Changes +~~~~~~~~~~~~~~~~ + + + +.. _whats-new.2024.07.0: + +v2024.07.0 (Jul 30, 2024) +------------------------- +This release extends the API for groupby operations with various `grouper objects `, and includes improvements to the documentation and numerous bugfixes. + +Thanks to the 22 contributors to this release: +Alfonso Ladino, ChrisCleaner, David Hoese, Deepak Cherian, Dieter Werthmüller, Illviljan, Jessica Scheick, Joel Jaeschke, Justus Magin, K. Arthur Endsley, Kai Mühlbauer, Mark Harfouche, Martin Raspaud, Mathijs Verhaegh, Maximilian Roos, Michael Niklas, Michał Górny, Moritz Schreiber, Pontus Lurcock, Spencer Clark, Stephan Hoyer and Tom Nicholas + +New Features +~~~~~~~~~~~~ +- Use fastpath when grouping both montonically increasing and decreasing variable + in :py:class:`GroupBy` (:issue:`6220`, :pull:`7427`). + By `Joel Jaeschke `_. +- Introduce new :py:class:`groupers.UniqueGrouper`, :py:class:`groupers.BinGrouper`, and + :py:class:`groupers.TimeResampler` objects as a step towards supporting grouping by + multiple variables. See the `docs ` and the `grouper design doc + `_ for more. + (:issue:`6610`, :pull:`8840`). + By `Deepak Cherian `_. +- Allow rechunking to a frequency using ``Dataset.chunk(time=TimeResampler("YE"))`` syntax. (:issue:`7559`, :pull:`9109`) + Such rechunking allows many time domain analyses to be executed in an embarassingly parallel fashion. + By `Deepak Cherian `_. +- Allow per-variable specification of ```mask_and_scale``, ``decode_times``, ``decode_timedelta`` + ``use_cftime`` and ``concat_characters`` params in :py:func:`~xarray.open_dataset` (:pull:`9218`). + By `Mathijs Verhaegh `_. +- Allow chunking for arrays with duplicated dimension names (:issue:`8759`, :pull:`9099`). + By `Martin Raspaud `_. +- Extract the source url from fsspec objects (:issue:`9142`, :pull:`8923`). + By `Justus Magin `_. +- Add :py:meth:`DataArray.drop_attrs` & :py:meth:`Dataset.drop_attrs` methods, + to return an object without ``attrs``. A ``deep`` parameter controls whether + variables' ``attrs`` are also dropped. + By `Maximilian Roos `_. (:pull:`8288`) + +Breaking changes +~~~~~~~~~~~~~~~~ +- The ``base`` and ``loffset`` parameters to :py:meth:`Dataset.resample` and + :py:meth:`DataArray.resample` are now removed. These parameters have been deprecated since + v2023.03.0. Using the ``origin`` or ``offset`` parameters is recommended as a replacement for + using the ``base`` parameter and using time offset arithmetic is recommended as a replacement for + using the ``loffset`` parameter. (:pull:`9233`) + By `Deepak Cherian `_. +- The ``squeeze`` kwarg to ``groupby`` is now ignored. This has been the source of some + quite confusing behaviour and has been deprecated since v2024.01.0. `groupby`` behavior is now + always consistent with the existing ``.groupby(..., squeeze=False)`` behavior. No errors will + be raised if `squeeze=False`. (:pull:`9280`) + By `Deepak Cherian `_. + + Bug fixes ~~~~~~~~~ - Fix scatter plot broadcasting unneccesarily. (:issue:`9129`, :pull:`9206`) @@ -45,28 +113,39 @@ Bug fixes By `Pontus Lurcock `_. - Allow diffing objects with array attributes on variables (:issue:`9153`, :pull:`9169`). By `Justus Magin `_. +- ``numpy>=2`` compatibility in the ``netcdf4`` backend (:pull:`9136`). + By `Justus Magin `_ and `Kai Mühlbauer `_. - Promote floating-point numeric datetimes before decoding (:issue:`9179`, :pull:`9182`). By `Justus Magin `_. -- Fiy static typing of tolerance arguments by allowing `str` type (:issue:`8892`, :pull:`9194`). +- Address regression introduced in :pull:`9002` that prevented objects returned + by py:meth:`DataArray.convert_calendar` to be indexed by a time index in + certain circumstances (:issue:`9138`, :pull:`9192`). + By `Mark Harfouche `_ and `Spencer Clark `_. +- Fix static typing of tolerance arguments by allowing `str` type (:issue:`8892`, :pull:`9194`). By `Michael Niklas `_. - Dark themes are now properly detected for ``html[data-theme=dark]``-tags (:pull:`9200`). By `Dieter Werthmüller `_. - Reductions no longer fail for ``np.complex_`` dtype arrays when numbagg is - installed. - By `Maximilian Roos `_ + installed. (:pull:`9210`) + By `Maximilian Roos `_. Documentation ~~~~~~~~~~~~~ -- Adds a flow-chart diagram to help users navigate help resources (`Discussion #8990 `_). +- Adds intro to backend section of docs, including a flow-chart to navigate types of backends (:pull:`9175`). + By `Jessica Scheick `_. +- Adds a flow-chart diagram to help users navigate help resources (:discussion:`8990`, :pull:`9147`). By `Jessica Scheick `_. - Improvements to Zarr & chunking docs (:pull:`9139`, :pull:`9140`, :pull:`9132`) - By `Maximilian Roos `_ - + By `Maximilian Roos `_. +- Fix copybutton for multi line examples and double digit ipython cell numbers (:pull:`9264`). + By `Moritz Schreiber `_. Internal Changes ~~~~~~~~~~~~~~~~ +- Enable typing checks of pandas (:pull:`9213`). + By `Michael Niklas `_. .. _whats-new.2024.06.0: diff --git a/properties/__init__.py b/properties/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/pyproject.toml b/pyproject.toml index 2ada0c1c171..3eafcda7670 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"] [tool.mypy] enable_error_code = "redundant-self" exclude = [ + 'build', 'xarray/util/generate_.*\.py', 'xarray/datatree_/doc/.*\.py', ] @@ -119,7 +120,6 @@ module = [ "netCDF4.*", "netcdftime.*", "opt_einsum.*", - "pandas.*", "pint.*", "pooch.*", "pyarrow.*", diff --git a/xarray/__init__.py b/xarray/__init__.py index 0c0d5995f72..dd6e1cda1a9 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -1,6 +1,6 @@ from importlib.metadata import version as _version -from xarray import testing, tutorial +from xarray import groupers, testing, tutorial from xarray.backends.api import ( load_dataarray, load_dataset, @@ -55,6 +55,7 @@ # `mypy --strict` running in projects that import xarray. __all__ = ( # Sub-packages + "groupers", "testing", "tutorial", # Top-level functions diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 521bdf65e6a..ece60a2b161 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -398,11 +398,11 @@ def open_dataset( chunks: T_Chunks = None, cache: bool | None = None, decode_cf: bool | None = None, - mask_and_scale: bool | None = None, - decode_times: bool | None = None, - decode_timedelta: bool | None = None, - use_cftime: bool | None = None, - concat_characters: bool | None = None, + mask_and_scale: bool | Mapping[str, bool] | None = None, + decode_times: bool | Mapping[str, bool] | None = None, + decode_timedelta: bool | Mapping[str, bool] | None = None, + use_cftime: bool | Mapping[str, bool] | None = None, + concat_characters: bool | Mapping[str, bool] | None = None, decode_coords: Literal["coordinates", "all"] | bool | None = None, drop_variables: str | Iterable[str] | None = None, inline_array: bool = False, @@ -451,25 +451,31 @@ def open_dataset( decode_cf : bool, optional Whether to decode these variables, assuming they were saved according to CF conventions. - mask_and_scale : bool, optional + mask_and_scale : bool or dict-like, optional If True, replace array values equal to `_FillValue` with NA and scale values according to the formula `original_values * scale_factor + add_offset`, where `_FillValue`, `scale_factor` and `add_offset` are taken from variable attributes (if they exist). If the `_FillValue` or `missing_value` attribute contains multiple values a warning will be issued and all array values matching one of the multiple values will - be replaced by NA. This keyword may not be supported by all the backends. - decode_times : bool, optional + be replaced by NA. Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + decode_times : bool or dict-like, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. + Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. This keyword may not be supported by all the backends. - decode_timedelta : bool, optional + decode_timedelta : bool or dict-like, optional If True, decode variables and coordinates with time units in {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} into timedelta objects. If False, leave them encoded as numbers. If None (default), assume the same value of decode_time. + Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. This keyword may not be supported by all the backends. - use_cftime: bool, optional + use_cftime: bool or dict-like, optional Only relevant if encoded dates come from a standard calendar (e.g. "gregorian", "proleptic_gregorian", "standard", or not specified). If None (default), attempt to decode times to @@ -478,12 +484,16 @@ def open_dataset( ``cftime.datetime`` objects, regardless of whether or not they can be represented using ``np.datetime64[ns]`` objects. If False, always decode times to ``np.datetime64[ns]`` objects; if this is not possible - raise an error. This keyword may not be supported by all the backends. - concat_characters : bool, optional + raise an error. Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + concat_characters : bool or dict-like, optional If True, concatenate along the last dimension of character arrays to form string arrays. Dimensions will only be concatenated over (and removed) if they have no corresponding variable and if they are only used as the last dimension of character arrays. + Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. This keyword may not be supported by all the backends. decode_coords : bool or {"coordinates", "all"}, optional Controls which variables are set as coordinate variables: diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index cd6bde45caa..3926ac051f5 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -438,7 +438,14 @@ def open_datatree( drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, + format=None, group: str | Iterable[str] | Callable | None = None, + lock=None, + invalid_netcdf=None, + phony_dims=None, + decode_vlen_strings=True, + driver=None, + driver_kwds=None, **kwargs, ) -> DataTree: from xarray.backends.api import open_dataset @@ -450,7 +457,14 @@ def open_datatree( filename_or_obj = _normalize_path(filename_or_obj) store = H5NetCDFStore.open( filename_or_obj, + format=format, group=group, + lock=lock, + invalid_netcdf=invalid_netcdf, + phony_dims=phony_dims, + decode_vlen_strings=decode_vlen_strings, + driver=driver, + driver_kwds=driver_kwds, ) if group: parent = NodePath("/") / NodePath(group) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index f8dd1c96572..302c002a779 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -680,6 +680,12 @@ def open_datatree( use_cftime=None, decode_timedelta=None, group: str | Iterable[str] | Callable | None = None, + format="NETCDF4", + clobber=True, + diskless=False, + persist=False, + lock=None, + autoclose=False, **kwargs, ) -> DataTree: from xarray.backends.api import open_dataset @@ -691,6 +697,12 @@ def open_datatree( store = NetCDF4DataStore.open( filename_or_obj, group=group, + format=format, + clobber=clobber, + diskless=diskless, + persist=persist, + lock=lock, + autoclose=autoclose, ) if group: parent = NodePath("/") / NodePath(group) diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index a62ca6c9862..f4890015040 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -204,6 +204,9 @@ def get_backend(engine: str | type[BackendEntrypoint]) -> BackendEntrypoint: if engine not in engines: raise ValueError( f"unrecognized engine {engine} must be one of: {list(engines)}" + "To install additional dependencies, see:\n" + "https://docs.xarray.dev/en/stable/user-guide/io.html \n" + "https://docs.xarray.dev/en/stable/getting-started-guide/installing.html" ) backend = engines[engine] elif isinstance(engine, type) and issubclass(engine, BackendEntrypoint): diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 85a1a6e214c..8c526ddb58d 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -806,7 +806,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No for k2, v2 in attrs.items(): encoded_attrs[k2] = self.encode_attribute(v2) - if coding.strings.check_vlen_dtype(dtype) == str: + if coding.strings.check_vlen_dtype(dtype) is str: dtype = str if self._write_empty is not None: diff --git a/xarray/coding/calendar_ops.py b/xarray/coding/calendar_ops.py index c4fe9e1f4ae..6f492e78bf9 100644 --- a/xarray/coding/calendar_ops.py +++ b/xarray/coding/calendar_ops.py @@ -5,7 +5,10 @@ from xarray.coding.cftime_offsets import date_range_like, get_date_type from xarray.coding.cftimeindex import CFTimeIndex -from xarray.coding.times import _should_cftime_be_used, convert_times +from xarray.coding.times import ( + _should_cftime_be_used, + convert_times, +) from xarray.core.common import _contains_datetime_like_objects, is_np_datetime_like try: @@ -222,6 +225,13 @@ def convert_calendar( # Remove NaN that where put on invalid dates in target calendar out = out.where(out[dim].notnull(), drop=True) + if use_cftime: + # Reassign times to ensure time index of output is a CFTimeIndex + # (previously it was an Index due to the presence of NaN values). + # Note this is not needed in the case that the output time index is + # a DatetimeIndex, since DatetimeIndexes can handle NaN values. + out[dim] = CFTimeIndex(out[dim].data) + if missing is not None: time_target = date_range_like(time, calendar=calendar, use_cftime=use_cftime) out = out.reindex({dim: time_target}, fill_value=missing) diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index c2712569782..b0caf2a0cd3 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -43,9 +43,11 @@ from __future__ import annotations import re +import warnings +from collections.abc import Mapping from datetime import datetime, timedelta from functools import partial -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING, ClassVar, Literal import numpy as np import pandas as pd @@ -74,7 +76,10 @@ if TYPE_CHECKING: - from xarray.core.types import InclusiveOptions, SideOptions + from xarray.core.types import InclusiveOptions, Self, SideOptions, TypeAlias + + +DayOption: TypeAlias = Literal["start", "end"] def _nanosecond_precision_timestamp(*args, **kwargs): @@ -109,9 +114,10 @@ def get_date_type(calendar, use_cftime=True): class BaseCFTimeOffset: _freq: ClassVar[str | None] = None - _day_option: ClassVar[str | None] = None + _day_option: ClassVar[DayOption | None] = None + n: int - def __init__(self, n: int = 1): + def __init__(self, n: int = 1) -> None: if not isinstance(n, int): raise TypeError( "The provided multiple 'n' must be an integer. " @@ -119,13 +125,15 @@ def __init__(self, n: int = 1): ) self.n = n - def rule_code(self): + def rule_code(self) -> str | None: return self._freq - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, BaseCFTimeOffset): + return NotImplemented return self.n == other.n and self.rule_code() == other.rule_code() - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self == other def __add__(self, other): @@ -142,12 +150,12 @@ def __sub__(self, other): else: return NotImplemented - def __mul__(self, other): + def __mul__(self, other: int) -> Self: if not isinstance(other, int): return NotImplemented return type(self)(n=other * self.n) - def __neg__(self): + def __neg__(self) -> Self: return self * -1 def __rmul__(self, other): @@ -161,10 +169,10 @@ def __rsub__(self, other): raise TypeError("Cannot subtract cftime offsets of differing types") return -self + other - def __apply__(self): + def __apply__(self, other): return NotImplemented - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" test_date = (self + date) - self @@ -197,22 +205,21 @@ def _get_offset_day(self, other): class Tick(BaseCFTimeOffset): # analogous https://github.com/pandas-dev/pandas/blob/ccb25ab1d24c4fb9691270706a59c8d319750870/pandas/_libs/tslibs/offsets.pyx#L806 - def _next_higher_resolution(self): + def _next_higher_resolution(self) -> Tick: self_type = type(self) - if self_type not in [Day, Hour, Minute, Second, Millisecond]: - raise ValueError("Could not convert to integer offset at any resolution") - if type(self) is Day: + if self_type is Day: return Hour(self.n * 24) - if type(self) is Hour: + if self_type is Hour: return Minute(self.n * 60) - if type(self) is Minute: + if self_type is Minute: return Second(self.n * 60) - if type(self) is Second: + if self_type is Second: return Millisecond(self.n * 1000) - if type(self) is Millisecond: + if self_type is Millisecond: return Microsecond(self.n * 1000) + raise ValueError("Could not convert to integer offset at any resolution") - def __mul__(self, other): + def __mul__(self, other: int | float) -> Tick: if not isinstance(other, (int, float)): return NotImplemented if isinstance(other, float): @@ -227,12 +234,12 @@ def __mul__(self, other): return new_self * other return type(self)(n=other * self.n) - def as_timedelta(self): + def as_timedelta(self) -> timedelta: """All Tick subclasses must implement an as_timedelta method.""" raise NotImplementedError -def _get_day_of_month(other, day_option): +def _get_day_of_month(other, day_option: DayOption) -> int: """Find the day in `other`'s month that satisfies a BaseCFTimeOffset's onOffset policy, as described by the `day_option` argument. @@ -252,22 +259,12 @@ def _get_day_of_month(other, day_option): if day_option == "start": return 1 elif day_option == "end": - return _days_in_month(other) + return other.daysinmonth elif day_option is None: # Note: unlike `_shift_month`, _get_day_of_month does not # allow day_option = None raise NotImplementedError() - else: - raise ValueError(day_option) - - -def _days_in_month(date): - """The number of days in the month of the given date""" - if date.month == 12: - reference = type(date)(date.year + 1, 1, 1) - else: - reference = type(date)(date.year, date.month + 1, 1) - return (reference - timedelta(days=1)).day + raise ValueError(day_option) def _adjust_n_months(other_day, n, reference_day): @@ -293,30 +290,44 @@ def _adjust_n_years(other, n, month, reference_day): return n -def _shift_month(date, months, day_option="start"): +def _shift_month(date, months, day_option: DayOption = "start"): """Shift the date to a month start or end a given number of months away.""" if cftime is None: raise ModuleNotFoundError("No module named 'cftime'") + has_year_zero = date.has_year_zero delta_year = (date.month + months) // 12 month = (date.month + months) % 12 if month == 0: month = 12 delta_year = delta_year - 1 + + if not has_year_zero: + if date.year < 0 and date.year + delta_year >= 0: + delta_year = delta_year + 1 + elif date.year > 0 and date.year + delta_year <= 0: + delta_year = delta_year - 1 + year = date.year + delta_year - if day_option == "start": - day = 1 - elif day_option == "end": - reference = type(date)(year, month, 1) - day = _days_in_month(reference) - else: - raise ValueError(day_option) - return date.replace(year=year, month=month, day=day) + # Silence warnings associated with generating dates with years < 1. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="this date/calendar/year zero") + + if day_option == "start": + day = 1 + elif day_option == "end": + reference = type(date)(year, month, 1, has_year_zero=has_year_zero) + day = reference.daysinmonth + else: + raise ValueError(day_option) + return date.replace(year=year, month=month, day=day) -def roll_qtrday(other, n, month, day_option, modby=3): +def roll_qtrday( + other, n: int, month: int, day_option: DayOption, modby: int = 3 +) -> int: """Possibly increment or decrement the number of periods to shift based on rollforward/rollbackward conventions. @@ -357,7 +368,7 @@ def roll_qtrday(other, n, month, day_option, modby=3): return n -def _validate_month(month, default_month): +def _validate_month(month: int | None, default_month: int) -> int: result_month = default_month if month is None else month if not isinstance(result_month, int): raise TypeError( @@ -381,7 +392,7 @@ def __apply__(self, other): n = _adjust_n_months(other.day, self.n, 1) return _shift_month(other, n, "start") - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" return date.day == 1 @@ -391,13 +402,13 @@ class MonthEnd(BaseCFTimeOffset): _freq = "ME" def __apply__(self, other): - n = _adjust_n_months(other.day, self.n, _days_in_month(other)) + n = _adjust_n_months(other.day, self.n, other.daysinmonth) return _shift_month(other, n, "end") - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" - return date.day == _days_in_month(date) + return date.day == date.daysinmonth _MONTH_ABBREVIATIONS = { @@ -419,10 +430,10 @@ def onOffset(self, date): class QuarterOffset(BaseCFTimeOffset): """Quarter representation copied off of pandas/tseries/offsets.py""" - _freq: ClassVar[str] _default_month: ClassVar[int] + month: int - def __init__(self, n=1, month=None): + def __init__(self, n: int = 1, month: int | None = None) -> None: BaseCFTimeOffset.__init__(self, n) self.month = _validate_month(month, self._default_month) @@ -439,29 +450,28 @@ def __apply__(self, other): months = qtrs * 3 - months_since return _shift_month(other, months, self._day_option) - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" mod_month = (date.month - self.month) % 3 return mod_month == 0 and date.day == self._get_offset_day(date) - def __sub__(self, other): + def __sub__(self, other: Self) -> Self: if cftime is None: raise ModuleNotFoundError("No module named 'cftime'") if isinstance(other, cftime.datetime): raise TypeError("Cannot subtract cftime.datetime from offset.") - elif type(other) == type(self) and other.month == self.month: + if type(other) == type(self) and other.month == self.month: return type(self)(self.n - other.n, month=self.month) - else: - return NotImplemented + return NotImplemented def __mul__(self, other): if isinstance(other, float): return NotImplemented return type(self)(n=other * self.n, month=self.month) - def rule_code(self): + def rule_code(self) -> str: return f"{self._freq}-{_MONTH_ABBREVIATIONS[self.month]}" def __str__(self): @@ -519,11 +529,10 @@ def rollback(self, date): class YearOffset(BaseCFTimeOffset): - _freq: ClassVar[str] - _day_option: ClassVar[str] _default_month: ClassVar[int] + month: int - def __init__(self, n=1, month=None): + def __init__(self, n: int = 1, month: int | None = None) -> None: BaseCFTimeOffset.__init__(self, n) self.month = _validate_month(month, self._default_month) @@ -549,10 +558,10 @@ def __mul__(self, other): return NotImplemented return type(self)(n=other * self.n, month=self.month) - def rule_code(self): + def rule_code(self) -> str: return f"{self._freq}-{_MONTH_ABBREVIATIONS[self.month]}" - def __str__(self): + def __str__(self) -> str: return f"<{type(self).__name__}: n={self.n}, month={self.month}>" @@ -561,7 +570,7 @@ class YearBegin(YearOffset): _day_option = "start" _default_month = 1 - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" return date.day == 1 and date.month == self.month @@ -586,10 +595,10 @@ class YearEnd(YearOffset): _day_option = "end" _default_month = 12 - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" - return date.day == _days_in_month(date) and date.month == self.month + return date.day == date.daysinmonth and date.month == self.month def rollforward(self, date): """Roll date forward to nearest end of year""" @@ -609,7 +618,7 @@ def rollback(self, date): class Day(Tick): _freq = "D" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(days=self.n) def __apply__(self, other): @@ -619,7 +628,7 @@ def __apply__(self, other): class Hour(Tick): _freq = "h" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(hours=self.n) def __apply__(self, other): @@ -629,7 +638,7 @@ def __apply__(self, other): class Minute(Tick): _freq = "min" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(minutes=self.n) def __apply__(self, other): @@ -639,7 +648,7 @@ def __apply__(self, other): class Second(Tick): _freq = "s" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(seconds=self.n) def __apply__(self, other): @@ -649,7 +658,7 @@ def __apply__(self, other): class Millisecond(Tick): _freq = "ms" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(milliseconds=self.n) def __apply__(self, other): @@ -659,30 +668,32 @@ def __apply__(self, other): class Microsecond(Tick): _freq = "us" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(microseconds=self.n) def __apply__(self, other): return other + self.as_timedelta() -def _generate_anchored_offsets(base_freq, offset): - offsets = {} +def _generate_anchored_offsets( + base_freq: str, offset: type[YearOffset | QuarterOffset] +) -> dict[str, type[BaseCFTimeOffset]]: + offsets: dict[str, type[BaseCFTimeOffset]] = {} for month, abbreviation in _MONTH_ABBREVIATIONS.items(): anchored_freq = f"{base_freq}-{abbreviation}" - offsets[anchored_freq] = partial(offset, month=month) + offsets[anchored_freq] = partial(offset, month=month) # type: ignore[assignment] return offsets -_FREQUENCIES = { +_FREQUENCIES: Mapping[str, type[BaseCFTimeOffset]] = { "A": YearEnd, "AS": YearBegin, "Y": YearEnd, "YE": YearEnd, "YS": YearBegin, - "Q": partial(QuarterEnd, month=12), - "QE": partial(QuarterEnd, month=12), - "QS": partial(QuarterBegin, month=1), + "Q": partial(QuarterEnd, month=12), # type: ignore[dict-item] + "QE": partial(QuarterEnd, month=12), # type: ignore[dict-item] + "QS": partial(QuarterBegin, month=1), # type: ignore[dict-item] "M": MonthEnd, "ME": MonthEnd, "MS": MonthBegin, @@ -717,7 +728,9 @@ def _generate_anchored_offsets(base_freq, offset): CFTIME_TICKS = (Day, Hour, Minute, Second) -def _generate_anchored_deprecated_frequencies(deprecated, recommended): +def _generate_anchored_deprecated_frequencies( + deprecated: str, recommended: str +) -> dict[str, str]: pairs = {} for abbreviation in _MONTH_ABBREVIATIONS.values(): anchored_deprecated = f"{deprecated}-{abbreviation}" @@ -726,7 +739,7 @@ def _generate_anchored_deprecated_frequencies(deprecated, recommended): return pairs -_DEPRECATED_FREQUENICES = { +_DEPRECATED_FREQUENICES: dict[str, str] = { "A": "YE", "Y": "YE", "AS": "YS", @@ -759,16 +772,16 @@ def _emit_freq_deprecation_warning(deprecated_freq): emit_user_level_warning(message, FutureWarning) -def to_offset(freq, warn=True): +def to_offset(freq: BaseCFTimeOffset | str, warn: bool = True) -> BaseCFTimeOffset: """Convert a frequency string to the appropriate subclass of BaseCFTimeOffset.""" if isinstance(freq, BaseCFTimeOffset): return freq - else: - try: - freq_data = re.match(_PATTERN, freq).groupdict() - except AttributeError: - raise ValueError("Invalid frequency string provided") + + match = re.match(_PATTERN, freq) + if match is None: + raise ValueError("Invalid frequency string provided") + freq_data = match.groupdict() freq = freq_data["freq"] if warn and freq in _DEPRECATED_FREQUENICES: @@ -909,7 +922,9 @@ def _translate_closed_to_inclusive(closed): return inclusive -def _infer_inclusive(closed, inclusive): +def _infer_inclusive( + closed: NoDefault | SideOptions, inclusive: InclusiveOptions | None +) -> InclusiveOptions: """Follows code added in pandas #43504.""" if closed is not no_default and inclusive is not None: raise ValueError( @@ -917,9 +932,9 @@ def _infer_inclusive(closed, inclusive): "passed if argument `inclusive` is not None." ) if closed is not no_default: - inclusive = _translate_closed_to_inclusive(closed) - elif inclusive is None: - inclusive = "both" + return _translate_closed_to_inclusive(closed) + if inclusive is None: + return "both" return inclusive @@ -933,7 +948,7 @@ def cftime_range( closed: NoDefault | SideOptions = no_default, inclusive: None | InclusiveOptions = None, calendar="standard", -): +) -> CFTimeIndex: """Return a fixed frequency CFTimeIndex. Parameters diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 6898809e3b0..ef01f4cc79a 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -45,6 +45,7 @@ import re import warnings from datetime import timedelta +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -64,6 +65,10 @@ except ImportError: cftime = None +if TYPE_CHECKING: + from xarray.coding.cftime_offsets import BaseCFTimeOffset + from xarray.core.types import Self + # constants for cftimeindex.repr CFTIME_REPR_LENGTH = 19 @@ -495,24 +500,28 @@ def get_value(self, series, key): else: return series.iloc[self.get_loc(key)] - def __contains__(self, key): + def __contains__(self, key: Any) -> bool: """Adapted from pandas.tseries.base.DatetimeIndexOpsMixin.__contains__""" try: result = self.get_loc(key) return ( is_scalar(result) - or type(result) == slice - or (isinstance(result, np.ndarray) and result.size) + or isinstance(result, slice) + or (isinstance(result, np.ndarray) and result.size > 0) ) except (KeyError, TypeError, ValueError): return False - def contains(self, key): + def contains(self, key: Any) -> bool: """Needed for .loc based partial-string indexing""" return self.__contains__(key) - def shift(self, n: int | float, freq: str | timedelta): + def shift( # type: ignore[override] # freq is typed Any, we are more precise + self, + periods: int | float, + freq: str | timedelta | BaseCFTimeOffset | None = None, + ) -> Self: """Shift the CFTimeIndex a multiple of the given frequency. See the documentation for :py:func:`~xarray.cftime_range` for a @@ -520,9 +529,9 @@ def shift(self, n: int | float, freq: str | timedelta): Parameters ---------- - n : int, float if freq of days or below + periods : int, float if freq of days or below Periods to shift by - freq : str or datetime.timedelta + freq : str, datetime.timedelta or BaseCFTimeOffset A frequency string or datetime.timedelta object to shift by Returns @@ -546,33 +555,42 @@ def shift(self, n: int | float, freq: str | timedelta): CFTimeIndex([2000-02-01 12:00:00], dtype='object', length=1, calendar='standard', freq=None) """ - if isinstance(freq, timedelta): - return self + n * freq - elif isinstance(freq, str): - from xarray.coding.cftime_offsets import to_offset + from xarray.coding.cftime_offsets import BaseCFTimeOffset - return self + n * to_offset(freq) - else: + if freq is None: + # None type is required to be compatible with base pd.Index class raise TypeError( - f"'freq' must be of type str or datetime.timedelta, got {freq}." + f"`freq` argument cannot be None for {type(self).__name__}.shift" ) - def __add__(self, other): + if isinstance(freq, timedelta): + return self + periods * freq + + if isinstance(freq, (str, BaseCFTimeOffset)): + from xarray.coding.cftime_offsets import to_offset + + return self + periods * to_offset(freq) + + raise TypeError( + f"'freq' must be of type str or datetime.timedelta, got {type(freq)}." + ) + + def __add__(self, other) -> Self: if isinstance(other, pd.TimedeltaIndex): other = other.to_pytimedelta() - return CFTimeIndex(np.array(self) + other) + return type(self)(np.array(self) + other) - def __radd__(self, other): + def __radd__(self, other) -> Self: if isinstance(other, pd.TimedeltaIndex): other = other.to_pytimedelta() - return CFTimeIndex(other + np.array(self)) + return type(self)(other + np.array(self)) def __sub__(self, other): if _contains_datetime_timedeltas(other): - return CFTimeIndex(np.array(self) - other) - elif isinstance(other, pd.TimedeltaIndex): - return CFTimeIndex(np.array(self) - other.to_pytimedelta()) - elif _contains_cftime_datetimes(np.array(other)): + return type(self)(np.array(self) - other) + if isinstance(other, pd.TimedeltaIndex): + return type(self)(np.array(self) - other.to_pytimedelta()) + if _contains_cftime_datetimes(np.array(other)): try: return pd.TimedeltaIndex(np.array(self) - np.array(other)) except OUT_OF_BOUNDS_TIMEDELTA_ERRORS: @@ -580,8 +598,7 @@ def __sub__(self, other): "The time difference exceeds the range of values " "that can be expressed at the nanosecond resolution." ) - else: - return NotImplemented + return NotImplemented def __rsub__(self, other): try: diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index db95286f6aa..d16ec52d645 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -39,11 +39,11 @@ def check_vlen_dtype(dtype): def is_unicode_dtype(dtype): - return dtype.kind == "U" or check_vlen_dtype(dtype) == str + return dtype.kind == "U" or check_vlen_dtype(dtype) is str def is_bytes_dtype(dtype): - return dtype.kind == "S" or check_vlen_dtype(dtype) == bytes + return dtype.kind == "S" or check_vlen_dtype(dtype) is bytes class EncodedStringCoder(VariableCoder): @@ -104,7 +104,7 @@ def encode_string_array(string_array, encoding="utf-8"): def ensure_fixed_length_bytes(var: Variable) -> Variable: """Ensure that a variable with vlen bytes is converted to fixed width.""" - if check_vlen_dtype(var.dtype) == bytes: + if check_vlen_dtype(var.dtype) is bytes: dims, data, attrs, encoding = unpack_for_encoding(var) # TODO: figure out how to handle this with dask data = np.asarray(data, dtype=np.bytes_) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 34d4f9a23ad..badb9259b06 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -5,7 +5,7 @@ from collections.abc import Hashable from datetime import datetime, timedelta from functools import partial -from typing import TYPE_CHECKING, Callable, Union +from typing import Callable, Literal, Union, cast import numpy as np import pandas as pd @@ -22,7 +22,7 @@ ) from xarray.core import indexing from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like -from xarray.core.duck_array_ops import asarray +from xarray.core.duck_array_ops import asarray, ravel, reshape from xarray.core.formatting import first_n_items, format_timestamp, last_item from xarray.core.pdcompat import nanosecond_precision_timestamp from xarray.core.utils import emit_user_level_warning @@ -36,10 +36,9 @@ except ImportError: cftime = None -if TYPE_CHECKING: - from xarray.core.types import CFCalendar, T_DuckArray +from xarray.core.types import CFCalendar, NPDatetimeUnitOptions, T_DuckArray - T_Name = Union[Hashable, None] +T_Name = Union[Hashable, None] # standard calendars recognized by cftime _STANDARD_CALENDARS = {"standard", "gregorian", "proleptic_gregorian"} @@ -111,22 +110,25 @@ def _is_numpy_compatible_time_range(times): return True -def _netcdf_to_numpy_timeunit(units: str) -> str: +def _netcdf_to_numpy_timeunit(units: str) -> NPDatetimeUnitOptions: units = units.lower() if not units.endswith("s"): units = f"{units}s" - return { - "nanoseconds": "ns", - "microseconds": "us", - "milliseconds": "ms", - "seconds": "s", - "minutes": "m", - "hours": "h", - "days": "D", - }[units] + return cast( + NPDatetimeUnitOptions, + { + "nanoseconds": "ns", + "microseconds": "us", + "milliseconds": "ms", + "seconds": "s", + "minutes": "m", + "hours": "h", + "days": "D", + }[units], + ) -def _numpy_to_netcdf_timeunit(units: str) -> str: +def _numpy_to_netcdf_timeunit(units: NPDatetimeUnitOptions) -> str: return { "ns": "nanoseconds", "us": "microseconds", @@ -252,12 +254,12 @@ def _decode_datetime_with_pandas( "pandas." ) - time_units, ref_date = _unpack_netcdf_time_units(units) + time_units, ref_date_str = _unpack_netcdf_time_units(units) time_units = _netcdf_to_numpy_timeunit(time_units) try: # TODO: the strict enforcement of nanosecond precision Timestamps can be # relaxed when addressing GitHub issue #7493. - ref_date = nanosecond_precision_timestamp(ref_date) + ref_date = nanosecond_precision_timestamp(ref_date_str) except ValueError: # ValueError is raised by pd.Timestamp for non-ISO timestamp # strings, in which case we fall back to using cftime @@ -315,7 +317,7 @@ def decode_cf_datetime( cftime.num2date """ num_dates = np.asarray(num_dates) - flat_num_dates = num_dates.ravel() + flat_num_dates = ravel(num_dates) if calendar is None: calendar = "standard" @@ -348,7 +350,7 @@ def decode_cf_datetime( else: dates = _decode_datetime_with_pandas(flat_num_dates, units, calendar) - return dates.reshape(num_dates.shape) + return reshape(dates, num_dates.shape) def to_timedelta_unboxed(value, **kwargs): @@ -369,8 +371,8 @@ def decode_cf_timedelta(num_timedeltas, units: str) -> np.ndarray: """ num_timedeltas = np.asarray(num_timedeltas) units = _netcdf_to_numpy_timeunit(units) - result = to_timedelta_unboxed(num_timedeltas.ravel(), unit=units) - return result.reshape(num_timedeltas.shape) + result = to_timedelta_unboxed(ravel(num_timedeltas), unit=units) + return reshape(result, num_timedeltas.shape) def _unit_timedelta_cftime(units: str) -> timedelta: @@ -428,7 +430,7 @@ def infer_datetime_units(dates) -> str: 'hours', 'minutes' or 'seconds' (the first one that can evenly divide all unique time deltas in `dates`) """ - dates = np.asarray(dates).ravel() + dates = ravel(np.asarray(dates)) if np.asarray(dates).dtype == "datetime64[ns]": dates = to_datetime_unboxed(dates) dates = dates[pd.notnull(dates)] @@ -456,7 +458,7 @@ def infer_timedelta_units(deltas) -> str: {'days', 'hours', 'minutes' 'seconds'} (the first one that can evenly divide all unique time deltas in `deltas`) """ - deltas = to_timedelta_unboxed(np.asarray(deltas).ravel()) + deltas = to_timedelta_unboxed(ravel(np.asarray(deltas))) unique_timedeltas = np.unique(deltas[pd.notnull(deltas)]) return _infer_time_units_from_diff(unique_timedeltas) @@ -471,6 +473,7 @@ def cftime_to_nptime(times, raise_on_invalid: bool = True) -> np.ndarray: # TODO: the strict enforcement of nanosecond precision datetime values can # be relaxed when addressing GitHub issue #7493. new = np.empty(times.shape, dtype="M8[ns]") + dt: pd.Timestamp | Literal["NaT"] for i, t in np.ndenumerate(times): try: # Use pandas.Timestamp in place of datetime.datetime, because @@ -643,7 +646,7 @@ def encode_datetime(d): except TypeError: return np.nan if d is None else cftime.date2num(d, units, calendar) - return np.array([encode_datetime(d) for d in dates.ravel()]).reshape(dates.shape) + return reshape(np.array([encode_datetime(d) for d in ravel(dates)]), dates.shape) def cast_to_int_if_safe(num) -> np.ndarray: @@ -753,7 +756,7 @@ def _eagerly_encode_cf_datetime( # Wrap the dates in a DatetimeIndex to do the subtraction to ensure # an OverflowError is raised if the ref_date is too far away from # dates to be encoded (GH 2272). - dates_as_index = pd.DatetimeIndex(dates.ravel()) + dates_as_index = pd.DatetimeIndex(ravel(dates)) time_deltas = dates_as_index - ref_date # retrieve needed units to faithfully encode to int64 @@ -791,7 +794,7 @@ def _eagerly_encode_cf_datetime( floor_division = True num = _division(time_deltas, time_delta, floor_division) - num = num.values.reshape(dates.shape) + num = reshape(num.values, dates.shape) except (OutOfBoundsDatetime, OverflowError, ValueError): num = _encode_datetime_with_cftime(dates, units, calendar) @@ -879,7 +882,7 @@ def _eagerly_encode_cf_timedelta( units = data_units time_delta = _time_units_to_timedelta64(units) - time_deltas = pd.TimedeltaIndex(timedeltas.ravel()) + time_deltas = pd.TimedeltaIndex(ravel(timedeltas)) # retrieve needed units to faithfully encode to int64 needed_units = data_units @@ -911,7 +914,7 @@ def _eagerly_encode_cf_timedelta( floor_division = True num = _division(time_deltas, time_delta, floor_division) - num = num.values.reshape(timedeltas.shape) + num = reshape(num.values, timedeltas.shape) if dtype is not None: num = _cast_to_dtype_if_safe(num, dtype) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index d31cb6e626a..8a3afe650f2 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -516,9 +516,22 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: dims, data, attrs, encoding = unpack_for_encoding(variable) pop_to(encoding, attrs, "_Unsigned") - signed_dtype = np.dtype(f"i{data.dtype.itemsize}") + # we need the on-disk type here + # trying to get it from encoding, resort to an int with the same precision as data.dtype if not available + signed_dtype = np.dtype(encoding.get("dtype", f"i{data.dtype.itemsize}")) if "_FillValue" in attrs: - new_fill = signed_dtype.type(attrs["_FillValue"]) + try: + # user provided the on-disk signed fill + new_fill = signed_dtype.type(attrs["_FillValue"]) + except OverflowError: + # user provided the in-memory unsigned fill, convert to signed type + unsigned_dtype = np.dtype(f"u{signed_dtype.itemsize}") + # use view here to prevent OverflowError + new_fill = ( + np.array(attrs["_FillValue"], dtype=unsigned_dtype) + .view(signed_dtype) + .item() + ) attrs["_FillValue"] = new_fill data = duck_array_ops.astype(duck_array_ops.around(data), signed_dtype) @@ -535,10 +548,11 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: if unsigned == "true": unsigned_dtype = np.dtype(f"u{data.dtype.itemsize}") transform = partial(np.asarray, dtype=unsigned_dtype) - data = lazy_elemwise_func(data, transform, unsigned_dtype) if "_FillValue" in attrs: - new_fill = unsigned_dtype.type(attrs["_FillValue"]) - attrs["_FillValue"] = new_fill + new_fill = np.array(attrs["_FillValue"], dtype=data.dtype) + # use view here to prevent OverflowError + attrs["_FillValue"] = new_fill.view(unsigned_dtype).item() + data = lazy_elemwise_func(data, transform, unsigned_dtype) elif data.dtype.kind == "u": if unsigned == "false": signed_dtype = np.dtype(f"i{data.dtype.itemsize}") @@ -663,7 +677,7 @@ def encode(self): raise NotImplementedError def decode(self, variable: Variable, name: T_Name = None) -> Variable: - if variable.dtype == object and variable.encoding.get("dtype", False) == str: + if variable.dtype.kind == "O" and variable.encoding.get("dtype", False) is str: variable = variable.astype(variable.encoding["dtype"]) return variable else: diff --git a/xarray/conventions.py b/xarray/conventions.py index 6eff45c5b2d..d572b215d2d 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -2,7 +2,7 @@ from collections import defaultdict from collections.abc import Hashable, Iterable, Mapping, MutableMapping -from typing import TYPE_CHECKING, Any, Literal, Union +from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union import numpy as np import pandas as pd @@ -273,7 +273,7 @@ def decode_cf_variable( var = strings.CharacterArrayCoder().decode(var, name=name) var = strings.EncodedStringCoder().decode(var) - if original_dtype == object: + if original_dtype.kind == "O": var = variables.ObjectVLenStringCoder().decode(var) original_dtype = var.dtype @@ -384,16 +384,26 @@ def _update_bounds_encoding(variables: T_Variables) -> None: bounds_encoding.setdefault("calendar", encoding["calendar"]) +T = TypeVar("T") + + +def _item_or_default(obj: Mapping[Any, T] | T, key: Hashable, default: T) -> T: + """ + Return item by key if obj is mapping and key is present, else return default value. + """ + return obj.get(key, default) if isinstance(obj, Mapping) else obj + + def decode_cf_variables( variables: T_Variables, attributes: T_Attrs, - concat_characters: bool = True, - mask_and_scale: bool = True, - decode_times: bool = True, + concat_characters: bool | Mapping[str, bool] = True, + mask_and_scale: bool | Mapping[str, bool] = True, + decode_times: bool | Mapping[str, bool] = True, decode_coords: bool | Literal["coordinates", "all"] = True, drop_variables: T_DropVariables = None, - use_cftime: bool | None = None, - decode_timedelta: bool | None = None, + use_cftime: bool | Mapping[str, bool] | None = None, + decode_timedelta: bool | Mapping[str, bool] | None = None, ) -> tuple[T_Variables, T_Attrs, set[Hashable]]: """ Decode several CF encoded variables. @@ -431,7 +441,7 @@ def stackable(dim: Hashable) -> bool: if k in drop_variables: continue stack_char_dim = ( - concat_characters + _item_or_default(concat_characters, k, True) and v.dtype == "S1" and v.ndim > 0 and stackable(v.dims[-1]) @@ -440,12 +450,12 @@ def stackable(dim: Hashable) -> bool: new_vars[k] = decode_cf_variable( k, v, - concat_characters=concat_characters, - mask_and_scale=mask_and_scale, - decode_times=decode_times, + concat_characters=_item_or_default(concat_characters, k, True), + mask_and_scale=_item_or_default(mask_and_scale, k, True), + decode_times=_item_or_default(decode_times, k, True), stack_char_dim=stack_char_dim, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + use_cftime=_item_or_default(use_cftime, k, None), + decode_timedelta=_item_or_default(decode_timedelta, k, None), ) except Exception as e: raise type(e)(f"Failed to decode variable {k!r}: {e}") from e diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index 96f860b3209..acc534d8b52 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -2316,19 +2316,6 @@ def cumprod( class DatasetGroupByAggregations: _obj: Dataset - def _reduce_without_squeeze_warn( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, - ) -> Dataset: - raise NotImplementedError() - def reduce( self, func: Callable[..., Any], @@ -2436,7 +2423,7 @@ def count( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.count, dim=dim, numeric_only=False, @@ -2532,7 +2519,7 @@ def all( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.array_all, dim=dim, numeric_only=False, @@ -2628,7 +2615,7 @@ def any( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.array_any, dim=dim, numeric_only=False, @@ -2741,7 +2728,7 @@ def max( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.max, dim=dim, skipna=skipna, @@ -2855,7 +2842,7 @@ def min( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.min, dim=dim, skipna=skipna, @@ -2971,7 +2958,7 @@ def mean( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.mean, dim=dim, skipna=skipna, @@ -3105,7 +3092,7 @@ def prod( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.prod, dim=dim, skipna=skipna, @@ -3240,7 +3227,7 @@ def sum( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.sum, dim=dim, skipna=skipna, @@ -3372,7 +3359,7 @@ def std( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.std, dim=dim, skipna=skipna, @@ -3504,7 +3491,7 @@ def var( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.var, dim=dim, skipna=skipna, @@ -3606,7 +3593,7 @@ def median( Data variables: da (labels) float64 24B nan 2.0 1.5 """ - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.median, dim=dim, skipna=skipna, @@ -3705,7 +3692,7 @@ def cumsum( Data variables: da (time) float64 48B 1.0 2.0 3.0 3.0 4.0 nan """ - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.cumsum, dim=dim, skipna=skipna, @@ -3804,7 +3791,7 @@ def cumprod( Data variables: da (time) float64 48B 1.0 2.0 3.0 0.0 4.0 nan """ - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.cumprod, dim=dim, skipna=skipna, @@ -3817,19 +3804,6 @@ def cumprod( class DatasetResampleAggregations: _obj: Dataset - def _reduce_without_squeeze_warn( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, - ) -> Dataset: - raise NotImplementedError() - def reduce( self, func: Callable[..., Any], @@ -3937,7 +3911,7 @@ def count( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.count, dim=dim, numeric_only=False, @@ -4033,7 +4007,7 @@ def all( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.array_all, dim=dim, numeric_only=False, @@ -4129,7 +4103,7 @@ def any( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.array_any, dim=dim, numeric_only=False, @@ -4242,7 +4216,7 @@ def max( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.max, dim=dim, skipna=skipna, @@ -4356,7 +4330,7 @@ def min( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.min, dim=dim, skipna=skipna, @@ -4472,7 +4446,7 @@ def mean( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.mean, dim=dim, skipna=skipna, @@ -4606,7 +4580,7 @@ def prod( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.prod, dim=dim, skipna=skipna, @@ -4741,7 +4715,7 @@ def sum( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.sum, dim=dim, skipna=skipna, @@ -4873,7 +4847,7 @@ def std( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.std, dim=dim, skipna=skipna, @@ -5005,7 +4979,7 @@ def var( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.var, dim=dim, skipna=skipna, @@ -5107,7 +5081,7 @@ def median( Data variables: da (time) float64 24B 1.0 2.0 nan """ - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.median, dim=dim, skipna=skipna, @@ -5206,7 +5180,7 @@ def cumsum( Data variables: da (time) float64 48B 1.0 2.0 5.0 5.0 2.0 nan """ - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.cumsum, dim=dim, skipna=skipna, @@ -5305,7 +5279,7 @@ def cumprod( Data variables: da (time) float64 48B 1.0 2.0 6.0 0.0 2.0 nan """ - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.cumprod, dim=dim, skipna=skipna, @@ -5318,19 +5292,6 @@ def cumprod( class DataArrayGroupByAggregations: _obj: DataArray - def _reduce_without_squeeze_warn( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, - ) -> DataArray: - raise NotImplementedError() - def reduce( self, func: Callable[..., Any], @@ -5432,7 +5393,7 @@ def count( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.count, dim=dim, keep_attrs=keep_attrs, @@ -5521,7 +5482,7 @@ def all( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.array_all, dim=dim, keep_attrs=keep_attrs, @@ -5610,7 +5571,7 @@ def any( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.array_any, dim=dim, keep_attrs=keep_attrs, @@ -5714,7 +5675,7 @@ def max( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.max, dim=dim, skipna=skipna, @@ -5819,7 +5780,7 @@ def min( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.min, dim=dim, skipna=skipna, @@ -5926,7 +5887,7 @@ def mean( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.mean, dim=dim, skipna=skipna, @@ -6049,7 +6010,7 @@ def prod( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.prod, dim=dim, skipna=skipna, @@ -6173,7 +6134,7 @@ def sum( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.sum, dim=dim, skipna=skipna, @@ -6294,7 +6255,7 @@ def std( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.std, dim=dim, skipna=skipna, @@ -6415,7 +6376,7 @@ def var( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.var, dim=dim, skipna=skipna, @@ -6509,7 +6470,7 @@ def median( Coordinates: * labels (labels) object 24B 'a' 'b' 'c' """ - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.median, dim=dim, skipna=skipna, @@ -6604,7 +6565,7 @@ def cumsum( * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) DataArray: - raise NotImplementedError() - def reduce( self, func: Callable[..., Any], @@ -6825,7 +6773,7 @@ def count( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.count, dim=dim, keep_attrs=keep_attrs, @@ -6914,7 +6862,7 @@ def all( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.array_all, dim=dim, keep_attrs=keep_attrs, @@ -7003,7 +6951,7 @@ def any( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.array_any, dim=dim, keep_attrs=keep_attrs, @@ -7107,7 +7055,7 @@ def max( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.max, dim=dim, skipna=skipna, @@ -7212,7 +7160,7 @@ def min( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.min, dim=dim, skipna=skipna, @@ -7319,7 +7267,7 @@ def mean( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.mean, dim=dim, skipna=skipna, @@ -7442,7 +7390,7 @@ def prod( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.prod, dim=dim, skipna=skipna, @@ -7566,7 +7514,7 @@ def sum( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.sum, dim=dim, skipna=skipna, @@ -7687,7 +7635,7 @@ def std( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.std, dim=dim, skipna=skipna, @@ -7808,7 +7756,7 @@ def var( **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.var, dim=dim, skipna=skipna, @@ -7902,7 +7850,7 @@ def median( Coordinates: * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.median, dim=dim, skipna=skipna, @@ -7997,7 +7945,7 @@ def cumsum( labels (time) T_DataArray: ... + def __add__(self, other: T_DA) -> T_DA: ... @overload def __add__(self, other: VarCompatible) -> Self: ... - def __add__(self, other: VarCompatible) -> Self | T_DataArray: + def __add__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.add) @overload - def __sub__(self, other: T_DataArray) -> T_DataArray: ... + def __sub__(self, other: T_DA) -> T_DA: ... @overload def __sub__(self, other: VarCompatible) -> Self: ... - def __sub__(self, other: VarCompatible) -> Self | T_DataArray: + def __sub__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.sub) @overload - def __mul__(self, other: T_DataArray) -> T_DataArray: ... + def __mul__(self, other: T_DA) -> T_DA: ... @overload def __mul__(self, other: VarCompatible) -> Self: ... - def __mul__(self, other: VarCompatible) -> Self | T_DataArray: + def __mul__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.mul) @overload - def __pow__(self, other: T_DataArray) -> T_DataArray: ... + def __pow__(self, other: T_DA) -> T_DA: ... @overload def __pow__(self, other: VarCompatible) -> Self: ... - def __pow__(self, other: VarCompatible) -> Self | T_DataArray: + def __pow__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.pow) @overload - def __truediv__(self, other: T_DataArray) -> T_DataArray: ... + def __truediv__(self, other: T_DA) -> T_DA: ... @overload def __truediv__(self, other: VarCompatible) -> Self: ... - def __truediv__(self, other: VarCompatible) -> Self | T_DataArray: + def __truediv__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.truediv) @overload - def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... + def __floordiv__(self, other: T_DA) -> T_DA: ... @overload def __floordiv__(self, other: VarCompatible) -> Self: ... - def __floordiv__(self, other: VarCompatible) -> Self | T_DataArray: + def __floordiv__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.floordiv) @overload - def __mod__(self, other: T_DataArray) -> T_DataArray: ... + def __mod__(self, other: T_DA) -> T_DA: ... @overload def __mod__(self, other: VarCompatible) -> Self: ... - def __mod__(self, other: VarCompatible) -> Self | T_DataArray: + def __mod__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.mod) @overload - def __and__(self, other: T_DataArray) -> T_DataArray: ... + def __and__(self, other: T_DA) -> T_DA: ... @overload def __and__(self, other: VarCompatible) -> Self: ... - def __and__(self, other: VarCompatible) -> Self | T_DataArray: + def __and__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.and_) @overload - def __xor__(self, other: T_DataArray) -> T_DataArray: ... + def __xor__(self, other: T_DA) -> T_DA: ... @overload def __xor__(self, other: VarCompatible) -> Self: ... - def __xor__(self, other: VarCompatible) -> Self | T_DataArray: + def __xor__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.xor) @overload - def __or__(self, other: T_DataArray) -> T_DataArray: ... + def __or__(self, other: T_DA) -> T_DA: ... @overload def __or__(self, other: VarCompatible) -> Self: ... - def __or__(self, other: VarCompatible) -> Self | T_DataArray: + def __or__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.or_) @overload - def __lshift__(self, other: T_DataArray) -> T_DataArray: ... + def __lshift__(self, other: T_DA) -> T_DA: ... @overload def __lshift__(self, other: VarCompatible) -> Self: ... - def __lshift__(self, other: VarCompatible) -> Self | T_DataArray: + def __lshift__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.lshift) @overload - def __rshift__(self, other: T_DataArray) -> T_DataArray: ... + def __rshift__(self, other: T_DA) -> T_DA: ... @overload def __rshift__(self, other: VarCompatible) -> Self: ... - def __rshift__(self, other: VarCompatible) -> Self | T_DataArray: + def __rshift__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.rshift) @overload - def __lt__(self, other: T_DataArray) -> T_DataArray: ... + def __lt__(self, other: T_DA) -> T_DA: ... @overload def __lt__(self, other: VarCompatible) -> Self: ... - def __lt__(self, other: VarCompatible) -> Self | T_DataArray: + def __lt__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.lt) @overload - def __le__(self, other: T_DataArray) -> T_DataArray: ... + def __le__(self, other: T_DA) -> T_DA: ... @overload def __le__(self, other: VarCompatible) -> Self: ... - def __le__(self, other: VarCompatible) -> Self | T_DataArray: + def __le__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.le) @overload - def __gt__(self, other: T_DataArray) -> T_DataArray: ... + def __gt__(self, other: T_DA) -> T_DA: ... @overload def __gt__(self, other: VarCompatible) -> Self: ... - def __gt__(self, other: VarCompatible) -> Self | T_DataArray: + def __gt__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.gt) @overload - def __ge__(self, other: T_DataArray) -> T_DataArray: ... + def __ge__(self, other: T_DA) -> T_DA: ... @overload def __ge__(self, other: VarCompatible) -> Self: ... - def __ge__(self, other: VarCompatible) -> Self | T_DataArray: + def __ge__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.ge) @overload # type:ignore[override] - def __eq__(self, other: T_DataArray) -> T_DataArray: ... + def __eq__(self, other: T_DA) -> T_DA: ... @overload def __eq__(self, other: VarCompatible) -> Self: ... - def __eq__(self, other: VarCompatible) -> Self | T_DataArray: + def __eq__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, nputils.array_eq) @overload # type:ignore[override] - def __ne__(self, other: T_DataArray) -> T_DataArray: ... + def __ne__(self, other: T_DA) -> T_DA: ... @overload def __ne__(self, other: VarCompatible) -> Self: ... - def __ne__(self, other: VarCompatible) -> Self | T_DataArray: + def __ne__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, nputils.array_ne) # When __eq__ is defined but __hash__ is not, then an object is unhashable, @@ -770,96 +770,96 @@ class DatasetGroupByOpsMixin: __slots__ = () def _binary_op( - self, other: GroupByCompatible, f: Callable, reflexive: bool = False + self, other: Dataset | DataArray, f: Callable, reflexive: bool = False ) -> Dataset: raise NotImplementedError - def __add__(self, other: GroupByCompatible) -> Dataset: + def __add__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.add) - def __sub__(self, other: GroupByCompatible) -> Dataset: + def __sub__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.sub) - def __mul__(self, other: GroupByCompatible) -> Dataset: + def __mul__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.mul) - def __pow__(self, other: GroupByCompatible) -> Dataset: + def __pow__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.pow) - def __truediv__(self, other: GroupByCompatible) -> Dataset: + def __truediv__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other: GroupByCompatible) -> Dataset: + def __floordiv__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.floordiv) - def __mod__(self, other: GroupByCompatible) -> Dataset: + def __mod__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.mod) - def __and__(self, other: GroupByCompatible) -> Dataset: + def __and__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.and_) - def __xor__(self, other: GroupByCompatible) -> Dataset: + def __xor__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.xor) - def __or__(self, other: GroupByCompatible) -> Dataset: + def __or__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.or_) - def __lshift__(self, other: GroupByCompatible) -> Dataset: + def __lshift__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.lshift) - def __rshift__(self, other: GroupByCompatible) -> Dataset: + def __rshift__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.rshift) - def __lt__(self, other: GroupByCompatible) -> Dataset: + def __lt__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.lt) - def __le__(self, other: GroupByCompatible) -> Dataset: + def __le__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.le) - def __gt__(self, other: GroupByCompatible) -> Dataset: + def __gt__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.gt) - def __ge__(self, other: GroupByCompatible) -> Dataset: + def __ge__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.ge) - def __eq__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] + def __eq__(self, other: Dataset | DataArray) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] + def __ne__(self, other: Dataset | DataArray) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_ne) # When __eq__ is defined but __hash__ is not, then an object is unhashable, # and it should be declared as follows: __hash__: None # type:ignore[assignment] - def __radd__(self, other: GroupByCompatible) -> Dataset: + def __radd__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other: GroupByCompatible) -> Dataset: + def __rsub__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other: GroupByCompatible) -> Dataset: + def __rmul__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other: GroupByCompatible) -> Dataset: + def __rpow__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other: GroupByCompatible) -> Dataset: + def __rtruediv__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other: GroupByCompatible) -> Dataset: + def __rfloordiv__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other: GroupByCompatible) -> Dataset: + def __rmod__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other: GroupByCompatible) -> Dataset: + def __rand__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other: GroupByCompatible) -> Dataset: + def __rxor__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other: GroupByCompatible) -> Dataset: + def __ror__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ diff --git a/xarray/core/common.py b/xarray/core/common.py index 28ffc80e4e7..1e9c8ed8e29 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -52,6 +52,7 @@ T_Variable, ) from xarray.core.variable import Variable + from xarray.groupers import Resampler DTypeMaybeMapping = Union[DTypeLikeSave, Mapping[Any, DTypeLikeSave]] @@ -876,16 +877,14 @@ def rolling_exp( def _resample( self, resample_cls: type[T_Resample], - indexer: Mapping[Any, str] | None, + indexer: Mapping[Hashable, str | Resampler] | None, skipna: bool | None, closed: SideOptions | None, label: SideOptions | None, - base: int | None, offset: pd.Timedelta | datetime.timedelta | str | None, origin: str | DatetimeLike, - loffset: datetime.timedelta | str | None, restore_coord_dims: bool | None, - **indexer_kwargs: str, + **indexer_kwargs: str | Resampler, ) -> T_Resample: """Returns a Resample object for performing resampling operations. @@ -905,16 +904,6 @@ def _resample( Side of each interval to treat as closed. label : {"left", "right"}, optional Side of each interval to use for labeling. - base : int, optional - For frequencies that evenly subdivide 1 day, the "origin" of the - aggregated intervals. For example, for "24H" frequency, base could - range from 0 through 23. - - .. deprecated:: 2023.03.0 - Following pandas, the ``base`` parameter is deprecated in favor - of the ``origin`` and ``offset`` parameters, and will be removed - in a future version of xarray. - origin : {'epoch', 'start', 'start_day', 'end', 'end_day'}, pd.Timestamp, datetime.datetime, np.datetime64, or cftime.datetime, default 'start_day' The datetime on which to adjust the grouping. The timezone of origin must match the timezone of the index. @@ -927,15 +916,6 @@ def _resample( - 'end_day': `origin` is the ceiling midnight of the last day offset : pd.Timedelta, datetime.timedelta, or str, default is None An offset timedelta added to the origin. - loffset : timedelta or str, optional - Offset used to adjust the resampled time labels. Some pandas date - offset strings are supported. - - .. deprecated:: 2023.03.0 - Following pandas, the ``loffset`` parameter is deprecated in favor - of using time offset arithmetic, and will be removed in a future - version of xarray. - restore_coord_dims : bool, optional If True, also restore the dimension order of multi-dimensional coordinates. @@ -1068,20 +1048,8 @@ def _resample( from xarray.core.dataarray import DataArray from xarray.core.groupby import ResolvedGrouper - from xarray.core.groupers import TimeResampler from xarray.core.resample import RESAMPLE_DIM - - # note: the second argument (now 'skipna') use to be 'dim' - if ( - (skipna is not None and not isinstance(skipna, bool)) - or ("how" in indexer_kwargs and "how" not in self.dims) - or ("dim" in indexer_kwargs and "dim" not in self.dims) - ): - raise TypeError( - "resample() no longer supports the `how` or " - "`dim` arguments. Instead call methods on resample " - "objects, e.g., data.resample(time='1D').mean()" - ) + from xarray.groupers import Resampler, TimeResampler indexer = either_dict_or_kwargs(indexer, indexer_kwargs, "resample") if len(indexer) != 1: @@ -1092,21 +1060,18 @@ def _resample( dim_coord = self[dim] group = DataArray( - dim_coord, - coords=dim_coord.coords, - dims=dim_coord.dims, - name=RESAMPLE_DIM, + dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM ) - grouper = TimeResampler( - freq=freq, - closed=closed, - label=label, - origin=origin, - offset=offset, - loffset=loffset, - base=base, - ) + grouper: Resampler + if isinstance(freq, str): + grouper = TimeResampler( + freq=freq, closed=closed, label=label, origin=origin, offset=offset + ) + elif isinstance(freq, Resampler): + grouper = freq + else: + raise ValueError("freq must be a str or a Resampler object") rgrouper = ResolvedGrouper(grouper, group, self) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index f418d3821c2..5d21d0836b9 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1064,7 +1064,7 @@ def apply_ufunc( supported: >>> magnitude(3, 4) - 5.0 + np.float64(5.0) >>> magnitude(3, np.array([0, 4])) array([3., 5.]) >>> magnitude(array, 0) @@ -1587,15 +1587,6 @@ def cross( array([-3, 6, -3]) Dimensions without coordinates: dim_0 - Vector cross-product with 2 dimensions, returns in the perpendicular - direction: - - >>> a = xr.DataArray([1, 2]) - >>> b = xr.DataArray([4, 5]) - >>> xr.cross(a, b, dim="dim_0") - Size: 8B - array(-3) - Vector cross-product with 3 dimensions but zeros at the last axis yields the same results as with 2 dimensions: @@ -1606,42 +1597,6 @@ def cross( array([ 0, 0, -3]) Dimensions without coordinates: dim_0 - One vector with dimension 2: - - >>> a = xr.DataArray( - ... [1, 2], - ... dims=["cartesian"], - ... coords=dict(cartesian=(["cartesian"], ["x", "y"])), - ... ) - >>> b = xr.DataArray( - ... [4, 5, 6], - ... dims=["cartesian"], - ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), - ... ) - >>> xr.cross(a, b, dim="cartesian") - Size: 24B - array([12, -6, -3]) - Coordinates: - * cartesian (cartesian) >> a = xr.DataArray( - ... [1, 2], - ... dims=["cartesian"], - ... coords=dict(cartesian=(["cartesian"], ["x", "z"])), - ... ) - >>> b = xr.DataArray( - ... [4, 5, 6], - ... dims=["cartesian"], - ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), - ... ) - >>> xr.cross(a, b, dim="cartesian") - Size: 24B - array([-10, 2, 5]) - Coordinates: - * cartesian (cartesian) tuple[Mapping[Hashable, Any], tuple[Hashable, ...]]: """All the logic for creating a new DataArray""" @@ -199,7 +206,11 @@ def _infer_coords_and_dims( def _check_data_shape( data: Any, - coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None, + coords: ( + Sequence[Sequence | pd.Index | DataArray | Variable | np.ndarray] + | Mapping + | None + ), dims: str | Iterable[Hashable] | None, ) -> Any: if data is dtypes.NA: @@ -413,7 +424,11 @@ class DataArray( def __init__( self, data: Any = dtypes.NA, - coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None, + coords: ( + Sequence[Sequence | pd.Index | DataArray | Variable | np.ndarray] + | Mapping + | None + ) = None, dims: str | Iterable[Hashable] | None = None, name: Hashable | None = None, attrs: Mapping | None = None, @@ -965,7 +980,7 @@ def indexes(self) -> Indexes: return self.xindexes.to_pandas_indexes() @property - def xindexes(self) -> Indexes: + def xindexes(self) -> Indexes[Index]: """Mapping of :py:class:`~xarray.indexes.Index` objects used for label based indexing. """ @@ -1337,7 +1352,7 @@ def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: @_deprecate_positional_args("v2023.10.0") def chunk( self, - chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + chunks: T_ChunksFreq = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) *, name_prefix: str = "xarray-", token: str | None = None, @@ -1345,7 +1360,7 @@ def chunk( inline_array: bool = False, chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, - **chunks_kwargs: Any, + **chunks_kwargs: T_ChunkDimFreq, ) -> Self: """Coerce this array's data into a dask arrays with the given chunks. @@ -1357,11 +1372,13 @@ def chunk( sizes along that dimension will not be updated; non-dask arrays will be converted into dask arrays with a single block. + Along datetime-like dimensions, a pandas frequency string is also accepted. + Parameters ---------- - chunks : int, "auto", tuple of int or mapping of Hashable to int, optional + chunks : int, "auto", tuple of int or mapping of hashable to int or a pandas frequency string, optional Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, ``(5, 5)`` or - ``{"x": 5, "y": 5}``. + ``{"x": 5, "y": 5}`` or ``{"x": 5, "time": "YE"}``. name_prefix : str, optional Prefix for the name of the new dask array. token : str, optional @@ -1396,29 +1413,30 @@ def chunk( xarray.unify_chunks dask.array.from_array """ + chunk_mapping: T_ChunksFreq if chunks is None: warnings.warn( "None value for 'chunks' is deprecated. " "It will raise an error in the future. Use instead '{}'", category=FutureWarning, ) - chunks = {} + chunk_mapping = {} if isinstance(chunks, (float, str, int)): # ignoring type; unclear why it won't accept a Literal into the value. - chunks = dict.fromkeys(self.dims, chunks) + chunk_mapping = dict.fromkeys(self.dims, chunks) elif isinstance(chunks, (tuple, list)): utils.emit_user_level_warning( "Supplying chunks as dimension-order tuples is deprecated. " "It will raise an error in the future. Instead use a dict with dimension names as keys.", category=DeprecationWarning, ) - chunks = dict(zip(self.dims, chunks)) + chunk_mapping = dict(zip(self.dims, chunks)) else: - chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") + chunk_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") ds = self._to_temp_dataset().chunk( - chunks, + chunk_mapping, name_prefix=name_prefix, token=token, lock=lock, @@ -3004,7 +3022,7 @@ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Data if not isinstance(idx, pd.MultiIndex): raise ValueError(f"'{dim}' is not a stacked coordinate") - level_number = idx._get_level_number(level) + level_number = idx._get_level_number(level) # type: ignore[attr-defined] variables = idx.levels[level_number] variable_dim = idx.names[level_number] @@ -3838,7 +3856,7 @@ def to_pandas(self) -> Self | pd.Series | pd.DataFrame: "pandas objects. Requires 2 or fewer dimensions." ) indexes = [self.get_index(dim) for dim in self.dims] - return constructor(self.values, *indexes) + return constructor(self.values, *indexes) # type: ignore[operator] def to_dataframe( self, name: Hashable | None = None, dim_order: Sequence[Hashable] | None = None @@ -6680,26 +6698,34 @@ def interp_calendar( """ return interp_calendar(self, target, dim=dim) + @_deprecate_positional_args("v2024.07.0") def groupby( self, - group: Hashable | DataArray | IndexVariable, - squeeze: bool | None = None, + group: ( + Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None + ) = None, + *, + squeeze: Literal[False] = False, restore_coord_dims: bool = False, + **groupers: Grouper, ) -> DataArrayGroupBy: """Returns a DataArrayGroupBy object for performing grouped operations. Parameters ---------- - group : Hashable, DataArray or IndexVariable + group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper Array whose unique values should be used to group this array. If a - Hashable, must be the name of a coordinate contained in this dataarray. - squeeze : bool, default: True - If "group" is a dimension of any arrays in this dataset, `squeeze` - controls whether the subarrays have a dimension of length 1 along - that dimension or if the dimension is squeezed out. + Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary, + must map an existing variable name to a :py:class:`Grouper` instance. + squeeze : False + This argument is deprecated. restore_coord_dims : bool, default: False If True, also restore the dimension order of multi-dimensional coordinates. + **groupers : Mapping of str to Grouper or Resampler + Mapping of variable name to group by to :py:class:`Grouper` or :py:class:`Resampler` object. + One of ``group`` or ``groupers`` must be provided. + Only a single ``grouper`` is allowed at present. Returns ------- @@ -6729,6 +6755,15 @@ def groupby( * time (time) datetime64[ns] 15kB 2000-01-01 2000-01-02 ... 2004-12-31 dayofyear (time) int64 15kB 1 2 3 4 5 6 7 8 ... 360 361 362 363 364 365 366 + Use a ``Grouper`` object to be more explicit + + >>> da.coords["dayofyear"] = da.time.dt.dayofyear + >>> da.groupby(dayofyear=xr.groupers.UniqueGrouper()).mean() + Size: 3kB + array([ 730.8, 731.8, 732.8, ..., 1093.8, 1094.8, 1095.5]) + Coordinates: + * dayofyear (dayofyear) int64 3kB 1 2 3 4 5 6 7 ... 361 362 363 364 365 366 + See Also -------- :ref:`groupby` @@ -6753,27 +6788,48 @@ def groupby( ResolvedGrouper, _validate_groupby_squeeze, ) - from xarray.core.groupers import UniqueGrouper + from xarray.groupers import UniqueGrouper _validate_groupby_squeeze(squeeze) - rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) + + if isinstance(group, Mapping): + groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore + group = None + + grouper: Grouper + if group is not None: + if groupers: + raise ValueError( + "Providing a combination of `group` and **groupers is not supported." + ) + grouper = UniqueGrouper() + else: + if len(groupers) > 1: + raise ValueError("grouping by multiple variables is not supported yet.") + if not groupers: + raise ValueError("Either `group` or `**groupers` must be provided.") + group, grouper = next(iter(groupers.items())) + + rgrouper = ResolvedGrouper(grouper, group, self) + return DataArrayGroupBy( self, (rgrouper,), - squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) + @_deprecate_positional_args("v2024.07.0") def groupby_bins( self, group: Hashable | DataArray | IndexVariable, - bins: ArrayLike, + bins: Bins, right: bool = True, labels: ArrayLike | Literal[False] | None = None, precision: int = 3, include_lowest: bool = False, - squeeze: bool | None = None, + squeeze: Literal[False] = False, restore_coord_dims: bool = False, + duplicates: Literal["raise", "drop"] = "raise", ) -> DataArrayGroupBy: """Returns a DataArrayGroupBy object for performing grouped operations. @@ -6803,13 +6859,13 @@ def groupby_bins( The precision at which to store and display the bins labels. include_lowest : bool, default: False Whether the first interval should be left-inclusive or not. - squeeze : bool, default: True - If "group" is a dimension of any arrays in this dataset, `squeeze` - controls whether the subarrays have a dimension of length 1 along - that dimension or if the dimension is squeezed out. + squeeze : False + This argument is deprecated. restore_coord_dims : bool, default: False If True, also restore the dimension order of multi-dimensional coordinates. + duplicates : {"raise", "drop"}, default: "raise" + If bin edges are not unique, raise ValueError or drop non-uniques. Returns ------- @@ -6837,24 +6893,21 @@ def groupby_bins( ResolvedGrouper, _validate_groupby_squeeze, ) - from xarray.core.groupers import BinGrouper + from xarray.groupers import BinGrouper _validate_groupby_squeeze(squeeze) grouper = BinGrouper( bins=bins, - cut_kwargs={ - "right": right, - "labels": labels, - "precision": precision, - "include_lowest": include_lowest, - }, + right=right, + labels=labels, + precision=precision, + include_lowest=include_lowest, ) rgrouper = ResolvedGrouper(grouper, group, self) return DataArrayGroupBy( self, (rgrouper,), - squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) @@ -7187,18 +7240,18 @@ def coarsen( coord_func=coord_func, ) + @_deprecate_positional_args("v2024.07.0") def resample( self, - indexer: Mapping[Any, str] | None = None, + indexer: Mapping[Hashable, str | Resampler] | None = None, + *, skipna: bool | None = None, closed: SideOptions | None = None, label: SideOptions | None = None, - base: int | None = None, offset: pd.Timedelta | datetime.timedelta | str | None = None, origin: str | DatetimeLike = "start_day", - loffset: datetime.timedelta | str | None = None, restore_coord_dims: bool | None = None, - **indexer_kwargs: str, + **indexer_kwargs: str | Resampler, ) -> DataArrayResample: """Returns a Resample object for performing resampling operations. @@ -7218,10 +7271,6 @@ def resample( Side of each interval to treat as closed. label : {"left", "right"}, optional Side of each interval to use for labeling. - base : int, optional - For frequencies that evenly subdivide 1 day, the "origin" of the - aggregated intervals. For example, for "24H" frequency, base could - range from 0 through 23. origin : {'epoch', 'start', 'start_day', 'end', 'end_day'}, pd.Timestamp, datetime.datetime, np.datetime64, or cftime.datetime, default 'start_day' The datetime on which to adjust the grouping. The timezone of origin must match the timezone of the index. @@ -7234,15 +7283,6 @@ def resample( - 'end_day': `origin` is the ceiling midnight of the last day offset : pd.Timedelta, datetime.timedelta, or str, default is None An offset timedelta added to the origin. - loffset : timedelta or str, optional - Offset used to adjust the resampled time labels. Some pandas date - offset strings are supported. - - .. deprecated:: 2023.03.0 - Following pandas, the ``loffset`` parameter is deprecated in favor - of using time offset arithmetic, and will be removed in a future - version of xarray. - restore_coord_dims : bool, optional If True, also restore the dimension order of multi-dimensional coordinates. @@ -7291,28 +7331,7 @@ def resample( 0.48387097, 0.51612903, 0.5483871 , 0.58064516, 0.61290323, 0.64516129, 0.67741935, 0.70967742, 0.74193548, 0.77419355, 0.80645161, 0.83870968, 0.87096774, 0.90322581, 0.93548387, - 0.96774194, 1. , 1.03225806, 1.06451613, 1.09677419, - 1.12903226, 1.16129032, 1.19354839, 1.22580645, 1.25806452, - 1.29032258, 1.32258065, 1.35483871, 1.38709677, 1.41935484, - 1.4516129 , 1.48387097, 1.51612903, 1.5483871 , 1.58064516, - 1.61290323, 1.64516129, 1.67741935, 1.70967742, 1.74193548, - 1.77419355, 1.80645161, 1.83870968, 1.87096774, 1.90322581, - 1.93548387, 1.96774194, 2. , 2.03448276, 2.06896552, - 2.10344828, 2.13793103, 2.17241379, 2.20689655, 2.24137931, - 2.27586207, 2.31034483, 2.34482759, 2.37931034, 2.4137931 , - 2.44827586, 2.48275862, 2.51724138, 2.55172414, 2.5862069 , - 2.62068966, 2.65517241, 2.68965517, 2.72413793, 2.75862069, - 2.79310345, 2.82758621, 2.86206897, 2.89655172, 2.93103448, - 2.96551724, 3. , 3.03225806, 3.06451613, 3.09677419, - 3.12903226, 3.16129032, 3.19354839, 3.22580645, 3.25806452, - ... - 7.87096774, 7.90322581, 7.93548387, 7.96774194, 8. , - 8.03225806, 8.06451613, 8.09677419, 8.12903226, 8.16129032, - 8.19354839, 8.22580645, 8.25806452, 8.29032258, 8.32258065, - 8.35483871, 8.38709677, 8.41935484, 8.4516129 , 8.48387097, - 8.51612903, 8.5483871 , 8.58064516, 8.61290323, 8.64516129, - 8.67741935, 8.70967742, 8.74193548, 8.77419355, 8.80645161, - 8.83870968, 8.87096774, 8.90322581, 8.93548387, 8.96774194, + 0.96774194, 1. , ..., 9. , 9.03333333, 9.06666667, 9.1 , 9.13333333, 9.16666667, 9.2 , 9.23333333, 9.26666667, 9.3 , 9.33333333, 9.36666667, 9.4 , 9.43333333, 9.46666667, @@ -7342,19 +7361,7 @@ def resample( nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 3., 3., 3., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, - nan, nan, nan, nan, 4., 4., 4., nan, nan, nan, nan, nan, nan, - nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, - nan, nan, nan, nan, nan, nan, nan, nan, 5., 5., 5., nan, nan, - nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, - nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, - 6., 6., 6., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, - nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, - nan, nan, nan, nan, 7., 7., 7., nan, nan, nan, nan, nan, nan, - nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, - nan, nan, nan, nan, nan, nan, nan, nan, nan, 8., 8., 8., nan, - nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, - nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, - nan, 9., 9., 9., nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, 4., 4., 4., nan, nan, nan, nan, nan, ..., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 10., 10., 10., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, @@ -7380,10 +7387,8 @@ def resample( skipna=skipna, closed=closed, label=label, - base=base, offset=offset, origin=origin, - loffset=loffset, restore_coord_dims=restore_coord_dims, **indexer_kwargs, ) @@ -7456,3 +7461,20 @@ def to_dask_dataframe( # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names str = utils.UncachedAccessor(StringAccessor["DataArray"]) + + def drop_attrs(self, *, deep: bool = True) -> Self: + """ + Removes all attributes from the DataArray. + + Parameters + ---------- + deep : bool, default True + Removes attributes from coordinates. + + Returns + ------- + DataArray + """ + return ( + self._to_temp_dataset().drop_attrs(deep=deep).pipe(self._from_temp_dataset) + ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 50cfc7b0c29..cad2f00ccc1 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -29,9 +29,9 @@ # remove once numpy 2.0 is the oldest supported version try: - from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] + from numpy.exceptions import RankWarning except ImportError: - from numpy import RankWarning + from numpy import RankWarning # type: ignore[no-redef,attr-defined,unused-ignore] import pandas as pd @@ -88,11 +88,12 @@ from xarray.core.missing import get_clean_interp_index from xarray.core.options import OPTIONS, _get_keep_attrs from xarray.core.types import ( + Bins, NetcdfWriteModes, QuantileMethods, Self, T_ChunkDim, - T_Chunks, + T_ChunksFreq, T_DataArray, T_DataArrayOrSet, T_Dataset, @@ -160,9 +161,11 @@ QueryParserOptions, ReindexMethodOptions, SideOptions, + T_ChunkDimFreq, T_Xarray, ) from xarray.core.weighted import DatasetWeighted + from xarray.groupers import Grouper, Resampler from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @@ -281,18 +284,17 @@ def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint): def _maybe_chunk( - name, - var, - chunks, + name: Hashable, + var: Variable, + chunks: Mapping[Any, T_ChunkDim] | None, token=None, lock=None, - name_prefix="xarray-", - overwrite_encoded_chunks=False, - inline_array=False, + name_prefix: str = "xarray-", + overwrite_encoded_chunks: bool = False, + inline_array: bool = False, chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, -): - +) -> Variable: from xarray.namedarray.daskmanager import DaskManager if chunks is not None: @@ -2646,14 +2648,14 @@ def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]: def chunk( self, - chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + chunks: T_ChunksFreq = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str | None = None, lock: bool = False, inline_array: bool = False, chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, - **chunks_kwargs: T_ChunkDim, + **chunks_kwargs: T_ChunkDimFreq, ) -> Self: """Coerce all arrays in this dataset into dask arrays with the given chunks. @@ -2665,11 +2667,13 @@ def chunk( sizes along that dimension will not be updated; non-dask arrays will be converted into dask arrays with a single block. + Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted. + Parameters ---------- - chunks : int, tuple of int, "auto" or mapping of hashable to int, optional + chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or - ``{"x": 5, "y": 5}``. + ``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``. name_prefix : str, default: "xarray-" Prefix for the name of any new dask arrays. token : str, optional @@ -2704,6 +2708,9 @@ def chunk( xarray.unify_chunks dask.array.from_array """ + from xarray.core.dataarray import DataArray + from xarray.groupers import TimeResampler + if chunks is None and not chunks_kwargs: warnings.warn( "None value for 'chunks' is deprecated. " @@ -2729,6 +2736,42 @@ def chunk( f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}" ) + def _resolve_frequency( + name: Hashable, resampler: TimeResampler + ) -> tuple[int, ...]: + variable = self._variables.get(name, None) + if variable is None: + raise ValueError( + f"Cannot chunk by resampler {resampler!r} for virtual variables." + ) + elif not _contains_datetime_like_objects(variable): + raise ValueError( + f"chunks={resampler!r} only supported for datetime variables. " + f"Received variable {name!r} with dtype {variable.dtype!r} instead." + ) + + assert variable.ndim == 1 + chunks: tuple[int, ...] = tuple( + DataArray( + np.ones(variable.shape, dtype=int), + dims=(name,), + coords={name: variable}, + ) + .resample({name: resampler}) + .sum() + .data.tolist() + ) + return chunks + + chunks_mapping_ints: Mapping[Any, T_ChunkDim] = { + name: ( + _resolve_frequency(name, chunks) + if isinstance(chunks, TimeResampler) + else chunks + ) + for name, chunks in chunks_mapping.items() + } + chunkmanager = guess_chunkmanager(chunked_array_type) if from_array_kwargs is None: from_array_kwargs = {} @@ -2737,7 +2780,7 @@ def chunk( k: _maybe_chunk( k, v, - chunks_mapping, + chunks_mapping_ints, token, lock, name_prefix, @@ -4156,15 +4199,15 @@ def interp_like( kwargs = {} # pick only dimension coordinates with a single index - coords = {} + coords: dict[Hashable, Variable] = {} other_indexes = other.xindexes for dim in self.dims: other_dim_coords = other_indexes.get_all_coords(dim, errors="ignore") if len(other_dim_coords) == 1: coords[dim] = other_dim_coords[dim] - numeric_coords: dict[Hashable, pd.Index] = {} - object_coords: dict[Hashable, pd.Index] = {} + numeric_coords: dict[Hashable, Variable] = {} + object_coords: dict[Hashable, Variable] = {} for k, v in coords.items(): if v.dtype.kind in "uifcMm": numeric_coords[k] = v @@ -6369,7 +6412,7 @@ def dropna( Data variables: temperature (time, location) float64 64B 23.4 24.1 nan ... 24.2 20.5 25.3 - # Drop NaN values from the dataset + Drop NaN values from the dataset >>> dataset.dropna(dim="time") Size: 80B @@ -6380,7 +6423,7 @@ def dropna( Data variables: temperature (time, location) float64 48B 23.4 24.1 21.8 24.2 20.5 25.3 - # Drop labels with any NAN values + Drop labels with any NaN values >>> dataset.dropna(dim="time", how="any") Size: 80B @@ -6391,7 +6434,7 @@ def dropna( Data variables: temperature (time, location) float64 48B 23.4 24.1 21.8 24.2 20.5 25.3 - # Drop labels with all NAN values + Drop labels with all NAN values >>> dataset.dropna(dim="time", how="all") Size: 104B @@ -6402,7 +6445,7 @@ def dropna( Data variables: temperature (time, location) float64 64B 23.4 24.1 nan ... 24.2 20.5 25.3 - # Drop labels with less than 2 non-NA values + Drop labels with less than 2 non-NA values >>> dataset.dropna(dim="time", thresh=2) Size: 80B @@ -6416,6 +6459,11 @@ def dropna( Returns ------- Dataset + + See Also + -------- + DataArray.dropna + pandas.DataFrame.dropna """ # TODO: consider supporting multiple dimensions? Or not, given that # there are some ugly edge cases, e.g., pandas's dropna differs @@ -6539,7 +6587,13 @@ def interpolate_na( limit: int | None = None, use_coordinate: bool | Hashable = True, max_gap: ( - int | float | str | pd.Timedelta | np.timedelta64 | datetime.timedelta + int + | float + | str + | pd.Timedelta + | np.timedelta64 + | datetime.timedelta + | None ) = None, **kwargs: Any, ) -> Self: @@ -6573,7 +6627,8 @@ def interpolate_na( or None for no limit. This filling is done regardless of the size of the gap in the data. To only interpolate over gaps less than a given length, see ``max_gap``. - max_gap : int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta, default: None + max_gap : int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta \ + or None, default: None Maximum size of gap, a continuous sequence of NaNs, that will be filled. Use None for no limit. When interpolating along a datetime64 dimension and ``use_coordinate=True``, ``max_gap`` can be one of the following: @@ -9715,7 +9770,7 @@ def eval( c (x) float64 40B 0.0 1.25 2.5 3.75 5.0 """ - return pd.eval( + return pd.eval( # type: ignore[return-value] statement, resolvers=[self], target=self, @@ -10254,19 +10309,25 @@ def interp_calendar( """ return interp_calendar(self, target, dim=dim) + @_deprecate_positional_args("v2024.07.0") def groupby( self, - group: Hashable | DataArray | IndexVariable, - squeeze: bool | None = None, + group: ( + Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None + ) = None, + *, + squeeze: Literal[False] = False, restore_coord_dims: bool = False, + **groupers: Grouper, ) -> DatasetGroupBy: """Returns a DatasetGroupBy object for performing grouped operations. Parameters ---------- - group : Hashable, DataArray or IndexVariable + group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper Array whose unique values should be used to group this array. If a - string, must be the name of a variable contained in this dataset. + Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary, + must map an existing variable name to a :py:class:`Grouper` instance. squeeze : bool, default: True If "group" is a dimension of any arrays in this dataset, `squeeze` controls whether the subarrays have a dimension of length 1 along @@ -10274,6 +10335,10 @@ def groupby( restore_coord_dims : bool, default: False If True, also restore the dimension order of multi-dimensional coordinates. + **groupers : Mapping of str to Grouper or Resampler + Mapping of variable name to group by to :py:class:`Grouper` or :py:class:`Resampler` object. + One of ``group`` or ``groupers`` must be provided. + Only a single ``grouper`` is allowed at present. Returns ------- @@ -10305,28 +10370,46 @@ def groupby( ResolvedGrouper, _validate_groupby_squeeze, ) - from xarray.core.groupers import UniqueGrouper + from xarray.groupers import UniqueGrouper _validate_groupby_squeeze(squeeze) - rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) + + if isinstance(group, Mapping): + groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore + group = None + + if group is not None: + if groupers: + raise ValueError( + "Providing a combination of `group` and **groupers is not supported." + ) + rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) + else: + if len(groupers) > 1: + raise ValueError("Grouping by multiple variables is not supported yet.") + elif not groupers: + raise ValueError("Either `group` or `**groupers` must be provided.") + for group, grouper in groupers.items(): + rgrouper = ResolvedGrouper(grouper, group, self) return DatasetGroupBy( self, (rgrouper,), - squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) + @_deprecate_positional_args("v2024.07.0") def groupby_bins( self, group: Hashable | DataArray | IndexVariable, - bins: ArrayLike, + bins: Bins, right: bool = True, labels: ArrayLike | None = None, precision: int = 3, include_lowest: bool = False, - squeeze: bool | None = None, + squeeze: Literal[False] = False, restore_coord_dims: bool = False, + duplicates: Literal["raise", "drop"] = "raise", ) -> DatasetGroupBy: """Returns a DatasetGroupBy object for performing grouped operations. @@ -10356,13 +10439,13 @@ def groupby_bins( The precision at which to store and display the bins labels. include_lowest : bool, default: False Whether the first interval should be left-inclusive or not. - squeeze : bool, default: True - If "group" is a dimension of any arrays in this dataset, `squeeze` - controls whether the subarrays have a dimension of length 1 along - that dimension or if the dimension is squeezed out. + squeeze : False + This argument is deprecated. restore_coord_dims : bool, default: False If True, also restore the dimension order of multi-dimensional coordinates. + duplicates : {"raise", "drop"}, default: "raise" + If bin edges are not unique, raise ValueError or drop non-uniques. Returns ------- @@ -10390,24 +10473,21 @@ def groupby_bins( ResolvedGrouper, _validate_groupby_squeeze, ) - from xarray.core.groupers import BinGrouper + from xarray.groupers import BinGrouper _validate_groupby_squeeze(squeeze) grouper = BinGrouper( bins=bins, - cut_kwargs={ - "right": right, - "labels": labels, - "precision": precision, - "include_lowest": include_lowest, - }, + right=right, + labels=labels, + precision=precision, + include_lowest=include_lowest, ) rgrouper = ResolvedGrouper(grouper, group, self) return DatasetGroupBy( self, (rgrouper,), - squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) @@ -10585,18 +10665,18 @@ def coarsen( coord_func=coord_func, ) + @_deprecate_positional_args("v2024.07.0") def resample( self, - indexer: Mapping[Any, str] | None = None, + indexer: Mapping[Any, str | Resampler] | None = None, + *, skipna: bool | None = None, closed: SideOptions | None = None, label: SideOptions | None = None, - base: int | None = None, offset: pd.Timedelta | datetime.timedelta | str | None = None, origin: str | DatetimeLike = "start_day", - loffset: datetime.timedelta | str | None = None, restore_coord_dims: bool | None = None, - **indexer_kwargs: str, + **indexer_kwargs: str | Resampler, ) -> DatasetResample: """Returns a Resample object for performing resampling operations. @@ -10616,10 +10696,6 @@ def resample( Side of each interval to treat as closed. label : {"left", "right"}, optional Side of each interval to use for labeling. - base : int, optional - For frequencies that evenly subdivide 1 day, the "origin" of the - aggregated intervals. For example, for "24H" frequency, base could - range from 0 through 23. origin : {'epoch', 'start', 'start_day', 'end', 'end_day'}, pd.Timestamp, datetime.datetime, np.datetime64, or cftime.datetime, default 'start_day' The datetime on which to adjust the grouping. The timezone of origin must match the timezone of the index. @@ -10632,15 +10708,6 @@ def resample( - 'end_day': `origin` is the ceiling midnight of the last day offset : pd.Timedelta, datetime.timedelta, or str, default is None An offset timedelta added to the origin. - loffset : timedelta or str, optional - Offset used to adjust the resampled time labels. Some pandas date - offset strings are supported. - - .. deprecated:: 2023.03.0 - Following pandas, the ``loffset`` parameter is deprecated in favor - of using time offset arithmetic, and will be removed in a future - version of xarray. - restore_coord_dims : bool, optional If True, also restore the dimension order of multi-dimensional coordinates. @@ -10673,10 +10740,50 @@ def resample( skipna=skipna, closed=closed, label=label, - base=base, offset=offset, origin=origin, - loffset=loffset, restore_coord_dims=restore_coord_dims, **indexer_kwargs, ) + + def drop_attrs(self, *, deep: bool = True) -> Self: + """ + Removes all attributes from the Dataset and its variables. + + Parameters + ---------- + deep : bool, default True + Removes attributes from all variables. + + Returns + ------- + Dataset + """ + # Remove attributes from the dataset + self = self._replace(attrs={}) + + if not deep: + return self + + # Remove attributes from each variable in the dataset + for var in self.variables: + # variables don't have a `._replace` method, so we copy and then remove + # attrs. If we added a `._replace` method, we could use that instead. + if var not in self.indexes: + self[var] = self[var].copy() + self[var].attrs = {} + + new_idx_variables = {} + # Not sure this is the most elegant way of doing this, but it works. + # (Should we have a more general "map over all variables, including + # indexes" approach?) + for idx, idx_vars in self.xindexes.group_by_index(): + # copy each coordinate variable of an index and drop their attrs + temp_idx_variables = {k: v.copy() for k, v in idx_vars.items()} + for v in temp_idx_variables.values(): + v.attrs = {} + # re-wrap the index object in new coordinate variables + new_idx_variables.update(idx.create_variables(temp_idx_variables)) + self = self.assign(new_idx_variables) + + return self diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 65ff8667cb7..6289146308e 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1102,9 +1102,14 @@ def from_dict( else: obj = cls(name=name, data=root_data, parent=None, children=None) + def depth(item) -> int: + pathstr, _ = item + return len(NodePath(pathstr).parts) + if d: # Populate tree with children determined from data_objects mapping - for path, data in d.items(): + # Sort keys by depth so as to insert nodes from root first (see GH issue #9276) + for path, data in sorted(d.items(), key=depth): # Create and set new node node_name = NodePath(path).name if isinstance(data, DataTree): @@ -1546,7 +1551,12 @@ def to_netcdf( ``dask.delayed.Delayed`` object that can be computed later. Currently, ``compute=False`` is not supported. kwargs : - Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` + Additional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` + + Note + ---- + Due to file format specifications the on-disk root group name + is always ``"/"`` overriding any given ``DataTree`` root node name. """ from xarray.core.datatree_io import _datatree_to_netcdf @@ -1602,6 +1612,11 @@ def to_zarr( supported. kwargs : Additional keyword arguments to be passed to ``xarray.Dataset.to_zarr`` + + Note + ---- + Due to file format specifications the on-disk root group name + is always ``"/"`` overriding any given ``DataTree`` root node name. """ from xarray.core.datatree_io import _datatree_to_zarr diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index c8b4fa88409..b0361ef0f0f 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Callable, Generic +from typing import Callable, Generic, cast import numpy as np import pandas as pd @@ -45,7 +45,7 @@ def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): def __extension_duck_array__concatenate( arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None ) -> T_ExtensionArray: - return type(arrays[0])._concat_same_type(arrays) + return type(arrays[0])._concat_same_type(arrays) # type: ignore[attr-defined] @implements(np.where) @@ -57,9 +57,9 @@ def __extension_duck_array__where( and isinstance(y, pd.Categorical) and x.dtype != y.dtype ): - x = x.add_categories(set(y.categories).difference(set(x.categories))) - y = y.add_categories(set(x.categories).difference(set(y.categories))) - return pd.Series(x).where(condition, pd.Series(y)).array + x = x.add_categories(set(y.categories).difference(set(x.categories))) # type: ignore[assignment] + y = y.add_categories(set(x.categories).difference(set(y.categories))) # type: ignore[assignment] + return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array) class PandasExtensionArray(Generic[T_ExtensionArray]): @@ -116,7 +116,7 @@ def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: if is_extension_array_dtype(item): return type(self)(item) if np.isscalar(item): - return type(self)(type(self.array)([item])) + return type(self)(type(self.array)([item])) # type: ignore[call-arg] # only subclasses with proper __init__ allowed return item def __setitem__(self, key, val): diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 42e7f01a526..9b0758d030b 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -38,7 +38,6 @@ FrozenMappingWarningOnValuesAccess, contains_only_chunked_or_numpy, either_dict_or_kwargs, - emit_user_level_warning, hashable, is_scalar, maybe_wrap_array, @@ -53,9 +52,9 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.groupers import Grouper - from xarray.core.types import GroupIndex, GroupKey, T_GroupIndices + from xarray.core.types import GroupIndex, GroupIndices, GroupKey from xarray.core.utils import Frozen + from xarray.groupers import Grouper def check_reduce_dims(reduce_dims, dimensions): @@ -66,34 +65,12 @@ def check_reduce_dims(reduce_dims, dimensions): raise ValueError( f"cannot reduce over dimensions {reduce_dims!r}. expected either '...' " f"to reduce over all dimensions or one or more of {dimensions!r}." - f" Try passing .groupby(..., squeeze=False)" ) -def _maybe_squeeze_indices( - indices, squeeze: bool | None, grouper: ResolvedGrouper, warn: bool -): - from xarray.core.groupers import UniqueGrouper - - is_unique_grouper = isinstance(grouper.grouper, UniqueGrouper) - can_squeeze = is_unique_grouper and grouper.grouper.can_squeeze - if squeeze in [None, True] and can_squeeze: - if isinstance(indices, slice): - if indices.stop - indices.start == 1: - if (squeeze is None and warn) or squeeze is True: - emit_user_level_warning( - "The `squeeze` kwarg to GroupBy is being removed." - "Pass .groupby(..., squeeze=False) to disable squeezing," - " which is the new default, and to silence this warning." - ) - - indices = indices.start - return indices - - -def _codes_to_group_indices(inverse: np.ndarray, N: int) -> T_GroupIndices: +def _codes_to_group_indices(inverse: np.ndarray, N: int) -> GroupIndices: assert inverse.ndim == 1 - groups: T_GroupIndices = [[] for _ in range(N)] + groups: GroupIndices = tuple([] for _ in range(N)) for n, g in enumerate(inverse): if g >= 0: groups[g].append(n) @@ -226,7 +203,7 @@ def attrs(self) -> dict: def __getitem__(self, key): if isinstance(key, tuple): - key = key[0] + (key,) = key return self.values[key] def to_index(self) -> pd.Index: @@ -296,16 +273,16 @@ class ResolvedGrouper(Generic[T_DataWithCoords]): obj: T_DataWithCoords # returned by factorize: - codes: DataArray = field(init=False) - full_index: pd.Index = field(init=False) - group_indices: T_GroupIndices = field(init=False) - unique_coord: Variable | _DummyGroup = field(init=False) + codes: DataArray = field(init=False, repr=False) + full_index: pd.Index = field(init=False, repr=False) + group_indices: GroupIndices = field(init=False, repr=False) + unique_coord: Variable | _DummyGroup = field(init=False, repr=False) # _ensure_1d: - group1d: T_Group = field(init=False) - stacked_obj: T_DataWithCoords = field(init=False) - stacked_dim: Hashable | None = field(init=False) - inserted_dims: list[Hashable] = field(init=False) + group1d: T_Group = field(init=False, repr=False) + stacked_obj: T_DataWithCoords = field(init=False, repr=False) + stacked_dim: Hashable | None = field(init=False, repr=False) + inserted_dims: list[Hashable] = field(init=False, repr=False) def __post_init__(self) -> None: # This copy allows the BinGrouper.factorize() method @@ -355,11 +332,11 @@ def factorize(self) -> None: if encoded.group_indices is not None: self.group_indices = encoded.group_indices else: - self.group_indices = [ + self.group_indices = tuple( g for g in _codes_to_group_indices(self.codes.data, len(self.full_index)) if g - ] + ) if encoded.unique_coord is None: unique_values = self.full_index[np.unique(encoded.codes)] self.unique_coord = Variable( @@ -369,17 +346,15 @@ def factorize(self) -> None: self.unique_coord = encoded.unique_coord -def _validate_groupby_squeeze(squeeze: bool | None) -> None: +def _validate_groupby_squeeze(squeeze: Literal[False]) -> None: # While we don't generally check the type of every arg, passing # multiple dimensions as multiple arguments is common enough, and the # consequences hidden enough (strings evaluate as true) to warrant # checking here. # A future version could make squeeze kwarg only, but would face # backward-compat issues. - if squeeze is not None and not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be None, True or False, but {squeeze} was supplied" - ) + if squeeze is not False: + raise TypeError(f"`squeeze` must be False, but {squeeze} was supplied.") def _resolve_group( @@ -466,7 +441,6 @@ class GroupBy(Generic[T_Xarray]): "_unique_coord", "_dims", "_sizes", - "_squeeze", # Save unstacked object for flox "_original_obj", "_original_group", @@ -475,12 +449,11 @@ class GroupBy(Generic[T_Xarray]): ) _obj: T_Xarray groupers: tuple[ResolvedGrouper] - _squeeze: bool | None _restore_coord_dims: bool _original_obj: T_Xarray _original_group: T_Group - _group_indices: T_GroupIndices + _group_indices: GroupIndices _codes: DataArray _group_dim: Hashable @@ -492,7 +465,6 @@ def __init__( self, obj: T_Xarray, groupers: tuple[ResolvedGrouper], - squeeze: bool | None = False, restore_coord_dims: bool = True, ) -> None: """Create a GroupBy object @@ -517,7 +489,6 @@ def __init__( # specification for the groupby operation self._obj = grouper.stacked_obj self._restore_coord_dims = restore_coord_dims - self._squeeze = squeeze # These should generalize to multiple groupers self._group_indices = grouper.group_indices @@ -542,14 +513,8 @@ def sizes(self) -> Mapping[Hashable, int]: """ if self._sizes is None: (grouper,) = self.groupers - index = _maybe_squeeze_indices( - self._group_indices[0], - self._squeeze, - grouper, - warn=True, - ) + index = self._group_indices[0] self._sizes = self._obj.isel({self._group_dim: index}).sizes - return self._sizes def map( @@ -582,11 +547,7 @@ def groups(self) -> dict[GroupKey, GroupIndex]: # provided to mimic pandas.groupby if self._groups is None: (grouper,) = self.groupers - squeezed_indices = ( - _maybe_squeeze_indices(ind, self._squeeze, grouper, warn=idx > 0) - for idx, ind in enumerate(self._group_indices) - ) - self._groups = dict(zip(grouper.unique_coord.values, squeezed_indices)) + self._groups = dict(zip(grouper.unique_coord.values, self._group_indices)) return self._groups def __getitem__(self, key: GroupKey) -> T_Xarray: @@ -594,10 +555,7 @@ def __getitem__(self, key: GroupKey) -> T_Xarray: Get DataArray or Dataset corresponding to a particular group label. """ (grouper,) = self.groupers - index = _maybe_squeeze_indices( - self.groups[key], self._squeeze, grouper, warn=True - ) - return self._obj.isel({self._group_dim: index}) + return self._obj.isel({self._group_dim: self.groups[key]}) def __len__(self) -> int: (grouper,) = self.groupers @@ -616,17 +574,14 @@ def __repr__(self) -> str: ", ".join(format_array_flat(grouper.full_index, 30).split()), ) - def _iter_grouped(self, warn_squeeze=True) -> Iterator[T_Xarray]: + def _iter_grouped(self) -> Iterator[T_Xarray]: """Iterate over each element in this group""" (grouper,) = self.groupers for idx, indices in enumerate(self._group_indices): - indices = _maybe_squeeze_indices( - indices, self._squeeze, grouper, warn=warn_squeeze and idx == 0 - ) yield self._obj.isel({self._group_dim: indices}) def _infer_concat_args(self, applied_example): - from xarray.core.groupers import BinGrouper + from xarray.groupers import BinGrouper (grouper,) = self.groupers if self._group_dim in applied_example.dims: @@ -740,7 +695,7 @@ def _maybe_restore_empty_groups(self, combined): """Our index contained empty groups (e.g., from a resampling or binning). If we reduced on that dimension, we want to restore the full index. """ - from xarray.core.groupers import BinGrouper, TimeResampler + from xarray.groupers import BinGrouper, TimeResampler (grouper,) = self.groupers if ( @@ -776,7 +731,7 @@ def _flox_reduce( from flox.xarray import xarray_reduce from xarray.core.dataset import Dataset - from xarray.core.groupers import BinGrouper + from xarray.groupers import BinGrouper obj = self._original_obj (grouper,) = self.groupers @@ -809,14 +764,6 @@ def _flox_reduce( # set explicitly to avoid unnecessarily accumulating count kwargs["min_count"] = 0 - # weird backcompat - # reducing along a unique indexed dimension with squeeze=True - # should raise an error - if (dim is None or dim == name) and name in obj.xindexes: - index = obj.indexes[name] - if index.is_unique and self._squeeze: - raise ValueError(f"cannot reduce over dimensions {name!r}") - unindexed_dims: tuple[Hashable, ...] = tuple() if isinstance(grouper.group, _DummyGroup) and not isbin: unindexed_dims = (name,) @@ -1141,23 +1088,17 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic): def dims(self) -> tuple[Hashable, ...]: if self._dims is None: (grouper,) = self.groupers - index = _maybe_squeeze_indices( - self._group_indices[0], self._squeeze, grouper, warn=True - ) + index = self._group_indices[0] self._dims = self._obj.isel({self._group_dim: index}).dims - return self._dims - def _iter_grouped_shortcut(self, warn_squeeze=True): + def _iter_grouped_shortcut(self): """Fast version of `_iter_grouped` that yields Variables without metadata """ var = self._obj.variable (grouper,) = self.groupers for idx, indices in enumerate(self._group_indices): - indices = _maybe_squeeze_indices( - indices, self._squeeze, grouper, warn=warn_squeeze and idx == 0 - ) yield var[{self._group_dim: indices}] def _concat_shortcut(self, applied, dim, positions=None): @@ -1237,24 +1178,7 @@ def map( applied : DataArray The result of splitting, applying and combining this array. """ - return self._map_maybe_warn( - func, args, warn_squeeze=True, shortcut=shortcut, **kwargs - ) - - def _map_maybe_warn( - self, - func: Callable[..., DataArray], - args: tuple[Any, ...] = (), - *, - warn_squeeze: bool = True, - shortcut: bool | None = None, - **kwargs: Any, - ) -> DataArray: - grouped = ( - self._iter_grouped_shortcut(warn_squeeze) - if shortcut - else self._iter_grouped(warn_squeeze) - ) + grouped = self._iter_grouped_shortcut() if shortcut else self._iter_grouped() applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped) return self._combine(applied, shortcut=shortcut) @@ -1356,68 +1280,6 @@ def reduce_array(ar: DataArray) -> DataArray: return self.map(reduce_array, shortcut=shortcut) - def _reduce_without_squeeze_warn( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, - ) -> DataArray: - """Reduce the items in this group by applying `func` along some - dimension(s). - - Parameters - ---------- - func : callable - Function which can be called in the form - `func(x, axis=axis, **kwargs)` to return the result of collapsing - an np.ndarray over an integer valued axis. - dim : "...", str, Iterable of Hashable or None, optional - Dimension(s) over which to apply `func`. If None, apply over the - groupby dimension, if "..." apply over all dimensions. - axis : int or sequence of int, optional - Axis(es) over which to apply `func`. Only one of the 'dimension' - and 'axis' arguments can be supplied. If neither are supplied, then - `func` is calculated over all dimension for each group item. - keep_attrs : bool, optional - If True, the datasets's attributes (`attrs`) will be copied from - the original object to the new one. If False (default), the new - object will be returned without attributes. - **kwargs : dict - Additional keyword arguments passed on to `func`. - - Returns - ------- - reduced : Array - Array with summarized data and the indicated dimension(s) - removed. - """ - if dim is None: - dim = [self._group_dim] - - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=True) - - def reduce_array(ar: DataArray) -> DataArray: - return ar.reduce( - func=func, - dim=dim, - axis=axis, - keep_attrs=keep_attrs, - keepdims=keepdims, - **kwargs, - ) - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="The `squeeze` kwarg") - check_reduce_dims(dim, self.dims) - - return self._map_maybe_warn(reduce_array, shortcut=shortcut, warn_squeeze=False) - # https://github.com/python/mypy/issues/9031 class DataArrayGroupBy( # type: ignore[misc] @@ -1436,12 +1298,7 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic): def dims(self) -> Frozen[Hashable, int]: if self._dims is None: (grouper,) = self.groupers - index = _maybe_squeeze_indices( - self._group_indices[0], - self._squeeze, - grouper, - warn=True, - ) + index = self._group_indices[0] self._dims = self._obj.isel({self._group_dim: index}).dims return FrozenMappingWarningOnValuesAccess(self._dims) @@ -1482,18 +1339,8 @@ def map( applied : Dataset The result of splitting, applying and combining this dataset. """ - return self._map_maybe_warn(func, args, shortcut, warn_squeeze=True, **kwargs) - - def _map_maybe_warn( - self, - func: Callable[..., Dataset], - args: tuple[Any, ...] = (), - shortcut: bool | None = None, - warn_squeeze: bool = False, - **kwargs: Any, - ) -> Dataset: # ignore shortcut if set (for now) - applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped(warn_squeeze)) + applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped()) return self._combine(applied) def apply(self, func, args=(), shortcut=None, **kwargs): @@ -1588,68 +1435,6 @@ def reduce_dataset(ds: Dataset) -> Dataset: return self.map(reduce_dataset) - def _reduce_without_squeeze_warn( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, - ) -> Dataset: - """Reduce the items in this group by applying `func` along some - dimension(s). - - Parameters - ---------- - func : callable - Function which can be called in the form - `func(x, axis=axis, **kwargs)` to return the result of collapsing - an np.ndarray over an integer valued axis. - dim : ..., str, Iterable of Hashable or None, optional - Dimension(s) over which to apply `func`. By default apply over the - groupby dimension, with "..." apply over all dimensions. - axis : int or sequence of int, optional - Axis(es) over which to apply `func`. Only one of the 'dimension' - and 'axis' arguments can be supplied. If neither are supplied, then - `func` is calculated over all dimension for each group item. - keep_attrs : bool, optional - If True, the datasets's attributes (`attrs`) will be copied from - the original object to the new one. If False (default), the new - object will be returned without attributes. - **kwargs : dict - Additional keyword arguments passed on to `func`. - - Returns - ------- - reduced : Dataset - Array with summarized data and the indicated dimension(s) - removed. - """ - if dim is None: - dim = [self._group_dim] - - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=True) - - def reduce_dataset(ds: Dataset) -> Dataset: - return ds.reduce( - func=func, - dim=dim, - axis=axis, - keep_attrs=keep_attrs, - keepdims=keepdims, - **kwargs, - ) - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="The `squeeze` kwarg") - check_reduce_dims(dim, self.dims) - - return self._map_maybe_warn(reduce_dataset, warn_squeeze=False) - def assign(self, **kwargs: Any) -> Dataset: """Assign data variables by group. diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index f25c0ecf936..9d8a68edbf3 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -451,7 +451,7 @@ def safe_cast_to_index(array: Any) -> pd.Index: elif isinstance(array, PandasIndexingAdapter): index = array.array else: - kwargs: dict[str, str] = {} + kwargs: dict[str, Any] = {} if hasattr(array, "dtype"): if array.dtype.kind == "O": kwargs["dtype"] = "object" @@ -551,7 +551,7 @@ def as_scalar(value: np.ndarray): return value[()] if value.dtype.kind in "mM" else value.item() -def get_indexer_nd(index, labels, method=None, tolerance=None): +def get_indexer_nd(index: pd.Index, labels, method=None, tolerance=None) -> np.ndarray: """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional labels """ @@ -740,7 +740,7 @@ def isel( # scalar indexer: drop index return None - return self._replace(self.index[indxr]) + return self._replace(self.index[indxr]) # type: ignore[index] def sel( self, labels: dict[Any, Any], method=None, tolerance=None @@ -898,36 +898,45 @@ def _check_dim_compat(variables: Mapping[Any, Variable], all_dims: str = "equal" ) -def remove_unused_levels_categories(index: pd.Index) -> pd.Index: +T_PDIndex = TypeVar("T_PDIndex", bound=pd.Index) + + +def remove_unused_levels_categories(index: T_PDIndex) -> T_PDIndex: """ Remove unused levels from MultiIndex and unused categories from CategoricalIndex """ if isinstance(index, pd.MultiIndex): - index = index.remove_unused_levels() + new_index = cast(pd.MultiIndex, index.remove_unused_levels()) # if it contains CategoricalIndex, we need to remove unused categories # manually. See https://github.com/pandas-dev/pandas/issues/30846 - if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels): + if any(isinstance(lev, pd.CategoricalIndex) for lev in new_index.levels): levels = [] - for i, level in enumerate(index.levels): + for i, level in enumerate(new_index.levels): if isinstance(level, pd.CategoricalIndex): - level = level[index.codes[i]].remove_unused_categories() + level = level[new_index.codes[i]].remove_unused_categories() else: - level = level[index.codes[i]] + level = level[new_index.codes[i]] levels.append(level) # TODO: calling from_array() reorders MultiIndex levels. It would # be best to avoid this, if possible, e.g., by using # MultiIndex.remove_unused_levels() (which does not reorder) on the # part of the MultiIndex that is not categorical, or by fixing this # upstream in pandas. - index = pd.MultiIndex.from_arrays(levels, names=index.names) - elif isinstance(index, pd.CategoricalIndex): - index = index.remove_unused_categories() + new_index = pd.MultiIndex.from_arrays(levels, names=new_index.names) + return cast(T_PDIndex, new_index) + + if isinstance(index, pd.CategoricalIndex): + return index.remove_unused_categories() # type: ignore[attr-defined] + return index class PandasMultiIndex(PandasIndex): """Wrap a pandas.MultiIndex as an xarray compatible index.""" + index: pd.MultiIndex + dim: Hashable + coord_dtype: Any level_coords_dtype: dict[str, Any] __slots__ = ("index", "dim", "coord_dtype", "level_coords_dtype") @@ -1063,8 +1072,8 @@ def from_variables_maybe_expand( The index and its corresponding coordinates may be created along a new dimension. """ names: list[Hashable] = [] - codes: list[list[int]] = [] - levels: list[list[int]] = [] + codes: list[Iterable[int]] = [] + levels: list[Iterable[Any]] = [] level_variables: dict[Any, Variable] = {} _check_dim_compat({**current_variables, **variables}) @@ -1134,7 +1143,7 @@ def reorder_levels( its corresponding coordinates. """ - index = self.index.reorder_levels(level_variables.keys()) + index = cast(pd.MultiIndex, self.index.reorder_levels(level_variables.keys())) level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names} return self._replace(index, level_coords_dtype=level_coords_dtype) @@ -1147,13 +1156,13 @@ def create_variables( variables = {} index_vars: IndexVars = {} - for name in (self.dim,) + self.index.names: + for name in (self.dim,) + tuple(self.index.names): if name == self.dim: level = None dtype = None else: level = name - dtype = self.level_coords_dtype[name] + dtype = self.level_coords_dtype[name] # type: ignore[index] # TODO: are Hashables ok? var = variables.get(name, None) if var is not None: @@ -1163,7 +1172,7 @@ def create_variables( attrs = {} encoding = {} - data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) + data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) # type: ignore[arg-type] # TODO: are Hashables ok? index_vars[name] = IndexVariable( self.dim, data, @@ -1186,6 +1195,8 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: new_index = None scalar_coord_values = {} + indexer: int | slice | np.ndarray | Variable | DataArray + # label(s) given for multi-index level(s) if all([lbl in self.index.names for lbl in labels]): label_values = {} @@ -1212,7 +1223,7 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: ) scalar_coord_values.update(label_values) # GH2619. Raise a KeyError if nothing is chosen - if indexer.dtype.kind == "b" and indexer.sum() == 0: + if indexer.dtype.kind == "b" and indexer.sum() == 0: # type: ignore[union-attr] raise KeyError(f"{labels} not found") # assume one label value given for the multi-index "array" (dimension) @@ -1600,9 +1611,7 @@ def group_by_index( """Returns a list of unique indexes and their corresponding coordinates.""" index_coords = [] - - for i in self._id_index: - index = self._id_index[i] + for i, index in self._id_index.items(): coords = {k: self._variables[k] for k in self._id_coord_names[i]} index_coords.append((index, coords)) @@ -1640,26 +1649,28 @@ def copy_indexes( in this dict. """ - new_indexes = {} - new_index_vars = {} + new_indexes: dict[Hashable, T_PandasOrXarrayIndex] = {} + new_index_vars: dict[Hashable, Variable] = {} - idx: T_PandasOrXarrayIndex + xr_idx: Index + new_idx: T_PandasOrXarrayIndex for idx, coords in self.group_by_index(): if isinstance(idx, pd.Index): convert_new_idx = True dim = next(iter(coords.values())).dims[0] if isinstance(idx, pd.MultiIndex): - idx = PandasMultiIndex(idx, dim) + xr_idx = PandasMultiIndex(idx, dim) else: - idx = PandasIndex(idx, dim) + xr_idx = PandasIndex(idx, dim) else: convert_new_idx = False + xr_idx = idx - new_idx = idx._copy(deep=deep, memo=memo) - idx_vars = idx.create_variables(coords) + new_idx = xr_idx._copy(deep=deep, memo=memo) # type: ignore[assignment] + idx_vars = xr_idx.create_variables(coords) if convert_new_idx: - new_idx = cast(PandasIndex, new_idx).index + new_idx = new_idx.index # type: ignore[attr-defined] new_indexes.update({k: new_idx for k in coords}) new_index_vars.update(idx_vars) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 06e7efdbb48..19937270268 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -34,6 +34,7 @@ from numpy.typing import DTypeLike from xarray.core.indexes import Index + from xarray.core.types import Self from xarray.core.variable import Variable from xarray.namedarray._typing import _Shape, duckarray from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @@ -1656,6 +1657,9 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): __slots__ = ("array", "_dtype") + array: pd.Index + _dtype: np.dtype + def __init__(self, array: pd.Index, dtype: DTypeLike = None): from xarray.core.indexes import safe_cast_to_index @@ -1792,7 +1796,7 @@ def transpose(self, order) -> pd.Index: def __repr__(self) -> str: return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})" - def copy(self, deep: bool = True) -> PandasIndexingAdapter: + def copy(self, deep: bool = True) -> Self: # Not the same as just writing `self.array.copy(deep=deep)`, as # shallow copies of the underlying numpy.ndarrays become deep ones # upon pickling @@ -1810,11 +1814,14 @@ class PandasMultiIndexingAdapter(PandasIndexingAdapter): This allows creating one instance for each multi-index level while preserving indexing efficiency (memoized + might reuse another instance with the same multi-index). - """ __slots__ = ("array", "_dtype", "level", "adapter") + array: pd.MultiIndex + _dtype: np.dtype + level: str | None + def __init__( self, array: pd.MultiIndex, @@ -1910,7 +1917,7 @@ def _repr_html_(self) -> str: array_repr = short_array_repr(self._get_array_subset()) return f"
{escape(array_repr)}
" - def copy(self, deep: bool = True) -> PandasMultiIndexingAdapter: + def copy(self, deep: bool = True) -> Self: # see PandasIndexingAdapter.copy array = self.array.copy(deep=True) if deep else self.array return type(self)(array, self._dtype, self.level) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 45abc70c0d3..bfbad72649a 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -315,7 +315,9 @@ def interp_na( use_coordinate: bool | str = True, method: InterpOptions = "linear", limit: int | None = None, - max_gap: int | float | str | pd.Timedelta | np.timedelta64 | dt.timedelta = None, + max_gap: ( + int | float | str | pd.Timedelta | np.timedelta64 | dt.timedelta | None + ) = None, keep_attrs: bool | None = None, **kwargs, ): diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py index 2493c08ca8e..92d30e1d31b 100644 --- a/xarray/core/npcompat.py +++ b/xarray/core/npcompat.py @@ -28,14 +28,18 @@ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +from typing import Any try: # requires numpy>=2.0 from numpy import isdtype # type: ignore[attr-defined,unused-ignore] except ImportError: import numpy as np + from numpy.typing import DTypeLike - dtype_kinds = { + kind_mapping = { "bool": np.bool_, "signed integer": np.signedinteger, "unsigned integer": np.unsignedinteger, @@ -45,16 +49,25 @@ "numeric": np.number, } - def isdtype(dtype, kind): + def isdtype( + dtype: np.dtype[Any] | type[Any], kind: DTypeLike | tuple[DTypeLike, ...] + ) -> bool: kinds = kind if isinstance(kind, tuple) else (kind,) + str_kinds = {k for k in kinds if isinstance(k, str)} + type_kinds = {k.type for k in kinds if isinstance(k, np.dtype)} - unknown_dtypes = [kind for kind in kinds if kind not in dtype_kinds] - if unknown_dtypes: - raise ValueError(f"unknown dtype kinds: {unknown_dtypes}") + if unknown_kind_types := set(kinds) - str_kinds - type_kinds: + raise TypeError( + f"kind must be str, np.dtype or a tuple of these, got {unknown_kind_types}" + ) + if unknown_kinds := {k for k in str_kinds if k not in kind_mapping}: + raise ValueError( + f"unknown kind: {unknown_kinds}, must be a np.dtype or one of {list(kind_mapping)}" + ) # verified the dtypes already, no need to check again - translated_kinds = [dtype_kinds[kind] for kind in kinds] + translated_kinds = {kind_mapping[k] for k in str_kinds} | type_kinds if isinstance(dtype, np.generic): - return any(isinstance(dtype, kind) for kind in translated_kinds) + return isinstance(dtype, translated_kinds) else: - return any(np.issubdtype(dtype, kind) for kind in translated_kinds) + return any(np.issubdtype(dtype, k) for k in translated_kinds) diff --git a/xarray/core/pdcompat.py b/xarray/core/pdcompat.py index c09dd82b074..ae4febd6beb 100644 --- a/xarray/core/pdcompat.py +++ b/xarray/core/pdcompat.py @@ -41,8 +41,6 @@ import pandas as pd from packaging.version import Version -from xarray.coding import cftime_offsets - def count_not_none(*args) -> int: """Compute the number of non-None arguments. @@ -75,26 +73,6 @@ def __repr__(self) -> str: NoDefault = Literal[_NoDefault.no_default] # For typing following pandas -def _convert_base_to_offset(base, freq, index): - """Required until we officially deprecate the base argument to resample. This - translates a provided `base` argument to an `offset` argument, following logic - from pandas. - """ - from xarray.coding.cftimeindex import CFTimeIndex - - if isinstance(index, pd.DatetimeIndex): - freq = cftime_offsets._new_to_legacy_freq(freq) - freq = pd.tseries.frequencies.to_offset(freq) - if isinstance(freq, pd.offsets.Tick): - return pd.Timedelta(base * freq.nanos // freq.n) - elif isinstance(index, CFTimeIndex): - freq = cftime_offsets.to_offset(freq) - if isinstance(freq, cftime_offsets.Tick): - return base * freq.as_timedelta() // freq.n - else: - raise ValueError("Can only resample using a DatetimeIndex or CFTimeIndex.") - - def nanosecond_precision_timestamp(*args, **kwargs) -> pd.Timestamp: """Return a nanosecond-precision Timestamp object. diff --git a/xarray/core/resample.py b/xarray/core/resample.py index ec86f2a283f..86b550466da 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -15,7 +15,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset -from xarray.core.groupers import RESAMPLE_DIM +from xarray.groupers import RESAMPLE_DIM class Resample(GroupBy[T_Xarray]): @@ -286,21 +286,9 @@ def map( applied : DataArray The result of splitting, applying and combining this array. """ - return self._map_maybe_warn(func, args, shortcut, warn_squeeze=True, **kwargs) - - def _map_maybe_warn( - self, - func: Callable[..., Any], - args: tuple[Any, ...] = (), - shortcut: bool | None = False, - warn_squeeze: bool = True, - **kwargs: Any, - ) -> DataArray: # TODO: the argument order for Resample doesn't match that for its parent, # GroupBy - combined = super()._map_maybe_warn( - func, shortcut=shortcut, args=args, warn_squeeze=warn_squeeze, **kwargs - ) + combined = super().map(func, shortcut=shortcut, args=args, **kwargs) # If the aggregation function didn't drop the original resampling # dimension, then we need to do so before we can rename the proxy @@ -380,18 +368,8 @@ def map( applied : Dataset The result of splitting, applying and combining this dataset. """ - return self._map_maybe_warn(func, args, shortcut, warn_squeeze=True, **kwargs) - - def _map_maybe_warn( - self, - func: Callable[..., Any], - args: tuple[Any, ...] = (), - shortcut: bool | None = None, - warn_squeeze: bool = True, - **kwargs: Any, - ) -> Dataset: # ignore shortcut if set (for now) - applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped(warn_squeeze)) + applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped()) combined = self._combine(applied) # If the aggregation function didn't drop the original resampling @@ -466,27 +444,6 @@ def reduce( **kwargs, ) - def _reduce_without_squeeze_warn( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, - ) -> Dataset: - return super()._reduce_without_squeeze_warn( - func=func, - dim=dim, - axis=axis, - keep_attrs=keep_attrs, - keepdims=keepdims, - shortcut=shortcut, - **kwargs, - ) - def asfreq(self) -> Dataset: """Return values of original object at the new up-sampling frequency; essentially a re-index with new times set to NaN. diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py index 216bd8fca6b..caa2d166d24 100644 --- a/xarray/core/resample_cftime.py +++ b/xarray/core/resample_cftime.py @@ -66,20 +66,22 @@ class CFTimeGrouper: single method, the only one required for resampling in xarray. It cannot be used in a call to groupby like a pandas.Grouper object can.""" + freq: BaseCFTimeOffset + closed: SideOptions + label: SideOptions + loffset: str | datetime.timedelta | BaseCFTimeOffset | None + origin: str | CFTimeDatetime + offset: datetime.timedelta | None + def __init__( self, freq: str | BaseCFTimeOffset, closed: SideOptions | None = None, label: SideOptions | None = None, - loffset: str | datetime.timedelta | BaseCFTimeOffset | None = None, origin: str | CFTimeDatetime = "start_day", - offset: str | datetime.timedelta | None = None, + offset: str | datetime.timedelta | BaseCFTimeOffset | None = None, ): - self.offset: datetime.timedelta | None - self.closed: SideOptions - self.label: SideOptions self.freq = to_offset(freq) - self.loffset = loffset self.origin = origin if isinstance(self.freq, (MonthEnd, QuarterEnd, YearEnd)): @@ -120,10 +122,10 @@ def __init__( if offset is not None: try: self.offset = _convert_offset_to_timedelta(offset) - except (ValueError, AttributeError) as error: + except (ValueError, TypeError) as error: raise ValueError( f"offset must be a datetime.timedelta object or an offset string " - f"that can be converted to a timedelta. Got {offset} instead." + f"that can be converted to a timedelta. Got {type(offset)} instead." ) from error else: self.offset = None @@ -141,22 +143,6 @@ def first_items(self, index: CFTimeIndex): datetime_bins, labels = _get_time_bins( index, self.freq, self.closed, self.label, self.origin, self.offset ) - if self.loffset is not None: - if not isinstance( - self.loffset, (str, datetime.timedelta, BaseCFTimeOffset) - ): - # BaseCFTimeOffset is not public API so we do not include it in - # the error message for now. - raise ValueError( - f"`loffset` must be a str or datetime.timedelta object. " - f"Got {self.loffset}." - ) - - if isinstance(self.loffset, datetime.timedelta): - labels = labels + self.loffset - else: - labels = labels + to_offset(self.loffset) - # check binner fits data if index[0] < datetime_bins[0]: raise ValueError("Value falls before first bin") @@ -250,12 +236,12 @@ def _get_time_bins( def _adjust_bin_edges( - datetime_bins: np.ndarray, + datetime_bins: CFTimeIndex, freq: BaseCFTimeOffset, closed: SideOptions, index: CFTimeIndex, - labels: np.ndarray, -): + labels: CFTimeIndex, +) -> tuple[CFTimeIndex, CFTimeIndex]: """This is required for determining the bin edges resampling with month end, quarter end, and year end frequencies. @@ -499,10 +485,11 @@ def _convert_offset_to_timedelta( ) -> datetime.timedelta: if isinstance(offset, datetime.timedelta): return offset - elif isinstance(offset, (str, Tick)): - return to_offset(offset).as_timedelta() - else: - raise ValueError + if isinstance(offset, (str, Tick)): + timedelta_cftime_offset = to_offset(offset) + if isinstance(timedelta_cftime_offset, Tick): + return timedelta_cftime_offset.as_timedelta() + raise TypeError(f"Expected timedelta, str or Tick, got {type(offset)}") def _ceil_via_cftimeindex(date: CFTimeDatetime, freq: str | BaseCFTimeOffset): diff --git a/xarray/core/types.py b/xarray/core/types.py index 588d15dc35d..591320d26da 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -28,10 +28,11 @@ else: Self: Any = None -if TYPE_CHECKING: - from numpy._typing import _SupportsDType - from numpy.typing import ArrayLike +from numpy._typing import _SupportsDType +from numpy.typing import ArrayLike + +if TYPE_CHECKING: from xarray.backends.common import BackendEntrypoint from xarray.core.alignment import Aligner from xarray.core.common import AbstractArray, DataWithCoords @@ -41,11 +42,12 @@ from xarray.core.indexes import Index, Indexes from xarray.core.utils import Frozen from xarray.core.variable import Variable + from xarray.groupers import TimeResampler try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray # type: ignore + DaskArray = np.ndarray try: from cubed import Array as CubedArray @@ -177,7 +179,7 @@ def copy( T_ExtensionArray = TypeVar("T_ExtensionArray", bound=pd.api.extensions.ExtensionArray) -ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] +ScalarOrArray = Union["ArrayLike", np.generic] VarCompatible = Union["Variable", "ScalarOrArray"] DaCompatible = Union["DataArray", "VarCompatible"] DsCompatible = Union["Dataset", "DaCompatible"] @@ -190,6 +192,8 @@ def copy( # FYI in some cases we don't allow `None`, which this doesn't take account of. # FYI the `str` is for a size string, e.g. "16MB", supported by dask. T_ChunkDim: TypeAlias = Union[str, int, Literal["auto"], None, tuple[int, ...]] +T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim] +T_ChunksFreq: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDimFreq]] # We allow the tuple form of this (though arguably we could transition to named dims only) T_Chunks: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDim]] T_NormalizedChunks = tuple[tuple[int, ...], ...] @@ -219,6 +223,7 @@ def copy( DatetimeUnitOptions = Literal[ "Y", "M", "W", "D", "h", "m", "s", "ms", "us", "μs", "ns", "ps", "fs", "as", None ] +NPDatetimeUnitOptions = Literal["D", "h", "m", "s", "ms", "us", "ns"] QueryEngineOptions = Literal["python", "numexpr", None] QueryParserOptions = Literal["pandas", "python"] @@ -293,4 +298,7 @@ def copy( GroupKey = Any GroupIndex = Union[int, slice, list[int]] -T_GroupIndices = list[GroupIndex] +GroupIndices = tuple[GroupIndex, ...] +Bins = Union[ + int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index +] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 5cb52cbd25c..c2859632360 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -57,19 +57,12 @@ Mapping, MutableMapping, MutableSet, + Sequence, ValuesView, ) from enum import Enum from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Literal, - TypeVar, - overload, -) +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, overload import numpy as np import pandas as pd @@ -117,26 +110,27 @@ def wrapper(*args, **kwargs): return wrapper -def get_valid_numpy_dtype(array: np.ndarray | pd.Index): +def get_valid_numpy_dtype(array: np.ndarray | pd.Index) -> np.dtype: """Return a numpy compatible dtype from either a numpy array or a pandas.Index. - Used for wrapping a pandas.Index as an xarray,Variable. + Used for wrapping a pandas.Index as an xarray.Variable. """ if isinstance(array, pd.PeriodIndex): - dtype = np.dtype("O") - elif hasattr(array, "categories"): + return np.dtype("O") + + if hasattr(array, "categories"): # category isn't a real numpy dtype dtype = array.categories.dtype if not is_valid_numpy_dtype(dtype): dtype = np.dtype("O") - elif not is_valid_numpy_dtype(array.dtype): - dtype = np.dtype("O") - else: - dtype = array.dtype + return dtype + + if not is_valid_numpy_dtype(array.dtype): + return np.dtype("O") - return dtype + return array.dtype # type: ignore[return-value] def maybe_coerce_to_str(index, original_coords): @@ -183,18 +177,17 @@ def equivalent(first: T, second: T) -> bool: if isinstance(first, np.ndarray) or isinstance(second, np.ndarray): return duck_array_ops.array_equiv(first, second) if isinstance(first, list) or isinstance(second, list): - return list_equiv(first, second) - return (first == second) or (pd.isnull(first) and pd.isnull(second)) + return list_equiv(first, second) # type: ignore[arg-type] + return (first == second) or (pd.isnull(first) and pd.isnull(second)) # type: ignore[call-overload] -def list_equiv(first, second): - equiv = True +def list_equiv(first: Sequence[T], second: Sequence[T]) -> bool: if len(first) != len(second): return False - else: - for f, s in zip(first, second): - equiv = equiv and equivalent(f, s) - return equiv + for f, s in zip(first, second): + if not equivalent(f, s): + return False + return True def peek_at(iterable: Iterable[T]) -> tuple[T, Iterator[T]]: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f0685882595..828c53e6187 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -8,7 +8,7 @@ from collections.abc import Hashable, Mapping, Sequence from datetime import timedelta from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast +from typing import TYPE_CHECKING, Any, Callable, NoReturn, cast import numpy as np import pandas as pd @@ -63,6 +63,7 @@ PadReflectOptions, QuantileMethods, Self, + T_Chunks, T_DuckArray, ) from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @@ -298,7 +299,7 @@ def as_compatible_data( # we don't want nested self-described arrays if isinstance(data, (pd.Series, pd.DataFrame)): - data = data.values + data = data.values # type: ignore[assignment] if isinstance(data, np.ma.MaskedArray): mask = np.ma.getmaskarray(data) @@ -1504,7 +1505,7 @@ def _unstack_once( # Potentially we could replace `len(other_dims)` with just `-1` other_dims = [d for d in self.dims if d != dim] new_shape = tuple(list(reordered.shape[: len(other_dims)]) + new_dim_sizes) - new_dims = reordered.dims[: len(other_dims)] + new_dim_names + new_dims = reordered.dims[: len(other_dims)] + tuple(new_dim_names) create_template: Callable if fill_value is dtypes.NA: @@ -2522,7 +2523,7 @@ def _to_dense(self) -> Variable: def chunk( # type: ignore[override] self, - chunks: int | Literal["auto"] | Mapping[Any, None | int | tuple[int, ...]] = {}, + chunks: T_Chunks = {}, name: str | None = None, lock: bool | None = None, inline_array: bool | None = None, diff --git a/xarray/datatree_/docs/source/io.rst b/xarray/datatree_/docs/source/io.rst index 2f2dabf9948..68ab6c18e12 100644 --- a/xarray/datatree_/docs/source/io.rst +++ b/xarray/datatree_/docs/source/io.rst @@ -24,6 +24,12 @@ To open a whole netCDF file as a tree of groups use the :py:func:`open_datatree` To save a DataTree object as a netCDF file containing many groups, use the :py:meth:`DataTree.to_netcdf` method. +.. _netcdf.root_group.note: + +.. note:: + Due to file format specifications the on-disk root group name is always ``"/"``, + overriding any given ``DataTree`` root node name. + .. _netcdf.group.warning: .. warning:: @@ -52,3 +58,8 @@ To save a DataTree object as a zarr store containing many groups, use the :py:me .. note:: Note that perfect round-tripping should always be possible with a zarr store (:ref:`unlike for netCDF files `), as zarr does not support "unused" dimensions. + + For the root group the same restrictions (:ref:`as for netCDF files `) apply. + Due to file format specifications the on-disk root group name is always ``"/"`` + overriding any given ``DataTree`` root node name. + diff --git a/xarray/core/groupers.py b/xarray/groupers.py similarity index 56% rename from xarray/core/groupers.py rename to xarray/groupers.py index 075afd9f62f..becb005b66c 100644 --- a/xarray/core/groupers.py +++ b/xarray/groupers.py @@ -8,9 +8,8 @@ import datetime from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence from dataclasses import dataclass, field -from typing import Any +from typing import Any, Literal, cast import numpy as np import pandas as pd @@ -21,8 +20,7 @@ from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper -from xarray.core.types import DatetimeLike, SideOptions, T_GroupIndices -from xarray.core.utils import emit_user_level_warning +from xarray.core.types import Bins, DatetimeLike, GroupIndices, SideOptions from xarray.core.variable import Variable __all__ = [ @@ -41,20 +39,25 @@ class EncodedGroups: """ Dataclass for storing intermediate values for GroupBy operation. - Returned by factorize method on Grouper objects. + Returned by the ``factorize`` method on Grouper objects. - Parameters + Attributes ---------- - codes: integer codes for each group - full_index: pandas Index for the group coordinate - group_indices: optional, List of indices of array elements belonging - to each group. Inferred if not provided. - unique_coord: Unique group values present in dataset. Inferred if not provided + codes : DataArray + Same shape as the DataArray to group by. Values consist of a unique integer code for each group. + full_index : pd.Index + Pandas Index for the group coordinate containing unique group labels. + This can differ from ``unique_coord`` in the case of resampling and binning, + where certain groups in the output need not be present in the input. + group_indices : tuple of int or slice or list of int, optional + List of indices of array elements belonging to each group. Inferred if not provided. + unique_coord : Variable, optional + Unique group values present in dataset. Inferred if not provided """ codes: DataArray full_index: pd.Index - group_indices: T_GroupIndices | None = field(default=None) + group_indices: GroupIndices | None = field(default=None) unique_coord: Variable | _DummyGroup | None = field(default=None) def __post_init__(self): @@ -69,33 +72,29 @@ def __post_init__(self): class Grouper(ABC): - """Base class for Grouper objects that allow specializing GroupBy instructions.""" - - @property - def can_squeeze(self) -> bool: - """TODO: delete this when the `squeeze` kwarg is deprecated. Only `UniqueGrouper` - should override it.""" - return False + """Abstract base class for Grouper objects that allow specializing GroupBy instructions.""" @abstractmethod - def factorize(self, group) -> EncodedGroups: + def factorize(self, group: T_Group) -> EncodedGroups: """ - Takes the group, and creates intermediates necessary for GroupBy. - These intermediates are - 1. codes - Same shape as `group` containing a unique integer code for each group. - 2. group_indices - Indexes that let us index out the members of each group. - 3. unique_coord - Unique groups present in the dataset. - 4. full_index - Unique groups in the output. This differs from `unique_coord` in the - case of resampling and binning, where certain groups in the output are not present in - the input. - - Returns an instance of EncodedGroups. + Creates intermediates necessary for GroupBy. + + Parameters + ---------- + group : DataArray + DataArray we are grouping by. + + Returns + ------- + EncodedGroups """ pass class Resampler(Grouper): - """Base class for Grouper objects that allow specializing resampling-type GroupBy instructions. + """ + Abstract base class for Grouper objects that allow specializing resampling-type GroupBy instructions. + Currently only used for TimeResampler, but could be used for SpaceResampler in the future. """ @@ -106,31 +105,27 @@ class Resampler(Grouper): class UniqueGrouper(Grouper): """Grouper object for grouping by a categorical variable.""" - _group_as_index: pd.Index | None = None - - @property - def is_unique_and_monotonic(self) -> bool: - if isinstance(self.group, _DummyGroup): - return True - index = self.group_as_index - return index.is_unique and index.is_monotonic_increasing + _group_as_index: pd.Index | None = field(default=None, repr=False) @property def group_as_index(self) -> pd.Index: + """Caches the group DataArray as a pandas Index.""" if self._group_as_index is None: self._group_as_index = self.group.to_index() return self._group_as_index - @property - def can_squeeze(self) -> bool: - """This is a deprecated method and will be removed eventually.""" - is_dimension = self.group.dims == (self.group.name,) - return is_dimension and self.is_unique_and_monotonic - - def factorize(self, group1d) -> EncodedGroups: + def factorize(self, group1d: T_Group) -> EncodedGroups: self.group = group1d - if self.can_squeeze: + index = self.group_as_index + is_unique_and_monotonic = isinstance(self.group, _DummyGroup) or ( + index.is_unique + and (index.is_monotonic_increasing or index.is_monotonic_decreasing) + ) + is_dimension = self.group.dims == (self.group.name,) + can_squeeze = is_dimension and is_unique_and_monotonic + + if can_squeeze: return self._factorize_dummy() else: return self._factorize_unique() @@ -158,8 +153,9 @@ def _factorize_dummy(self) -> EncodedGroups: # no need to factorize # use slices to do views instead of fancy indexing # equivalent to: group_indices = group_indices.reshape(-1, 1) - group_indices: T_GroupIndices = [slice(i, i + 1) for i in range(size)] + group_indices: GroupIndices = tuple(slice(i, i + 1) for i in range(size)) size_range = np.arange(size) + full_index: pd.Index if isinstance(self.group, _DummyGroup): codes = self.group.to_dataarray().copy(data=size_range) unique_coord = self.group @@ -179,23 +175,71 @@ def _factorize_dummy(self) -> EncodedGroups: @dataclass class BinGrouper(Grouper): - """Grouper object for binning numeric data.""" + """ + Grouper object for binning numeric data. - bins: int | Sequence | pd.IntervalIndex - cut_kwargs: Mapping = field(default_factory=dict) - binned: Any = None - name: Any = None + Attributes + ---------- + bins : int, sequence of scalars, or IntervalIndex + The criteria to bin by. + + * int : Defines the number of equal-width bins in the range of `x`. The + range of `x` is extended by .1% on each side to include the minimum + and maximum values of `x`. + * sequence of scalars : Defines the bin edges allowing for non-uniform + width. No extension of the range of `x` is done. + * IntervalIndex : Defines the exact bins to be used. Note that + IntervalIndex for `bins` must be non-overlapping. + + right : bool, default True + Indicates whether `bins` includes the rightmost edge or not. If + ``right == True`` (the default), then the `bins` ``[1, 2, 3, 4]`` + indicate (1,2], (2,3], (3,4]. This argument is ignored when + `bins` is an IntervalIndex. + labels : array or False, default None + Specifies the labels for the returned bins. Must be the same length as + the resulting bins. If False, returns only integer indicators of the + bins. This affects the type of the output container (see below). + This argument is ignored when `bins` is an IntervalIndex. If True, + raises an error. When `ordered=False`, labels must be provided. + retbins : bool, default False + Whether to return the bins or not. Useful when bins is provided + as a scalar. + precision : int, default 3 + The precision at which to store and display the bins labels. + include_lowest : bool, default False + Whether the first interval should be left-inclusive or not. + duplicates : {"raise", "drop"}, default: "raise" + If bin edges are not unique, raise ValueError or drop non-uniques. + """ + + bins: Bins + # The rest are copied from pandas + right: bool = True + labels: Any = None + precision: int = 3 + include_lowest: bool = False + duplicates: Literal["raise", "drop"] = "raise" def __post_init__(self) -> None: if duck_array_ops.isnull(self.bins).all(): raise ValueError("All bin edges are NaN.") - def factorize(self, group) -> EncodedGroups: + def factorize(self, group: T_Group) -> EncodedGroups: from xarray.core.dataarray import DataArray - data = group.data - - binned, self.bins = pd.cut(data, self.bins, **self.cut_kwargs, retbins=True) + data = np.asarray(group.data) # Cast _DummyGroup data to array + + binned, self.bins = pd.cut( # type: ignore [call-overload] + data, + bins=self.bins, + right=self.right, + labels=self.labels, + precision=self.precision, + include_lowest=self.include_lowest, + duplicates=self.duplicates, + retbins=True, + ) binned_codes = binned.codes if (binned_codes == -1).all(): @@ -220,71 +264,66 @@ def factorize(self, group) -> EncodedGroups: ) -@dataclass +@dataclass(repr=False) class TimeResampler(Resampler): - """Grouper object specialized to resampling the time coordinate.""" + """ + Grouper object specialized to resampling the time coordinate. + + Attributes + ---------- + freq : str + Frequency to resample to. See `Pandas frequency + aliases `_ + for a list of possible values. + closed : {"left", "right"}, optional + Side of each interval to treat as closed. + label : {"left", "right"}, optional + Side of each interval to use for labeling. + origin : {'epoch', 'start', 'start_day', 'end', 'end_day'}, pandas.Timestamp, datetime.datetime, numpy.datetime64, or cftime.datetime, default 'start_day' + The datetime on which to adjust the grouping. The timezone of origin + must match the timezone of the index. + + If a datetime is not used, these values are also supported: + - 'epoch': `origin` is 1970-01-01 + - 'start': `origin` is the first value of the timeseries + - 'start_day': `origin` is the first day at midnight of the timeseries + - 'end': `origin` is the last value of the timeseries + - 'end_day': `origin` is the ceiling midnight of the last day + offset : pd.Timedelta, datetime.timedelta, or str, default is None + An offset timedelta added to the origin. + """ freq: str closed: SideOptions | None = field(default=None) label: SideOptions | None = field(default=None) origin: str | DatetimeLike = field(default="start_day") offset: pd.Timedelta | datetime.timedelta | str | None = field(default=None) - loffset: datetime.timedelta | str | None = field(default=None) - base: int | None = field(default=None) - index_grouper: CFTimeGrouper | pd.Grouper = field(init=False) - group_as_index: pd.Index = field(init=False) - - def __post_init__(self): - if self.loffset is not None: - emit_user_level_warning( - "Following pandas, the `loffset` parameter to resample is deprecated. " - "Switch to updating the resampled dataset time coordinate using " - "time offset arithmetic. For example:\n" - " >>> offset = pd.tseries.frequencies.to_offset(freq) / 2\n" - ' >>> resampled_ds["time"] = resampled_ds.get_index("time") + offset', - FutureWarning, - ) - - if self.base is not None: - emit_user_level_warning( - "Following pandas, the `base` parameter to resample will be deprecated in " - "a future version of xarray. Switch to using `origin` or `offset` instead.", - FutureWarning, - ) - - if self.base is not None and self.offset is not None: - raise ValueError("base and offset cannot be present at the same time") + index_grouper: CFTimeGrouper | pd.Grouper = field(init=False, repr=False) + group_as_index: pd.Index = field(init=False, repr=False) def _init_properties(self, group: T_Group) -> None: from xarray import CFTimeIndex - from xarray.core.pdcompat import _convert_base_to_offset group_as_index = safe_cast_to_index(group) - - if self.base is not None: - # grouper constructor verifies that grouper.offset is None at this point - offset = _convert_base_to_offset(self.base, self.freq, group_as_index) - else: - offset = self.offset + offset = self.offset if not group_as_index.is_monotonic_increasing: # TODO: sort instead of raising an error - raise ValueError("index must be monotonic for resampling") + raise ValueError("Index must be monotonic for resampling") if isinstance(group_as_index, CFTimeIndex): from xarray.core.resample_cftime import CFTimeGrouper - index_grouper = CFTimeGrouper( + self.index_grouper = CFTimeGrouper( freq=self.freq, closed=self.closed, label=self.label, origin=self.origin, offset=offset, - loffset=self.loffset, ) else: - index_grouper = pd.Grouper( + self.index_grouper = pd.Grouper( # TODO remove once requiring pandas >= 2.2 freq=_new_to_legacy_freq(self.freq), closed=self.closed, @@ -292,7 +331,6 @@ def _init_properties(self, group: T_Group) -> None: origin=self.origin, offset=offset, ) - self.index_grouper = index_grouper self.group_as_index = group_as_index def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: @@ -305,10 +343,13 @@ def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: return full_index, first_items, codes def first_items(self) -> tuple[pd.Series, np.ndarray]: - from xarray import CFTimeIndex + from xarray.coding.cftimeindex import CFTimeIndex + from xarray.core.resample_cftime import CFTimeGrouper - if isinstance(self.group_as_index, CFTimeIndex): - return self.index_grouper.first_items(self.group_as_index) + if isinstance(self.index_grouper, CFTimeGrouper): + return self.index_grouper.first_items( + cast(CFTimeIndex, self.group_as_index) + ) else: s = pd.Series(np.arange(self.group_as_index.size), self.group_as_index) grouped = s.groupby(self.index_grouper) @@ -318,18 +359,16 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]: # So for _flox_reduce we avoid one reindex and copy by avoiding # _maybe_restore_empty_groups codes = np.repeat(np.arange(len(first_items)), counts) - if self.loffset is not None: - _apply_loffset(self.loffset, first_items) return first_items, codes - def factorize(self, group) -> EncodedGroups: + def factorize(self, group: T_Group) -> EncodedGroups: self._init_properties(group) full_index, first_items, codes_ = self._get_index_and_items() sbins = first_items.values.astype(np.int64) - group_indices: T_GroupIndices = [ - slice(i, j) for i, j in zip(sbins[:-1], sbins[1:]) - ] - group_indices += [slice(sbins[-1], None)] + group_indices: GroupIndices = tuple( + [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + + [slice(sbins[-1], None)] + ) unique_coord = Variable( dims=group.name, data=first_items.index, attrs=group.attrs @@ -344,43 +383,6 @@ def factorize(self, group) -> EncodedGroups: ) -def _apply_loffset( - loffset: str | pd.DateOffset | datetime.timedelta | pd.Timedelta, - result: pd.Series | pd.DataFrame, -): - """ - (copied from pandas) - if loffset is set, offset the result index - - This is NOT an idempotent routine, it will be applied - exactly once to the result. - - Parameters - ---------- - result : Series or DataFrame - the result of resample - """ - # pd.Timedelta is a subclass of datetime.timedelta so we do not need to - # include it in instance checks. - if not isinstance(loffset, (str, pd.DateOffset, datetime.timedelta)): - raise ValueError( - f"`loffset` must be a str, pd.DateOffset, datetime.timedelta, or pandas.Timedelta object. " - f"Got {loffset}." - ) - - if isinstance(loffset, str): - loffset = pd.tseries.frequencies.to_offset(loffset) - - needs_offset = ( - isinstance(loffset, (pd.DateOffset, datetime.timedelta)) - and isinstance(result.index, pd.DatetimeIndex) - and len(result.index) > 0 - ) - - if needs_offset: - result.index = result.index + loffset - - def unique_value_groups( ar, sort: bool = True ) -> tuple[np.ndarray | pd.Index, np.ndarray]: diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index b715973814f..c8d6953fc72 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -153,13 +153,16 @@ def __getitem__( ) -> _arrayfunction[Any, _DType_co] | Any: ... @overload - def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]: ... - + def __array__( + self, dtype: None = ..., /, *, copy: None | bool = ... + ) -> np.ndarray[Any, _DType_co]: ... @overload - def __array__(self, dtype: _DType, /) -> np.ndarray[Any, _DType]: ... + def __array__( + self, dtype: _DType, /, *, copy: None | bool = ... + ) -> np.ndarray[Any, _DType]: ... def __array__( - self, dtype: _DType | None = ..., / + self, dtype: _DType | None = ..., /, *, copy: None | bool = ... ) -> np.ndarray[Any, _DType] | np.ndarray[Any, _DType_co]: ... # TODO: Should return the same subclass but with a new dtype generic. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index fe47bf50533..e9668d89d94 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -53,7 +53,7 @@ if TYPE_CHECKING: from numpy.typing import ArrayLike, NDArray - from xarray.core.types import Dims + from xarray.core.types import Dims, T_Chunks from xarray.namedarray._typing import ( Default, _AttrsLike, @@ -748,7 +748,7 @@ def sizes(self) -> dict[_Dim, _IntOrUnknown]: def chunk( self, - chunks: int | Literal["auto"] | Mapping[Any, None | int | tuple[int, ...]] = {}, + chunks: T_Chunks = {}, chunked_array_type: str | ChunkManagerEntrypoint[Any] | None = None, from_array_kwargs: Any = None, **chunks_kwargs: Any, @@ -839,7 +839,7 @@ def chunk( ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) # type: ignore[assignment] if is_dict_like(chunks): - chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape)) # type: ignore[assignment] + chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape)) data_chunked = chunkmanager.from_array(ndata, chunks, **from_array_kwargs) # type: ignore[arg-type] diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index 14744d2de6b..963d12fd865 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -21,13 +21,13 @@ try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray[Any, Any] # type: ignore[assignment, misc] + DaskArray = np.ndarray[Any, Any] dask_available = module_available("dask") -class DaskManager(ChunkManagerEntrypoint["DaskArray"]): # type: ignore[type-var] +class DaskManager(ChunkManagerEntrypoint["DaskArray"]): array_cls: type[DaskArray] available: bool = dask_available @@ -91,7 +91,7 @@ def array_api(self) -> Any: return da - def reduction( # type: ignore[override] + def reduction( self, arr: T_ChunkedArray, func: Callable[..., Any], @@ -113,7 +113,7 @@ def reduction( # type: ignore[override] keepdims=keepdims, ) # type: ignore[no-untyped-call] - def scan( # type: ignore[override] + def scan( self, func: Callable[..., Any], binop: Callable[..., Any], diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index faf809a8a74..613362ed5d5 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -774,7 +774,7 @@ def _get_largest_lims(self) -> dict[str, tuple[float, float]]: >>> ds = xr.tutorial.scatter_example_dataset(seed=42) >>> fg = ds.plot.scatter(x="A", y="B", hue="y", row="x", col="w") >>> round(fg._get_largest_lims()["x"][0], 3) - -0.334 + np.float64(-0.334) """ lims_largest: dict[str, tuple[float, float]] = dict( x=(np.inf, -np.inf), y=(np.inf, -np.inf), z=(np.inf, -np.inf) @@ -817,7 +817,7 @@ def _set_lims( >>> fg = ds.plot.scatter(x="A", y="B", hue="y", row="x", col="w") >>> fg._set_lims(x=(-0.3, 0.3), y=(0, 2), z=(0, 4)) >>> fg.axs[0, 0].get_xlim(), fg.axs[0, 0].get_ylim() - ((-0.3, 0.3), (0.0, 2.0)) + ((np.float64(-0.3), np.float64(0.3)), (np.float64(0.0), np.float64(2.0))) """ lims_largest = self._get_largest_lims() diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 8789bc2f9c2..a0abe0c8cfd 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -811,11 +811,11 @@ def _update_axes( def _is_monotonic(coord, axis=0): """ >>> _is_monotonic(np.array([0, 1, 2])) - True + np.True_ >>> _is_monotonic(np.array([2, 1, 0])) - True + np.True_ >>> _is_monotonic(np.array([0, 2, 1])) - False + np.False_ """ if coord.shape[axis] < 3: return True diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 15485dc178a..4fa8736427d 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -568,7 +568,7 @@ def test_roundtrip_cftime_datetime_data(self) -> None: assert actual.t.encoding["calendar"] == expected_calendar def test_roundtrip_timedelta_data(self) -> None: - time_deltas = pd.to_timedelta(["1h", "2h", "NaT"]) + time_deltas = pd.to_timedelta(["1h", "2h", "NaT"]) # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 expected = Dataset({"td": ("td", time_deltas), "td0": time_deltas[0]}) with self.roundtrip(expected) as actual: assert_identical(expected, actual) @@ -863,14 +863,14 @@ def test_roundtrip_empty_vlen_string_array(self) -> None: # checks preserving vlen dtype for empty arrays GH7862 dtype = create_vlen_dtype(str) original = Dataset({"a": np.array([], dtype=dtype)}) - assert check_vlen_dtype(original["a"].dtype) == str + assert check_vlen_dtype(original["a"].dtype) is str with self.roundtrip(original) as actual: assert_identical(original, actual) if np.issubdtype(actual["a"].dtype, object): # only check metadata for capable backends # eg. NETCDF3 based backends do not roundtrip metadata if actual["a"].dtype.metadata is not None: - assert check_vlen_dtype(actual["a"].dtype) == str + assert check_vlen_dtype(actual["a"].dtype) is str else: assert actual["a"].dtype == np.dtype(" None: assert decoded.variables[k].dtype == actual.variables[k].dtype assert_allclose(decoded, actual, decode_bytes=False) + @pytest.mark.parametrize("fillvalue", [np.int8(-1), np.uint8(255), -1, 255]) + def test_roundtrip_unsigned(self, fillvalue): + # regression/numpy2 test for + encoding = { + "_FillValue": fillvalue, + "_Unsigned": "true", + "dtype": "i1", + } + x = np.array([0, 1, 127, 128, 254, np.nan], dtype=np.float32) + decoded = Dataset({"x": ("t", x, {}, encoding)}) + + attributes = { + "_FillValue": fillvalue, + "_Unsigned": "true", + } + # Create unsigned data corresponding to [0, 1, 127, 128, 255] unsigned + sb = np.asarray([0, 1, 127, -128, -2, -1], dtype="i1") + encoded = Dataset({"x": ("t", sb, attributes)}) + + with self.roundtrip(decoded) as actual: + for k in decoded.variables: + assert decoded.variables[k].dtype == actual.variables[k].dtype + assert_allclose(decoded, actual, decode_bytes=False) + + with self.roundtrip(decoded, open_kwargs=dict(decode_cf=False)) as actual: + for k in encoded.variables: + assert encoded.variables[k].dtype == actual.variables[k].dtype + assert_allclose(encoded, actual, decode_bytes=False) + @staticmethod def _create_cf_dataset(): original = Dataset( @@ -4285,7 +4314,7 @@ def test_roundtrip_coordinates_with_space(self) -> None: def test_roundtrip_numpy_datetime_data(self) -> None: # Override method in DatasetIOBase - remove not applicable # save_kwargs - times = pd.to_datetime(["2000-01-01", "2000-01-02", "NaT"]) + times = pd.to_datetime(["2000-01-01", "2000-01-02", "NaT"], unit="ns") expected = Dataset({"t": ("t", times), "t0": times[0]}) with self.roundtrip(expected) as actual: assert_identical(expected, actual) @@ -5598,7 +5627,9 @@ def test_h5netcdf_entrypoint(tmp_path: Path) -> None: @requires_netCDF4 @pytest.mark.parametrize("str_type", (str, np.str_)) -def test_write_file_from_np_str(str_type, tmpdir) -> None: +def test_write_file_from_np_str( + str_type: type[str] | type[np.str_], tmpdir: str +) -> None: # https://github.com/pydata/xarray/pull/5264 scenarios = [str_type(v) for v in ["scenario_a", "scenario_b", "scenario_c"]] years = range(2015, 2100 + 1) @@ -5609,7 +5640,7 @@ def test_write_file_from_np_str(str_type, tmpdir) -> None: ) tdf.index.name = "scenario" tdf.columns.name = "year" - tdf = tdf.stack() + tdf = cast(pd.DataFrame, tdf.stack()) tdf.name = "tas" txr = tdf.to_xarray() diff --git a/xarray/tests/test_calendar_ops.py b/xarray/tests/test_calendar_ops.py index 7d229371808..13e9f7a1030 100644 --- a/xarray/tests/test_calendar_ops.py +++ b/xarray/tests/test_calendar_ops.py @@ -1,9 +1,10 @@ from __future__ import annotations import numpy as np +import pandas as pd import pytest -from xarray import DataArray, infer_freq +from xarray import CFTimeIndex, DataArray, infer_freq from xarray.coding.calendar_ops import convert_calendar, interp_calendar from xarray.coding.cftime_offsets import date_range from xarray.testing import assert_identical @@ -286,3 +287,25 @@ def test_interp_calendar_errors(): ValueError, match="Both 'source.x' and 'target' must contain datetime objects." ): interp_calendar(da1, da2, dim="x") + + +@requires_cftime +@pytest.mark.parametrize( + ("source_calendar", "target_calendar", "expected_index"), + [("standard", "noleap", CFTimeIndex), ("all_leap", "standard", pd.DatetimeIndex)], +) +def test_convert_calendar_produces_time_index( + source_calendar, target_calendar, expected_index +): + # https://github.com/pydata/xarray/issues/9138 + time = date_range("2000-01-01", "2002-01-01", freq="D", calendar=source_calendar) + temperature = np.ones(len(time)) + da = DataArray( + data=temperature, + dims=["time"], + coords=dict( + time=time, + ), + ) + converted = da.convert_calendar(target_calendar) + assert isinstance(converted.indexes["time"], expected_index) diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py index eabb7d2f4d6..cec4ea4c944 100644 --- a/xarray/tests/test_cftime_offsets.py +++ b/xarray/tests/test_cftime_offsets.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from itertools import product from typing import Callable, Literal @@ -24,7 +25,6 @@ Tick, YearBegin, YearEnd, - _days_in_month, _legacy_to_new_freq, _new_to_legacy_freq, cftime_range, @@ -511,7 +511,7 @@ def test_Microsecond_multiplied_float_error(): ], ids=_id_func, ) -def test_neg(offset, expected): +def test_neg(offset: BaseCFTimeOffset, expected: BaseCFTimeOffset) -> None: assert -offset == expected @@ -589,22 +589,6 @@ def test_minus_offset_error(a, b): b - a -def test_days_in_month_non_december(calendar): - date_type = get_date_type(calendar) - reference = date_type(1, 4, 1) - assert _days_in_month(reference) == 30 - - -def test_days_in_month_december(calendar): - if calendar == "360_day": - expected = 30 - else: - expected = 31 - date_type = get_date_type(calendar) - reference = date_type(1, 12, 5) - assert _days_in_month(reference) == expected - - @pytest.mark.parametrize( ("initial_date_args", "offset", "expected_date_args"), [ @@ -657,7 +641,7 @@ def test_add_month_end( # Here the days at the end of each month varies based on the calendar used expected_date_args = ( - expected_year_month + (_days_in_month(reference),) + expected_sub_day + expected_year_month + (reference.daysinmonth,) + expected_sub_day ) expected = date_type(*expected_date_args) assert result == expected @@ -694,9 +678,7 @@ def test_add_month_end_onOffset( date_type = get_date_type(calendar) reference_args = initial_year_month + (1,) reference = date_type(*reference_args) - initial_date_args = ( - initial_year_month + (_days_in_month(reference),) + initial_sub_day - ) + initial_date_args = initial_year_month + (reference.daysinmonth,) + initial_sub_day initial = date_type(*initial_date_args) result = initial + offset reference_args = expected_year_month + (1,) @@ -704,7 +686,7 @@ def test_add_month_end_onOffset( # Here the days at the end of each month varies based on the calendar used expected_date_args = ( - expected_year_month + (_days_in_month(reference),) + expected_sub_day + expected_year_month + (reference.daysinmonth,) + expected_sub_day ) expected = date_type(*expected_date_args) assert result == expected @@ -756,7 +738,7 @@ def test_add_year_end( # Here the days at the end of each month varies based on the calendar used expected_date_args = ( - expected_year_month + (_days_in_month(reference),) + expected_sub_day + expected_year_month + (reference.daysinmonth,) + expected_sub_day ) expected = date_type(*expected_date_args) assert result == expected @@ -792,9 +774,7 @@ def test_add_year_end_onOffset( date_type = get_date_type(calendar) reference_args = initial_year_month + (1,) reference = date_type(*reference_args) - initial_date_args = ( - initial_year_month + (_days_in_month(reference),) + initial_sub_day - ) + initial_date_args = initial_year_month + (reference.daysinmonth,) + initial_sub_day initial = date_type(*initial_date_args) result = initial + offset reference_args = expected_year_month + (1,) @@ -802,7 +782,7 @@ def test_add_year_end_onOffset( # Here the days at the end of each month varies based on the calendar used expected_date_args = ( - expected_year_month + (_days_in_month(reference),) + expected_sub_day + expected_year_month + (reference.daysinmonth,) + expected_sub_day ) expected = date_type(*expected_date_args) assert result == expected @@ -854,7 +834,7 @@ def test_add_quarter_end( # Here the days at the end of each month varies based on the calendar used expected_date_args = ( - expected_year_month + (_days_in_month(reference),) + expected_sub_day + expected_year_month + (reference.daysinmonth,) + expected_sub_day ) expected = date_type(*expected_date_args) assert result == expected @@ -890,9 +870,7 @@ def test_add_quarter_end_onOffset( date_type = get_date_type(calendar) reference_args = initial_year_month + (1,) reference = date_type(*reference_args) - initial_date_args = ( - initial_year_month + (_days_in_month(reference),) + initial_sub_day - ) + initial_date_args = initial_year_month + (reference.daysinmonth,) + initial_sub_day initial = date_type(*initial_date_args) result = initial + offset reference_args = expected_year_month + (1,) @@ -900,7 +878,7 @@ def test_add_quarter_end_onOffset( # Here the days at the end of each month varies based on the calendar used expected_date_args = ( - expected_year_month + (_days_in_month(reference),) + expected_sub_day + expected_year_month + (reference.daysinmonth,) + expected_sub_day ) expected = date_type(*expected_date_args) assert result == expected @@ -957,7 +935,7 @@ def test_onOffset_month_or_quarter_or_year_end( date_type = get_date_type(calendar) reference_args = year_month_args + (1,) reference = date_type(*reference_args) - date_args = year_month_args + (_days_in_month(reference),) + sub_day_args + date_args = year_month_args + (reference.daysinmonth,) + sub_day_args date = date_type(*date_args) result = offset.onOffset(date) assert result @@ -1005,7 +983,7 @@ def test_rollforward(calendar, offset, initial_date_args, partial_expected_date_ elif isinstance(offset, (MonthEnd, QuarterEnd, YearEnd)): reference_args = partial_expected_date_args + (1,) reference = date_type(*reference_args) - expected_date_args = partial_expected_date_args + (_days_in_month(reference),) + expected_date_args = partial_expected_date_args + (reference.daysinmonth,) else: expected_date_args = partial_expected_date_args expected = date_type(*expected_date_args) @@ -1056,7 +1034,7 @@ def test_rollback(calendar, offset, initial_date_args, partial_expected_date_arg elif isinstance(offset, (MonthEnd, QuarterEnd, YearEnd)): reference_args = partial_expected_date_args + (1,) reference = date_type(*reference_args) - expected_date_args = partial_expected_date_args + (_days_in_month(reference),) + expected_date_args = partial_expected_date_args + (reference.daysinmonth,) else: expected_date_args = partial_expected_date_args expected = date_type(*expected_date_args) @@ -1787,3 +1765,69 @@ def test_date_range_no_freq(start, end, periods): expected = pd.date_range(start=start, end=end, periods=periods) np.testing.assert_array_equal(result, expected) + + +@pytest.mark.parametrize( + "offset", + [ + MonthBegin(n=1), + MonthEnd(n=1), + QuarterBegin(n=1), + QuarterEnd(n=1), + YearBegin(n=1), + YearEnd(n=1), + ], + ids=lambda x: f"{x}", +) +@pytest.mark.parametrize("has_year_zero", [False, True]) +def test_offset_addition_preserves_has_year_zero(offset, has_year_zero): + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="this date/calendar/year zero") + datetime = cftime.DatetimeGregorian(-1, 12, 31, has_year_zero=has_year_zero) + + result = datetime + offset + assert result.has_year_zero == datetime.has_year_zero + if has_year_zero: + assert result.year == 0 + else: + assert result.year == 1 + + +@pytest.mark.parametrize( + "offset", + [ + MonthBegin(n=1), + MonthEnd(n=1), + QuarterBegin(n=1), + QuarterEnd(n=1), + YearBegin(n=1), + YearEnd(n=1), + ], + ids=lambda x: f"{x}", +) +@pytest.mark.parametrize("has_year_zero", [False, True]) +def test_offset_subtraction_preserves_has_year_zero(offset, has_year_zero): + datetime = cftime.DatetimeGregorian(1, 1, 1, has_year_zero=has_year_zero) + result = datetime - offset + assert result.has_year_zero == datetime.has_year_zero + if has_year_zero: + assert result.year == 0 + else: + assert result.year == -1 + + +@pytest.mark.parametrize("has_year_zero", [False, True]) +def test_offset_day_option_end_accounts_for_has_year_zero(has_year_zero): + offset = MonthEnd(n=1) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="this date/calendar/year zero") + datetime = cftime.DatetimeGregorian(-1, 1, 31, has_year_zero=has_year_zero) + + result = datetime + offset + assert result.has_year_zero == datetime.has_year_zero + if has_year_zero: + assert result.day == 28 + else: + assert result.day == 29 diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index f6eb15fa373..116487e2bcf 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -679,11 +679,11 @@ def test_indexing_in_series_loc(series, index, scalar_args, range_args): @requires_cftime def test_indexing_in_series_iloc(series, index): - expected = 1 - assert series.iloc[0] == expected + expected1 = 1 + assert series.iloc[0] == expected1 - expected = pd.Series([1, 2], index=index[:2]) - assert series.iloc[:2].equals(expected) + expected2 = pd.Series([1, 2], index=index[:2]) + assert series.iloc[:2].equals(expected2) @requires_cftime @@ -696,27 +696,27 @@ def test_series_dropna(index): @requires_cftime def test_indexing_in_dataframe_loc(df, index, scalar_args, range_args): - expected = pd.Series([1], name=index[0]) + expected_s = pd.Series([1], name=index[0]) for arg in scalar_args: - result = df.loc[arg] - assert result.equals(expected) + result_s = df.loc[arg] + assert result_s.equals(expected_s) - expected = pd.DataFrame([1, 2], index=index[:2]) + expected_df = pd.DataFrame([1, 2], index=index[:2]) for arg in range_args: - result = df.loc[arg] - assert result.equals(expected) + result_df = df.loc[arg] + assert result_df.equals(expected_df) @requires_cftime def test_indexing_in_dataframe_iloc(df, index): - expected = pd.Series([1], name=index[0]) - result = df.iloc[0] - assert result.equals(expected) - assert result.equals(expected) + expected_s = pd.Series([1], name=index[0]) + result_s = df.iloc[0] + assert result_s.equals(expected_s) + assert result_s.equals(expected_s) - expected = pd.DataFrame([1, 2], index=index[:2]) - result = df.iloc[:2] - assert result.equals(expected) + expected_df = pd.DataFrame([1, 2], index=index[:2]) + result_df = df.iloc[:2] + assert result_df.equals(expected_df) @requires_cftime @@ -957,17 +957,17 @@ def test_cftimeindex_shift(index, freq) -> None: @requires_cftime -def test_cftimeindex_shift_invalid_n() -> None: +def test_cftimeindex_shift_invalid_periods() -> None: index = xr.cftime_range("2000", periods=3) with pytest.raises(TypeError): - index.shift("a", "D") + index.shift("a", "D") # type: ignore[arg-type] @requires_cftime def test_cftimeindex_shift_invalid_freq() -> None: index = xr.cftime_range("2000", periods=3) with pytest.raises(TypeError): - index.shift(1, 1) + index.shift(1, 1) # type: ignore[arg-type] @requires_cftime diff --git a/xarray/tests/test_cftimeindex_resample.py b/xarray/tests/test_cftimeindex_resample.py index 98d4377706c..081970108a0 100644 --- a/xarray/tests/test_cftimeindex_resample.py +++ b/xarray/tests/test_cftimeindex_resample.py @@ -10,7 +10,7 @@ import xarray as xr from xarray.coding.cftime_offsets import _new_to_legacy_freq -from xarray.core.pdcompat import _convert_base_to_offset +from xarray.coding.cftimeindex import CFTimeIndex from xarray.core.resample_cftime import CFTimeGrouper cftime = pytest.importorskip("cftime") @@ -60,10 +60,8 @@ def compare_against_pandas( freq, closed=None, label=None, - base=None, offset=None, origin=None, - loffset=None, ) -> None: if isinstance(origin, tuple): origin_pandas = pd.Timestamp(datetime.datetime(*origin)) @@ -77,8 +75,6 @@ def compare_against_pandas( time=freq, closed=closed, label=label, - base=base, - loffset=loffset, offset=offset, origin=origin_pandas, ).mean() @@ -88,8 +84,6 @@ def compare_against_pandas( time=freq, closed=closed, label=label, - base=base, - loffset=loffset, origin=origin_cftime, offset=offset, ).mean() @@ -98,8 +92,6 @@ def compare_against_pandas( time=freq, closed=closed, label=label, - base=base, - loffset=loffset, origin=origin_cftime, offset=offset, ).mean() @@ -116,14 +108,11 @@ def da(index) -> xr.DataArray: ) -@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") @pytest.mark.parametrize("freqs", FREQS, ids=lambda x: "{}->{}".format(*x)) @pytest.mark.parametrize("closed", [None, "left", "right"]) @pytest.mark.parametrize("label", [None, "left", "right"]) -@pytest.mark.parametrize( - ("base", "offset"), [(24, None), (31, None), (None, "5s")], ids=lambda x: f"{x}" -) -def test_resample(freqs, closed, label, base, offset) -> None: +@pytest.mark.parametrize("offset", [None, "5s"], ids=lambda x: f"{x}") +def test_resample(freqs, closed, label, offset) -> None: initial_freq, resample_freq = freqs if ( resample_freq == "4001D" @@ -136,7 +125,6 @@ def test_resample(freqs, closed, label, base, offset) -> None: "result as pandas for earlier pandas versions." ) start = "2000-01-01T12:07:01" - loffset = "12h" origin = "start" datetime_index = pd.date_range( @@ -146,18 +134,15 @@ def test_resample(freqs, closed, label, base, offset) -> None: da_datetimeindex = da(datetime_index) da_cftimeindex = da(cftime_index) - with pytest.warns(FutureWarning, match="`loffset` parameter"): - compare_against_pandas( - da_datetimeindex, - da_cftimeindex, - resample_freq, - closed=closed, - label=label, - base=base, - offset=offset, - origin=origin, - loffset=loffset, - ) + compare_against_pandas( + da_datetimeindex, + da_cftimeindex, + resample_freq, + closed=closed, + label=label, + offset=offset, + origin=origin, + ) @pytest.mark.parametrize( @@ -181,30 +166,22 @@ def test_closed_label_defaults(freq, expected) -> None: @pytest.mark.filterwarnings("ignore:Converting a CFTimeIndex") -@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") @pytest.mark.parametrize( "calendar", ["gregorian", "noleap", "all_leap", "360_day", "julian"] ) def test_calendars(calendar: str) -> None: # Limited testing for non-standard calendars - freq, closed, label, base = "8001min", None, None, 17 - loffset = datetime.timedelta(hours=12) + freq, closed, label = "8001min", None, None xr_index = xr.cftime_range( start="2004-01-01T12:07:01", periods=7, freq="3D", calendar=calendar ) pd_index = pd.date_range(start="2004-01-01T12:07:01", periods=7, freq="3D") - da_cftime = ( - da(xr_index) - .resample(time=freq, closed=closed, label=label, base=base, loffset=loffset) - .mean() - ) - da_datetime = ( - da(pd_index) - .resample(time=freq, closed=closed, label=label, base=base, loffset=loffset) - .mean() - ) + da_cftime = da(xr_index).resample(time=freq, closed=closed, label=label).mean() + da_datetime = da(pd_index).resample(time=freq, closed=closed, label=label).mean() # TODO (benbovy - flexible indexes): update when CFTimeIndex is a xarray Index subclass - da_cftime["time"] = da_cftime.xindexes["time"].to_pandas_index().to_datetimeindex() + new_pd_index = da_cftime.xindexes["time"].to_pandas_index() + assert isinstance(new_pd_index, CFTimeIndex) # shouldn't that be a pd.Index? + da_cftime["time"] = new_pd_index.to_datetimeindex() xr.testing.assert_identical(da_cftime, da_datetime) @@ -214,7 +191,6 @@ class DateRangeKwargs(TypedDict): freq: str -@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") @pytest.mark.parametrize("closed", ["left", "right"]) @pytest.mark.parametrize( "origin", @@ -239,20 +215,12 @@ def test_origin(closed, origin) -> None: ) -@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") -def test_base_and_offset_error(): - cftime_index = xr.cftime_range("2000", periods=5) - da_cftime = da(cftime_index) - with pytest.raises(ValueError, match="base and offset cannot"): - da_cftime.resample(time="2D", base=3, offset="5s") - - @pytest.mark.parametrize("offset", ["foo", "5MS", 10]) -def test_invalid_offset_error(offset) -> None: +def test_invalid_offset_error(offset: str | int) -> None: cftime_index = xr.cftime_range("2000", periods=5) da_cftime = da(cftime_index) with pytest.raises(ValueError, match="offset must be"): - da_cftime.resample(time="2D", offset=offset) + da_cftime.resample(time="2D", offset=offset) # type: ignore[arg-type] def test_timedelta_offset() -> None: @@ -265,44 +233,3 @@ def test_timedelta_offset() -> None: timedelta_result = da_cftime.resample(time="2D", offset=timedelta).mean() string_result = da_cftime.resample(time="2D", offset=string).mean() xr.testing.assert_identical(timedelta_result, string_result) - - -@pytest.mark.parametrize("loffset", ["MS", "12h", datetime.timedelta(hours=-12)]) -def test_resample_loffset_cftimeindex(loffset) -> None: - datetimeindex = pd.date_range("2000-01-01", freq="6h", periods=10) - da_datetimeindex = xr.DataArray(np.arange(10), [("time", datetimeindex)]) - - cftimeindex = xr.cftime_range("2000-01-01", freq="6h", periods=10) - da_cftimeindex = xr.DataArray(np.arange(10), [("time", cftimeindex)]) - - with pytest.warns(FutureWarning, match="`loffset` parameter"): - result = da_cftimeindex.resample(time="24h", loffset=loffset).mean() - expected = da_datetimeindex.resample(time="24h", loffset=loffset).mean() - - result["time"] = result.xindexes["time"].to_pandas_index().to_datetimeindex() - xr.testing.assert_identical(result, expected) - - -@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") -def test_resample_invalid_loffset_cftimeindex() -> None: - times = xr.cftime_range("2000-01-01", freq="6h", periods=10) - da = xr.DataArray(np.arange(10), [("time", times)]) - - with pytest.raises(ValueError): - da.resample(time="24h", loffset=1) # type: ignore - - -@pytest.mark.parametrize(("base", "freq"), [(1, "10s"), (17, "3h"), (15, "5us")]) -def test__convert_base_to_offset(base, freq): - # Verify that the cftime_offset adapted version of _convert_base_to_offset - # produces the same result as the pandas version. - datetimeindex = pd.date_range("2000", periods=2) - cftimeindex = xr.cftime_range("2000", periods=2) - pandas_result = _convert_base_to_offset(base, freq, datetimeindex) - cftime_result = _convert_base_to_offset(base, freq, cftimeindex) - assert pandas_result.to_pytimedelta() == cftime_result - - -def test__convert_base_to_offset_invalid_index(): - with pytest.raises(ValueError, match="Can only resample"): - _convert_base_to_offset(1, "12h", pd.Index([0])) diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py index 51f63ea72dd..17179a44a8a 100644 --- a/xarray/tests/test_coding_strings.py +++ b/xarray/tests/test_coding_strings.py @@ -21,13 +21,13 @@ def test_vlen_dtype() -> None: dtype = strings.create_vlen_dtype(str) - assert dtype.metadata["element_type"] == str + assert dtype.metadata["element_type"] is str assert strings.is_unicode_dtype(dtype) assert not strings.is_bytes_dtype(dtype) assert strings.check_vlen_dtype(dtype) is str dtype = strings.create_vlen_dtype(bytes) - assert dtype.metadata["element_type"] == bytes + assert dtype.metadata["element_type"] is bytes assert not strings.is_unicode_dtype(dtype) assert strings.is_bytes_dtype(dtype) assert strings.check_vlen_dtype(dtype) is bytes diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index d568bdc3268..35df5bf41d4 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -3,6 +3,7 @@ import warnings from datetime import timedelta from itertools import product +from typing import Literal import numpy as np import pandas as pd @@ -44,6 +45,8 @@ FirstElementAccessibleArray, arm_xfail, assert_array_equal, + assert_duckarray_allclose, + assert_duckarray_equal, assert_no_warnings, has_cftime, requires_cftime, @@ -142,15 +145,16 @@ def test_cf_datetime(num_dates, units, calendar) -> None: # we could do this check with near microsecond accuracy: # https://github.com/Unidata/netcdf4-python/issues/355 assert (abs_diff <= np.timedelta64(1, "s")).all() - encoded, _, _ = encode_cf_datetime(actual, units, calendar) + encoded1, _, _ = encode_cf_datetime(actual, units, calendar) + + assert_array_equal(num_dates, np.round(encoded1, 1)) - assert_array_equal(num_dates, np.round(encoded, 1)) if hasattr(num_dates, "ndim") and num_dates.ndim == 1 and "1000" not in units: # verify that wrapping with a pandas.Index works # note that it *does not* currently work to put # non-datetime64 compatible dates into a pandas.Index - encoded, _, _ = encode_cf_datetime(pd.Index(actual), units, calendar) - assert_array_equal(num_dates, np.round(encoded, 1)) + encoded2, _, _ = encode_cf_datetime(pd.Index(actual), units, calendar) + assert_array_equal(num_dates, np.round(encoded2, 1)) @requires_cftime @@ -625,10 +629,10 @@ def test_cf_timedelta_2d() -> None: @pytest.mark.parametrize( ["deltas", "expected"], [ - (pd.to_timedelta(["1 day", "2 days"]), "days"), - (pd.to_timedelta(["1h", "1 day 1 hour"]), "hours"), - (pd.to_timedelta(["1m", "2m", np.nan]), "minutes"), - (pd.to_timedelta(["1m3s", "1m4s"]), "seconds"), + (pd.to_timedelta(["1 day", "2 days"]), "days"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 + (pd.to_timedelta(["1h", "1 day 1 hour"]), "hours"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 + (pd.to_timedelta(["1m", "2m", np.nan]), "minutes"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 + (pd.to_timedelta(["1m3s", "1m4s"]), "seconds"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 ], ) def test_infer_timedelta_units(deltas, expected) -> None: @@ -893,10 +897,10 @@ def test_time_units_with_timezone_roundtrip(calendar) -> None: ) if calendar in _STANDARD_CALENDARS: - np.testing.assert_array_equal(result_num_dates, expected_num_dates) + assert_duckarray_equal(result_num_dates, expected_num_dates) else: # cftime datetime arithmetic is not quite exact. - np.testing.assert_allclose(result_num_dates, expected_num_dates) + assert_duckarray_allclose(result_num_dates, expected_num_dates) assert result_units == expected_units assert result_calendar == calendar @@ -1235,7 +1239,7 @@ def test_contains_cftime_lazy() -> None: ) def test_roundtrip_datetime64_nanosecond_precision( timestr: str, - timeunit: str, + timeunit: Literal["ns", "us"], dtype: np.typing.DTypeLike, fill_value: int | float | None, use_encoding: bool, @@ -1431,8 +1435,8 @@ def test_roundtrip_float_times() -> None: def test_encode_cf_datetime_datetime64_via_dask(freq, units, dtype) -> None: import dask.array - times = pd.date_range(start="1700", freq=freq, periods=3) - times = dask.array.from_array(times, chunks=1) + times_pd = pd.date_range(start="1700", freq=freq, periods=3) + times = dask.array.from_array(times_pd, chunks=1) encoded_times, encoding_units, encoding_calendar = encode_cf_datetime( times, units, None, dtype ) @@ -1482,8 +1486,8 @@ def test_encode_cf_datetime_cftime_datetime_via_dask(units, dtype) -> None: import dask.array calendar = "standard" - times = cftime_range(start="1700", freq="D", periods=3, calendar=calendar) - times = dask.array.from_array(times, chunks=1) + times_idx = cftime_range(start="1700", freq="D", periods=3, calendar=calendar) + times = dask.array.from_array(times_idx, chunks=1) encoded_times, encoding_units, encoding_calendar = encode_cf_datetime( times, units, None, dtype ) @@ -1555,11 +1559,13 @@ def test_encode_cf_datetime_casting_overflow_error(use_cftime, use_dask, dtype) @pytest.mark.parametrize( ("units", "dtype"), [("days", np.dtype("int32")), (None, None)] ) -def test_encode_cf_timedelta_via_dask(units, dtype) -> None: +def test_encode_cf_timedelta_via_dask( + units: str | None, dtype: np.dtype | None +) -> None: import dask.array - times = pd.timedelta_range(start="0D", freq="D", periods=3) - times = dask.array.from_array(times, chunks=1) + times_pd = pd.timedelta_range(start="0D", freq="D", periods=3) + times = dask.array.from_array(times_pd, chunks=1) encoded_times, encoding_units = encode_cf_timedelta(times, units, dtype) assert is_duck_dask_array(encoded_times) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index b000de311af..8b480b02472 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -118,10 +118,8 @@ def test_apply_identity() -> None: assert_identical(variable, apply_identity(variable)) assert_identical(data_array, apply_identity(data_array)) assert_identical(data_array, apply_identity(data_array.groupby("x"))) - assert_identical(data_array, apply_identity(data_array.groupby("x", squeeze=False))) assert_identical(dataset, apply_identity(dataset)) assert_identical(dataset, apply_identity(dataset.groupby("x"))) - assert_identical(dataset, apply_identity(dataset.groupby("x", squeeze=False))) def add(a, b): @@ -521,10 +519,8 @@ def func(x): assert_identical(stacked_variable, stack_negative(variable)) assert_identical(stacked_data_array, stack_negative(data_array)) assert_identical(stacked_dataset, stack_negative(dataset)) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - assert_identical(stacked_data_array, stack_negative(data_array.groupby("x"))) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - assert_identical(stacked_dataset, stack_negative(dataset.groupby("x"))) + assert_identical(stacked_data_array, stack_negative(data_array.groupby("x"))) + assert_identical(stacked_dataset, stack_negative(dataset.groupby("x"))) def original_and_stack_negative(obj): def func(x): @@ -551,13 +547,11 @@ def func(x): assert_identical(dataset, out0) assert_identical(stacked_dataset, out1) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - out0, out1 = original_and_stack_negative(data_array.groupby("x")) + out0, out1 = original_and_stack_negative(data_array.groupby("x")) assert_identical(data_array, out0) assert_identical(stacked_data_array, out1) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - out0, out1 = original_and_stack_negative(dataset.groupby("x")) + out0, out1 = original_and_stack_negative(dataset.groupby("x")) assert_identical(dataset, out0) assert_identical(stacked_dataset, out1) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 0c570de3b52..8b2a7ec5d28 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal import numpy as np import pandas as pd @@ -474,7 +474,7 @@ def data(self, request) -> Dataset: "dim3" ) - def rectify_dim_order(self, data, dataset) -> Dataset: + def rectify_dim_order(self, data: Dataset, dataset) -> Dataset: # return a new dataset with all variable dimensions transposed into # the order in which they are found in `data` return Dataset( @@ -487,11 +487,13 @@ def rectify_dim_order(self, data, dataset) -> Dataset: @pytest.mark.parametrize( "dim,data", [["dim1", True], ["dim2", False]], indirect=["data"] ) - def test_concat_simple(self, data, dim, coords) -> None: + def test_concat_simple(self, data: Dataset, dim, coords) -> None: datasets = [g for _, g in data.groupby(dim, squeeze=False)] assert_identical(data, concat(datasets, dim, coords=coords)) - def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: + def test_concat_merge_variables_present_in_some_datasets( + self, data: Dataset + ) -> None: # coordinates present in some datasets but not others ds1 = Dataset(data_vars={"a": ("y", [0.1])}, coords={"x": 0.1}) ds2 = Dataset(data_vars={"a": ("y", [0.2])}, coords={"z": 0.2}) @@ -515,7 +517,7 @@ def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: assert_identical(expected, actual) @pytest.mark.parametrize("data", [False], indirect=["data"]) - def test_concat_2(self, data) -> None: + def test_concat_2(self, data: Dataset) -> None: dim = "dim2" datasets = [g.squeeze(dim) for _, g in data.groupby(dim, squeeze=False)] concat_over = [k for k, v in data.coords.items() if dim in v.dims and k != dim] @@ -524,7 +526,9 @@ def test_concat_2(self, data) -> None: @pytest.mark.parametrize("coords", ["different", "minimal", "all"]) @pytest.mark.parametrize("dim", ["dim1", "dim2"]) - def test_concat_coords_kwarg(self, data, dim, coords) -> None: + def test_concat_coords_kwarg( + self, data: Dataset, dim: str, coords: Literal["all", "minimal", "different"] + ) -> None: data = data.copy(deep=True) # make sure the coords argument behaves as expected data.coords["extra"] = ("dim4", np.arange(3)) @@ -538,7 +542,7 @@ def test_concat_coords_kwarg(self, data, dim, coords) -> None: else: assert_equal(data["extra"], actual["extra"]) - def test_concat(self, data) -> None: + def test_concat(self, data: Dataset) -> None: split_data = [ data.isel(dim1=slice(3)), data.isel(dim1=3), @@ -546,7 +550,7 @@ def test_concat(self, data) -> None: ] assert_identical(data, concat(split_data, "dim1")) - def test_concat_dim_precedence(self, data) -> None: + def test_concat_dim_precedence(self, data: Dataset) -> None: # verify that the dim argument takes precedence over # concatenating dataset variables of the same name dim = (2 * data["dim1"]).rename("dim1") diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index fdfea3c3fe8..ea518b6d677 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -119,7 +119,7 @@ def test_incompatible_attributes(self) -> None: Variable( ["t"], pd.date_range("2000-01-01", periods=3), {"units": "foobar"} ), - Variable(["t"], pd.to_timedelta(["1 day"]), {"units": "foobar"}), + Variable(["t"], pd.to_timedelta(["1 day"]), {"units": "foobar"}), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 Variable(["t"], [0, 1, 2], {"add_offset": 0}, {"add_offset": 2}), Variable(["t"], [0, 1, 2], {"_FillValue": 0}, {"_FillValue": 2}), ] @@ -564,10 +564,10 @@ def test_encode_cf_variable_with_vlen_dtype() -> None: ) encoded_v = conventions.encode_cf_variable(v) assert encoded_v.data.dtype.kind == "O" - assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) == str + assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) is str # empty array v = Variable(["x"], np.array([], dtype=coding.strings.create_vlen_dtype(str))) encoded_v = conventions.encode_cf_variable(v) assert encoded_v.data.dtype.kind == "O" - assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) == str + assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) is str diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 659c7c168a5..69f1a377513 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -15,9 +15,9 @@ # remove once numpy 2.0 is the oldest supported version try: - from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] + from numpy.exceptions import RankWarning except ImportError: - from numpy import RankWarning + from numpy import RankWarning # type: ignore[no-redef,attr-defined,unused-ignore] import xarray as xr from xarray import ( @@ -341,12 +341,15 @@ def test_constructor(self) -> None: assert_identical(expected, actual) # list coords, w dims - coords1 = [["a", "b"], [-1, -2, -3]] + coords1: list[Any] = [["a", "b"], [-1, -2, -3]] actual = DataArray(data, coords1, ["x", "y"]) assert_identical(expected, actual) # pd.Index coords, w dims - coords2 = [pd.Index(["a", "b"], name="A"), pd.Index([-1, -2, -3], name="B")] + coords2: list[pd.Index] = [ + pd.Index(["a", "b"], name="A"), + pd.Index([-1, -2, -3], name="B"), + ] actual = DataArray(data, coords2, ["x", "y"]) assert_identical(expected, actual) @@ -424,7 +427,7 @@ def test_constructor_invalid(self) -> None: DataArray(np.random.rand(4, 4), [("x", self.mindex), ("level_1", range(4))]) def test_constructor_from_self_described(self) -> None: - data = [[-0.1, 21], [0, 2]] + data: list[list[float]] = [[-0.1, 21], [0, 2]] expected = DataArray( data, coords={"x": ["a", "b"], "y": [-1, -2]}, @@ -2488,7 +2491,7 @@ def test_stack_unstack(self) -> None: # test GH3000 a = orig[:0, :1].stack(new_dim=("x", "y")).indexes["new_dim"] b = pd.MultiIndex( - levels=[pd.Index([], np.int64), pd.Index([0], np.int64)], + levels=[pd.Index([], dtype=np.int64), pd.Index([0], dtype=np.int64)], codes=[[], []], names=["x", "y"], ) @@ -2980,6 +2983,11 @@ def test_assign_attrs(self) -> None: assert_identical(new_actual, expected) assert actual.attrs == {"a": 1, "b": 2} + def test_drop_attrs(self) -> None: + # Mostly tested in test_dataset.py, but adding a very small test here + da = DataArray([], attrs=dict(a=1, b=2)) + assert da.drop_attrs().attrs == {} + @pytest.mark.parametrize( "func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs] ) @@ -3326,28 +3334,28 @@ def test_broadcast_coordinates(self) -> None: def test_to_pandas(self) -> None: # 0d - actual = DataArray(42).to_pandas() + actual_xr = DataArray(42).to_pandas() expected = np.array(42) - assert_array_equal(actual, expected) + assert_array_equal(actual_xr, expected) # 1d values = np.random.randn(3) index = pd.Index(["a", "b", "c"], name="x") da = DataArray(values, coords=[index]) - actual = da.to_pandas() - assert_array_equal(actual.values, values) - assert_array_equal(actual.index, index) - assert_array_equal(actual.index.name, "x") + actual_s = da.to_pandas() + assert_array_equal(np.asarray(actual_s.values), values) + assert_array_equal(actual_s.index, index) + assert_array_equal(actual_s.index.name, "x") # 2d values = np.random.randn(3, 2) da = DataArray( values, coords=[("x", ["a", "b", "c"]), ("y", [0, 1])], name="foo" ) - actual = da.to_pandas() - assert_array_equal(actual.values, values) - assert_array_equal(actual.index, ["a", "b", "c"]) - assert_array_equal(actual.columns, [0, 1]) + actual_df = da.to_pandas() + assert_array_equal(np.asarray(actual_df.values), values) + assert_array_equal(actual_df.index, ["a", "b", "c"]) + assert_array_equal(actual_df.columns, [0, 1]) # roundtrips for shape in [(3,), (3, 4)]: @@ -3364,24 +3372,24 @@ def test_to_dataframe(self) -> None: arr_np = np.random.randn(3, 4) arr = DataArray(arr_np, [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo") - expected = arr.to_series() - actual = arr.to_dataframe()["foo"] - assert_array_equal(expected.values, actual.values) - assert_array_equal(expected.name, actual.name) - assert_array_equal(expected.index.values, actual.index.values) + expected_s = arr.to_series() + actual_s = arr.to_dataframe()["foo"] + assert_array_equal(np.asarray(expected_s.values), np.asarray(actual_s.values)) + assert_array_equal(np.asarray(expected_s.name), np.asarray(actual_s.name)) + assert_array_equal(expected_s.index.values, actual_s.index.values) - actual = arr.to_dataframe(dim_order=["A", "B"])["foo"] - assert_array_equal(arr_np.transpose().reshape(-1), actual.values) + actual_s = arr.to_dataframe(dim_order=["A", "B"])["foo"] + assert_array_equal(arr_np.transpose().reshape(-1), np.asarray(actual_s.values)) # regression test for coords with different dimensions arr.coords["C"] = ("B", [-1, -2, -3]) - expected = arr.to_series().to_frame() - expected["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 - expected = expected[["C", "foo"]] - actual = arr.to_dataframe() - assert_array_equal(expected.values, actual.values) - assert_array_equal(expected.columns.values, actual.columns.values) - assert_array_equal(expected.index.values, actual.index.values) + expected_df = arr.to_series().to_frame() + expected_df["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 + expected_df = expected_df[["C", "foo"]] + actual_df = arr.to_dataframe() + assert_array_equal(np.asarray(expected_df.values), np.asarray(actual_df.values)) + assert_array_equal(expected_df.columns.values, actual_df.columns.values) + assert_array_equal(expected_df.index.values, actual_df.index.values) with pytest.raises(ValueError, match="does not match the set of dimensions"): arr.to_dataframe(dim_order=["B", "A", "C"]) @@ -3402,11 +3410,13 @@ def test_to_dataframe_multiindex(self) -> None: arr = DataArray(arr_np, [("MI", mindex), ("C", [5, 6, 7])], name="foo") actual = arr.to_dataframe() - assert_array_equal(actual["foo"].values, arr_np.flatten()) - assert_array_equal(actual.index.names, list("ABC")) - assert_array_equal(actual.index.levels[0], [1, 2]) - assert_array_equal(actual.index.levels[1], ["a", "b"]) - assert_array_equal(actual.index.levels[2], [5, 6, 7]) + index_pd = actual.index + assert isinstance(index_pd, pd.MultiIndex) + assert_array_equal(np.asarray(actual["foo"].values), arr_np.flatten()) + assert_array_equal(index_pd.names, list("ABC")) + assert_array_equal(index_pd.levels[0], [1, 2]) + assert_array_equal(index_pd.levels[1], ["a", "b"]) + assert_array_equal(index_pd.levels[2], [5, 6, 7]) def test_to_dataframe_0length(self) -> None: # regression test for #3008 @@ -3426,10 +3436,10 @@ def test_to_dataframe_0length(self) -> None: def test_to_dask_dataframe(self) -> None: arr_np = np.arange(3 * 4).reshape(3, 4) arr = DataArray(arr_np, [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo") - expected = arr.to_series() + expected_s = arr.to_series() actual = arr.to_dask_dataframe()["foo"] - assert_array_equal(actual.values, expected.values) + assert_array_equal(actual.values, np.asarray(expected_s.values)) actual = arr.to_dask_dataframe(dim_order=["A", "B"])["foo"] assert_array_equal(arr_np.transpose().reshape(-1), actual.values) @@ -3437,13 +3447,15 @@ def test_to_dask_dataframe(self) -> None: # regression test for coords with different dimensions arr.coords["C"] = ("B", [-1, -2, -3]) - expected = arr.to_series().to_frame() - expected["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 - expected = expected[["C", "foo"]] + expected_df = arr.to_series().to_frame() + expected_df["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 + expected_df = expected_df[["C", "foo"]] actual = arr.to_dask_dataframe()[["C", "foo"]] - assert_array_equal(expected.values, actual.values) - assert_array_equal(expected.columns.values, actual.columns.values) + assert_array_equal(expected_df.values, np.asarray(actual.values)) + assert_array_equal( + expected_df.columns.values, np.asarray(actual.columns.values) + ) with pytest.raises(ValueError, match="does not match the set of dimensions"): arr.to_dask_dataframe(dim_order=["B", "A", "C"]) @@ -3459,8 +3471,8 @@ def test_to_pandas_name_matches_coordinate(self) -> None: # coordinate with same name as array arr = DataArray([1, 2, 3], dims="x", name="x") series = arr.to_series() - assert_array_equal([1, 2, 3], series.values) - assert_array_equal([0, 1, 2], series.index.values) + assert_array_equal([1, 2, 3], list(series.values)) + assert_array_equal([0, 1, 2], list(series.index.values)) assert "x" == series.name assert "x" == series.index.name @@ -3510,9 +3522,9 @@ def test_from_multiindex_series_sparse(self) -> None: import sparse idx = pd.MultiIndex.from_product([np.arange(3), np.arange(5)], names=["a", "b"]) - series = pd.Series(np.random.RandomState(0).random(len(idx)), index=idx).sample( - n=5, random_state=3 - ) + series: pd.Series = pd.Series( + np.random.RandomState(0).random(len(idx)), index=idx + ).sample(n=5, random_state=3) dense = DataArray.from_series(series, sparse=False) expected_coords = sparse.COO.from_numpy(dense.data, np.nan).coords @@ -3539,7 +3551,7 @@ def test_nbytes_does_not_load_data(self) -> None: def test_to_and_from_empty_series(self) -> None: # GH697 - expected = pd.Series([], dtype=np.float64) + expected: pd.Series[Any] = pd.Series([], dtype=np.float64) da = DataArray.from_series(expected) assert len(da) == 0 actual = da.to_series() diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f6829861776..f859ff38ccb 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -17,9 +17,9 @@ # remove once numpy 2.0 is the oldest supported version try: - from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] + from numpy.exceptions import RankWarning except ImportError: - from numpy import RankWarning + from numpy import RankWarning # type: ignore[no-redef,attr-defined,unused-ignore] import xarray as xr from xarray import ( @@ -39,7 +39,9 @@ from xarray.core.common import duck_array_ops, full_like from xarray.core.coordinates import Coordinates, DatasetCoordinates from xarray.core.indexes import Index, PandasIndex +from xarray.core.types import ArrayLike from xarray.core.utils import is_scalar +from xarray.groupers import TimeResampler from xarray.namedarray.pycompat import array_type, integer_types from xarray.testing import _assert_internal_invariants from xarray.tests import ( @@ -84,7 +86,9 @@ try: from numpy import trapezoid # type: ignore[attr-defined,unused-ignore] except ImportError: - from numpy import trapz as trapezoid + from numpy import ( # type: ignore[arg-type,no-redef,attr-defined,unused-ignore] + trapz as trapezoid, + ) sparse_array_type = array_type("sparse") @@ -580,6 +584,7 @@ def test_constructor_pandas_single(self) -> None: pandas_obj = a.to_pandas() ds_based_on_pandas = Dataset(pandas_obj) # type: ignore # TODO: improve typing of __init__ for dim in ds_based_on_pandas.data_vars: + assert isinstance(dim, int) assert_array_equal(ds_based_on_pandas[dim], pandas_obj[dim]) def test_constructor_compat(self) -> None: @@ -1192,6 +1197,54 @@ def get_dask_names(ds): ): data.chunk({"foo": 10}) + @requires_dask + @pytest.mark.parametrize( + "calendar", + ( + "standard", + pytest.param( + "gregorian", + marks=pytest.mark.skipif(not has_cftime, reason="needs cftime"), + ), + ), + ) + @pytest.mark.parametrize("freq", ["D", "W", "5ME", "YE"]) + def test_chunk_by_frequency(self, freq, calendar) -> None: + import dask.array + + N = 365 * 2 + ds = Dataset( + { + "pr": ("time", dask.array.random.random((N), chunks=(20))), + "pr2d": (("x", "time"), dask.array.random.random((10, N), chunks=(20))), + "ones": ("time", np.ones((N,))), + }, + coords={ + "time": xr.date_range( + "2001-01-01", periods=N, freq="D", calendar=calendar + ) + }, + ) + rechunked = ds.chunk(x=2, time=TimeResampler(freq)) + expected = tuple(ds.ones.resample(time=freq).sum().data.tolist()) + assert rechunked.chunksizes["time"] == expected + assert rechunked.chunksizes["x"] == (2,) * 5 + + rechunked = ds.chunk({"x": 2, "time": TimeResampler(freq)}) + assert rechunked.chunksizes["time"] == expected + assert rechunked.chunksizes["x"] == (2,) * 5 + + def test_chunk_by_frequecy_errors(self): + ds = Dataset({"foo": ("x", [1, 2, 3])}) + with pytest.raises(ValueError, match="virtual variable"): + ds.chunk(x=TimeResampler("YE")) + ds["x"] = ("x", [1, 2, 3]) + with pytest.raises(ValueError, match="datetime variables"): + ds.chunk(x=TimeResampler("YE")) + ds["x"] = ("x", xr.date_range("2001-01-01", periods=3, freq="D")) + with pytest.raises(ValueError, match="Invalid frequency"): + ds.chunk(x=TimeResampler("foo")) + @requires_dask def test_dask_is_lazy(self) -> None: store = InaccessibleVariableDataStore() @@ -1694,7 +1747,7 @@ def test_sel_categorical_error(self) -> None: with pytest.raises(ValueError): ds.sel(ind="bar", method="nearest") with pytest.raises(ValueError): - ds.sel(ind="bar", tolerance="nearest") + ds.sel(ind="bar", tolerance="nearest") # type: ignore[arg-type] def test_categorical_index(self) -> None: cat = pd.CategoricalIndex( @@ -2044,9 +2097,9 @@ def test_to_pandas(self) -> None: y = np.random.randn(10) t = list("abcdefghij") ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t)}) - actual = ds.to_pandas() - expected = ds.to_dataframe() - assert expected.equals(actual), (expected, actual) + actual_df = ds.to_pandas() + expected_df = ds.to_dataframe() + assert expected_df.equals(actual_df), (expected_df, actual_df) # 2D -> error x2d = np.random.randn(10, 10) @@ -3618,6 +3671,7 @@ def test_reset_index_drop_convert( def test_reorder_levels(self) -> None: ds = create_test_multiindex() mindex = ds["x"].to_index() + assert isinstance(mindex, pd.MultiIndex) midx = mindex.reorder_levels(["level_2", "level_1"]) midx_coords = Coordinates.from_pandas_multiindex(midx, "x") expected = Dataset({}, coords=midx_coords) @@ -3943,7 +3997,9 @@ def test_to_stacked_array_dtype_dims(self) -> None: D = xr.Dataset({"a": a, "b": b}) sample_dims = ["x"] y = D.to_stacked_array("features", sample_dims) - assert y.xindexes["features"].to_pandas_index().levels[1].dtype == D.y.dtype + mindex = y.xindexes["features"].to_pandas_index() + assert isinstance(mindex, pd.MultiIndex) + assert mindex.levels[1].dtype == D.y.dtype assert y.dims == ("x", "features") def test_to_stacked_array_to_unstacked_dataset(self) -> None: @@ -4114,9 +4170,9 @@ def test_virtual_variables_default_coords(self) -> None: def test_virtual_variables_time(self) -> None: # access virtual variables data = create_test_data() - assert_array_equal( - data["time.month"].values, data.variables["time"].to_index().month - ) + index = data.variables["time"].to_index() + assert isinstance(index, pd.DatetimeIndex) + assert_array_equal(data["time.month"].values, index.month) assert_array_equal(data["time.season"].values, "DJF") # test virtual variable math assert_array_equal(data["time.dayofyear"] + 1, 2 + np.arange(20)) @@ -4450,6 +4506,54 @@ def test_assign_attrs(self) -> None: assert_identical(new_actual, expected) assert actual.attrs == dict(a=1, b=2) + def test_drop_attrs(self) -> None: + # Simple example + ds = Dataset().assign_attrs(a=1, b=2) + original = ds.copy() + expected = Dataset() + result = ds.drop_attrs() + assert_identical(result, expected) + + # Doesn't change original + assert_identical(ds, original) + + # Example with variables and coords with attrs, and a multiindex. (arguably + # should have used a canonical dataset with all the features we're should + # support...) + var = Variable("x", [1, 2, 3], attrs=dict(x=1, y=2)) + idx = IndexVariable("y", [1, 2, 3], attrs=dict(c=1, d=2)) + mx = xr.Coordinates.from_pandas_multiindex( + pd.MultiIndex.from_tuples([(1, 2), (3, 4)], names=["d", "e"]), "z" + ) + ds = Dataset(dict(var1=var), coords=dict(y=idx, z=mx)).assign_attrs(a=1, b=2) + assert ds.attrs != {} + assert ds["var1"].attrs != {} + assert ds["y"].attrs != {} + assert ds.coords["y"].attrs != {} + + original = ds.copy(deep=True) + result = ds.drop_attrs() + + assert result.attrs == {} + assert result["var1"].attrs == {} + assert result["y"].attrs == {} + assert list(result.data_vars) == list(ds.data_vars) + assert list(result.coords) == list(ds.coords) + + # Doesn't change original + assert_identical(ds, original) + # Specifically test that the attrs on the coords are still there. (The index + # can't currently contain `attrs`, so we can't test those.) + assert ds.coords["y"].attrs != {} + + # Test for deep=False + result_shallow = ds.drop_attrs(deep=False) + assert result_shallow.attrs == {} + assert result_shallow["var1"].attrs != {} + assert result_shallow["y"].attrs != {} + assert list(result.data_vars) == list(ds.data_vars) + assert list(result.coords) == list(ds.coords) + def test_assign_multiindex_level(self) -> None: data = create_test_multiindex() with pytest.raises(ValueError, match=r"cannot drop or update.*corrupt.*index "): @@ -4757,20 +4861,20 @@ def test_to_and_from_dataframe(self) -> None: # check pathological cases df = pd.DataFrame([1]) - actual = Dataset.from_dataframe(df) - expected = Dataset({0: ("index", [1])}, {"index": [0]}) - assert_identical(expected, actual) + actual_ds = Dataset.from_dataframe(df) + expected_ds = Dataset({0: ("index", [1])}, {"index": [0]}) + assert_identical(expected_ds, actual_ds) df = pd.DataFrame() - actual = Dataset.from_dataframe(df) - expected = Dataset(coords={"index": []}) - assert_identical(expected, actual) + actual_ds = Dataset.from_dataframe(df) + expected_ds = Dataset(coords={"index": []}) + assert_identical(expected_ds, actual_ds) # GH697 df = pd.DataFrame({"A": []}) - actual = Dataset.from_dataframe(df) - expected = Dataset({"A": DataArray([], dims=("index",))}, {"index": []}) - assert_identical(expected, actual) + actual_ds = Dataset.from_dataframe(df) + expected_ds = Dataset({"A": DataArray([], dims=("index",))}, {"index": []}) + assert_identical(expected_ds, actual_ds) # regression test for GH278 # use int64 to ensure consistent results for the pandas .equals method @@ -4809,7 +4913,7 @@ def test_from_dataframe_categorical_index(self) -> None: def test_from_dataframe_categorical_index_string_categories(self) -> None: cat = pd.CategoricalIndex( pd.Categorical.from_codes( - np.array([1, 1, 0, 2]), + np.array([1, 1, 0, 2], dtype=np.int64), # type: ignore[arg-type] categories=pd.Index(["foo", "bar", "baz"], dtype="string"), ) ) @@ -4894,7 +4998,7 @@ def test_from_dataframe_unsorted_levels(self) -> None: def test_from_dataframe_non_unique_columns(self) -> None: # regression test for GH449 df = pd.DataFrame(np.zeros((2, 2))) - df.columns = ["foo", "foo"] + df.columns = ["foo", "foo"] # type: ignore[assignment] with pytest.raises(ValueError, match=r"non-unique columns"): Dataset.from_dataframe(df) @@ -7183,6 +7287,7 @@ def test_cumulative_integrate(dask) -> None: @pytest.mark.parametrize("which_datetime", ["np", "cftime"]) def test_trapezoid_datetime(dask, which_datetime) -> None: rs = np.random.RandomState(42) + coord: ArrayLike if which_datetime == "np": coord = np.array( [ diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index f7cff17bab5..c875322b9c5 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -561,6 +561,30 @@ def test_roundtrip_unnamed_root(self, simple_datatree): roundtrip = DataTree.from_dict(dt.to_dict()) assert roundtrip.equals(dt) + def test_insertion_order(self): + # regression test for GH issue #9276 + reversed = DataTree.from_dict( + { + "/Homer/Lisa": xr.Dataset({"age": 8}), + "/Homer/Bart": xr.Dataset({"age": 10}), + "/Homer": xr.Dataset({"age": 39}), + "/": xr.Dataset({"age": 83}), + } + ) + expected = DataTree.from_dict( + { + "/": xr.Dataset({"age": 83}), + "/Homer": xr.Dataset({"age": 39}), + "/Homer/Lisa": xr.Dataset({"age": 8}), + "/Homer/Bart": xr.Dataset({"age": 10}), + } + ) + assert reversed.equals(expected) + + # Check that Bart and Lisa's order is still preserved within the group, + # despite 'Bart' coming before 'Lisa' when sorted alphabetically + assert list(reversed["Homer"].children.keys()) == ["Lisa", "Bart"] + class TestDatasetView: def test_view_contents(self): @@ -593,7 +617,7 @@ def test_methods(self): ds = create_test_data() dt: DataTree = DataTree(data=ds) assert ds.mean().identical(dt.ds.mean()) - assert type(dt.ds.mean()) == xr.Dataset + assert isinstance(dt.ds.mean(), xr.Dataset) def test_arithmetic(self, create_test_datatree): dt = create_test_datatree() diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index afcf10ec125..3bbae55b105 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -347,7 +347,7 @@ def construct_dataarray(dim_num, dtype, contains_nan, dask): array = rng.randint(0, 10, size=shapes).astype(dtype) elif np.issubdtype(dtype, np.bool_): array = rng.randint(0, 1, size=shapes).astype(dtype) - elif dtype == str: + elif dtype is str: array = rng.choice(["a", "b", "c", "d"], size=shapes) else: raise ValueError diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 9d0eb81bace..0bd8abc3a70 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -118,9 +118,9 @@ def test_format_items(self) -> None: np.arange(4) * np.timedelta64(500, "ms"), "00:00:00 00:00:00.500000 00:00:01 00:00:01.500000", ), - (pd.to_timedelta(["NaT", "0s", "1s", "NaT"]), "NaT 00:00:00 00:00:01 NaT"), + (pd.to_timedelta(["NaT", "0s", "1s", "NaT"]), "NaT 00:00:00 00:00:01 NaT"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 ( - pd.to_timedelta(["1 day 1 hour", "1 day", "0 hours"]), + pd.to_timedelta(["1 day 1 hour", "1 day", "0 hours"]), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 "1 days 01:00:00 1 days 00:00:00 0 days 00:00:00", ), ([1, 2, 3], "1 2 3"), diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index f0a0fd14d9d..6c9254966d9 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -1,6 +1,5 @@ from __future__ import annotations -import datetime import operator import warnings from unittest import mock @@ -14,6 +13,7 @@ from xarray import DataArray, Dataset, Variable from xarray.core.groupby import _consolidate_slices from xarray.core.types import InterpOptions +from xarray.groupers import BinGrouper, EncodedGroups, Grouper, UniqueGrouper from xarray.tests import ( InaccessibleArray, assert_allclose, @@ -22,6 +22,7 @@ create_test_data, has_cftime, has_flox, + requires_cftime, requires_dask, requires_flox, requires_scipy, @@ -62,59 +63,44 @@ def test_consolidate_slices() -> None: @pytest.mark.filterwarnings("ignore:return type") -def test_groupby_dims_property(dataset, recwarn) -> None: - # dims is sensitive to squeeze, always warn - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - assert dataset.groupby("x").dims == dataset.isel(x=1).dims - assert dataset.groupby("y").dims == dataset.isel(y=1).dims - # in pytest-8, pytest.warns() no longer clears all warnings - recwarn.clear() +def test_groupby_dims_property(dataset) -> None: + with pytest.warns(FutureWarning, match="The return type of"): + assert dataset.groupby("x").dims == dataset.isel(x=[1]).dims + with pytest.warns(FutureWarning, match="The return type of"): + assert dataset.groupby("y").dims == dataset.isel(y=[1]).dims - # when squeeze=False, no warning should be raised - assert tuple(dataset.groupby("x", squeeze=False).dims) == tuple( - dataset.isel(x=slice(1, 2)).dims - ) - assert tuple(dataset.groupby("y", squeeze=False).dims) == tuple( - dataset.isel(y=slice(1, 2)).dims - ) - assert len(recwarn) == 0 + assert tuple(dataset.groupby("x").dims) == tuple(dataset.isel(x=slice(1, 2)).dims) + assert tuple(dataset.groupby("y").dims) == tuple(dataset.isel(y=slice(1, 2)).dims) dataset = dataset.drop_vars(["cat"]) stacked = dataset.stack({"xy": ("x", "y")}) - assert tuple(stacked.groupby("xy", squeeze=False).dims) == tuple( - stacked.isel(xy=[0]).dims - ) - assert len(recwarn) == 0 + assert tuple(stacked.groupby("xy").dims) == tuple(stacked.isel(xy=[0]).dims) def test_groupby_sizes_property(dataset) -> None: - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - assert dataset.groupby("x").sizes == dataset.isel(x=1).sizes - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - assert dataset.groupby("y").sizes == dataset.isel(y=1).sizes + assert dataset.groupby("x").sizes == dataset.isel(x=[1]).sizes + assert dataset.groupby("y").sizes == dataset.isel(y=[1]).sizes dataset = dataset.drop_vars("cat") stacked = dataset.stack({"xy": ("x", "y")}) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - assert stacked.groupby("xy").sizes == stacked.isel(xy=0).sizes + assert stacked.groupby("xy").sizes == stacked.isel(xy=[0]).sizes def test_multi_index_groupby_map(dataset) -> None: # regression test for GH873 ds = dataset.isel(z=1, drop=True)[["foo"]] expected = 2 * ds - # The function in `map` may be sensitive to squeeze, always warn - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - actual = ( - ds.stack(space=["x", "y"]) - .groupby("space") - .map(lambda x: 2 * x) - .unstack("space") - ) + actual = ( + ds.stack(space=["x", "y"]) + .groupby("space") + .map(lambda x: 2 * x) + .unstack("space") + ) assert_equal(expected, actual) -def test_reduce_numeric_only(dataset) -> None: - gb = dataset.groupby("x", squeeze=False) +@pytest.mark.parametrize("grouper", [dict(group="x"), dict(x=UniqueGrouper())]) +def test_reduce_numeric_only(dataset, grouper: dict) -> None: + gb = dataset.groupby(**grouper) with xr.set_options(use_flox=False): expected = gb.sum() with xr.set_options(use_flox=True): @@ -221,8 +207,7 @@ def func(arg1, arg2, arg3=0): array = xr.DataArray([1, 1, 1], [("x", [1, 2, 3])]) expected = xr.DataArray([3, 3, 3], [("x", [1, 2, 3])]) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - actual = array.groupby("x").map(func, args=(1,), arg3=1) + actual = array.groupby("x").map(func, args=(1,), arg3=1) assert_identical(expected, actual) @@ -232,9 +217,7 @@ def func(arg1, arg2, arg3=0): dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]}) expected = xr.Dataset({"foo": ("x", [3, 3, 3])}, {"x": [1, 2, 3]}) - # The function in `map` may be sensitive to squeeze, always warn - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - actual = dataset.groupby("x").map(func, args=(1,), arg3=1) + actual = dataset.groupby("x").map(func, args=(1,), arg3=1) assert_identical(expected, actual) @@ -548,10 +531,8 @@ def test_da_groupby_assign_coords() -> None: actual = xr.DataArray( [[3, 4, 5], [6, 7, 8]], dims=["y", "x"], coords={"y": range(2), "x": range(3)} ) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - actual1 = actual.groupby("x").assign_coords({"y": [-1, -2]}) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - actual2 = actual.groupby("x").assign_coords(y=[-1, -2]) + actual1 = actual.groupby("x").assign_coords({"y": [-1, -2]}) + actual2 = actual.groupby("x").assign_coords(y=[-1, -2]) expected = xr.DataArray( [[3, 4, 5], [6, 7, 8]], dims=["y", "x"], coords={"y": [-1, -2], "x": range(3)} ) @@ -707,11 +688,10 @@ def test_groupby_reduce_dimension_error(array) -> None: with pytest.raises(ValueError, match=r"cannot reduce over dimensions"): grouped.mean(("x", "y", "asd")) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - assert_identical(array.mean("x"), grouped.reduce(np.mean, "x")) - assert_allclose(array.mean(["x", "z"]), grouped.reduce(np.mean, ["x", "z"])) + assert_identical(array.mean("x"), grouped.reduce(np.mean, "x")) + assert_allclose(array.mean(["x", "z"]), grouped.reduce(np.mean, ["x", "z"])) - grouped = array.groupby("y", squeeze=False) + grouped = array.groupby("y") assert_identical(array, grouped.mean()) assert_identical(array.mean("x"), grouped.reduce(np.mean, "x")) @@ -720,7 +700,7 @@ def test_groupby_reduce_dimension_error(array) -> None: def test_groupby_multiple_string_args(array) -> None: with pytest.raises(TypeError): - array.groupby("x", "y") + array.groupby("x", squeeze="y") def test_groupby_bins_timeseries() -> None: @@ -734,7 +714,7 @@ def test_groupby_bins_timeseries() -> None: expected = xr.DataArray( 96 * np.ones((14,)), dims=["time_bins"], - coords={"time_bins": pd.cut(time_bins, time_bins).categories}, + coords={"time_bins": pd.cut(time_bins, time_bins).categories}, # type: ignore[arg-type] ).to_dataset(name="val") assert_identical(actual, expected) @@ -752,34 +732,19 @@ def test_groupby_none_group_name() -> None: def test_groupby_getitem(dataset) -> None: - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - assert_identical(dataset.sel(x="a"), dataset.groupby("x")["a"]) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - assert_identical(dataset.sel(z=1), dataset.groupby("z")[1]) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - assert_identical(dataset.foo.sel(x="a"), dataset.foo.groupby("x")["a"]) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - assert_identical(dataset.foo.sel(z=1), dataset.foo.groupby("z")[1]) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - assert_identical(dataset.cat.sel(y=1), dataset.cat.groupby("y")[1]) - - assert_identical(dataset.sel(x=["a"]), dataset.groupby("x", squeeze=False)["a"]) - assert_identical(dataset.sel(z=[1]), dataset.groupby("z", squeeze=False)[1]) - - assert_identical( - dataset.foo.sel(x=["a"]), dataset.foo.groupby("x", squeeze=False)["a"] - ) - assert_identical(dataset.foo.sel(z=[1]), dataset.foo.groupby("z", squeeze=False)[1]) - - assert_identical(dataset.cat.sel(y=[1]), dataset.cat.groupby("y", squeeze=False)[1]) + + assert_identical(dataset.sel(x=["a"]), dataset.groupby("x")["a"]) + assert_identical(dataset.sel(z=[1]), dataset.groupby("z")[1]) + assert_identical(dataset.foo.sel(x=["a"]), dataset.foo.groupby("x")["a"]) + assert_identical(dataset.foo.sel(z=[1]), dataset.foo.groupby("z")[1]) + assert_identical(dataset.cat.sel(y=[1]), dataset.cat.groupby("y")[1]) + with pytest.raises( NotImplementedError, match="Cannot broadcast 1d-only pandas categorical array." ): - dataset.groupby("boo", squeeze=False) + dataset.groupby("boo") dataset = dataset.drop_vars(["cat"]) - actual = ( - dataset.groupby("boo", squeeze=False)["f"].unstack().transpose("x", "y", "z") - ) + actual = dataset.groupby("boo")["f"].unstack().transpose("x", "y", "z") expected = dataset.sel(y=[1], z=[1, 2]).transpose("x", "y", "z") assert_identical(expected, actual) @@ -789,7 +754,7 @@ def test_groupby_dataset() -> None: {"z": (["x", "y"], np.random.randn(3, 5))}, {"x": ("x", list("abc")), "c": ("x", [0, 1, 0]), "y": range(5)}, ) - groupby = data.groupby("x", squeeze=False) + groupby = data.groupby("x") assert len(groupby) == 3 expected_groups = {"a": slice(0, 1), "b": slice(1, 2), "c": slice(2, 3)} assert groupby.groups == expected_groups @@ -806,55 +771,25 @@ def identity(x): return x for k in ["x", "c", "y"]: - actual2 = data.groupby(k, squeeze=False).map(identity) + actual2 = data.groupby(k).map(identity) assert_equal(data, actual2) -def test_groupby_dataset_squeeze_None() -> None: - """Delete when removing squeeze.""" - data = Dataset( - {"z": (["x", "y"], np.random.randn(3, 5))}, - {"x": ("x", list("abc")), "c": ("x", [0, 1, 0]), "y": range(5)}, - ) - groupby = data.groupby("x") - assert len(groupby) == 3 - expected_groups = {"a": 0, "b": 1, "c": 2} - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - assert groupby.groups == expected_groups - expected_items = [ - ("a", data.isel(x=0)), - ("b", data.isel(x=1)), - ("c", data.isel(x=2)), - ] - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - for actual1, expected1 in zip(groupby, expected_items): - assert actual1[0] == expected1[0] - assert_equal(actual1[1], expected1[1]) - - def identity(x): - return x - - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): - for k in ["x", "c"]: - actual2 = data.groupby(k).map(identity) - assert_equal(data, actual2) - - def test_groupby_dataset_returns_new_type() -> None: data = Dataset({"z": (["x", "y"], np.random.randn(3, 5))}) - actual1 = data.groupby("x", squeeze=False).map(lambda ds: ds["z"]) + actual1 = data.groupby("x").map(lambda ds: ds["z"]) expected1 = data["z"] assert_identical(expected1, actual1) - actual2 = data["z"].groupby("x", squeeze=False).map(lambda x: x.to_dataset()) + actual2 = data["z"].groupby("x").map(lambda x: x.to_dataset()) expected2 = data assert_identical(expected2, actual2) def test_groupby_dataset_iter() -> None: data = create_test_data() - for n, (t, sub) in enumerate(list(data.groupby("dim1", squeeze=False))[:3]): + for n, (t, sub) in enumerate(list(data.groupby("dim1"))[:3]): assert data["dim1"][n] == t assert_equal(data["var1"][[n]], sub["var1"]) assert_equal(data["var2"][[n]], sub["var2"]) @@ -868,10 +803,17 @@ def test_groupby_dataset_errors() -> None: with pytest.raises(ValueError, match=r"length does not match"): data.groupby(data["dim1"][:3]) with pytest.raises(TypeError, match=r"`group` must be"): - data.groupby(data.coords["dim1"].to_index()) + data.groupby(data.coords["dim1"].to_index()) # type: ignore[arg-type] -def test_groupby_dataset_reduce() -> None: +@pytest.mark.parametrize( + "by_func", + [ + pytest.param(lambda x: x, id="group-by-string"), + pytest.param(lambda x: {x: UniqueGrouper()}, id="group-by-unique-grouper"), + ], +) +def test_groupby_dataset_reduce_ellipsis(by_func) -> None: data = Dataset( { "xy": (["x", "y"], np.random.randn(3, 4)), @@ -883,10 +825,11 @@ def test_groupby_dataset_reduce() -> None: expected = data.mean("y") expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3}) - actual = data.groupby("x").mean(...) + gb = data.groupby(by_func("x")) + actual = gb.mean(...) assert_allclose(expected, actual) - actual = data.groupby("x").mean("y") + actual = gb.mean("y") assert_allclose(expected, actual) letters = data["letters"] @@ -897,18 +840,18 @@ def test_groupby_dataset_reduce() -> None: "yonly": data["yonly"].groupby(letters).mean(), } ) - actual = data.groupby("letters").mean(...) + gb = data.groupby(by_func("letters")) + actual = gb.mean(...) assert_allclose(expected, actual) -@pytest.mark.parametrize("squeeze", [True, False]) -def test_groupby_dataset_math(squeeze: bool) -> None: +def test_groupby_dataset_math() -> None: def reorder_dims(x): return x.transpose("dim1", "dim2", "dim3", "time") ds = create_test_data() ds["dim1"] = ds["dim1"] - grouped = ds.groupby("dim1", squeeze=squeeze) + grouped = ds.groupby("dim1") expected = reorder_dims(ds + ds.coords["dim1"]) actual = grouped + ds.coords["dim1"] @@ -1017,7 +960,7 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None: with xr.set_options(use_flox=use_flox): actual = da.groupby_bins( - "x", bins=x_bins, include_lowest=True, right=False, squeeze=False + "x", bins=x_bins, include_lowest=True, right=False ).mean() expected = xr.DataArray( np.array([[1.0, 2.0], [5.0, 6.0], [9.0, 10.0]]), @@ -1028,15 +971,29 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None: ) assert_identical(expected, actual) + with xr.set_options(use_flox=use_flox): + actual = da.groupby( + x=BinGrouper(bins=x_bins, include_lowest=True, right=False), + ).mean() + assert_identical(expected, actual) + @pytest.mark.parametrize("indexed_coord", [True, False]) -def test_groupby_bins_math(indexed_coord) -> None: +@pytest.mark.parametrize( + ["groupby_method", "args"], + ( + ("groupby_bins", ("x", np.arange(0, 8, 3))), + ("groupby", ({"x": BinGrouper(bins=np.arange(0, 8, 3))},)), + ), +) +def test_groupby_bins_math(groupby_method, args, indexed_coord) -> None: N = 7 da = DataArray(np.random.random((N, N)), dims=("x", "y")) if indexed_coord: da["x"] = np.arange(N) da["y"] = np.arange(N) - g = da.groupby_bins("x", np.arange(0, N + 1, 3)) + + g = getattr(da, groupby_method)(*args) mean = g.mean() expected = da.isel(x=slice(1, None)) - mean.isel(x_bins=("x", [0, 0, 0, 1, 1, 1])) actual = g - mean @@ -1243,11 +1200,10 @@ def test_stack_groupby_unsorted_coord(self) -> None: def test_groupby_iter(self) -> None: for (act_x, act_dv), (exp_x, exp_ds) in zip( - self.dv.groupby("y", squeeze=False), self.ds.groupby("y", squeeze=False) + self.dv.groupby("y"), self.ds.groupby("y") ): assert exp_x == act_x assert_identical(exp_ds["foo"], act_dv) - with pytest.warns(UserWarning, match="The `squeeze` kwarg"): for (_, exp_dv), (_, act_dv) in zip( self.dv.groupby("x"), self.dv.groupby("x") ): @@ -1272,8 +1228,7 @@ def test_groupby_properties(self) -> None: "by, use_da", [("x", False), ("y", False), ("y", True), ("abc", False)] ) @pytest.mark.parametrize("shortcut", [True, False]) - @pytest.mark.parametrize("squeeze", [None, True, False]) - def test_groupby_map_identity(self, by, use_da, shortcut, squeeze, recwarn) -> None: + def test_groupby_map_identity(self, by, use_da, shortcut) -> None: expected = self.da if use_da: by = expected.coords[by] @@ -1281,14 +1236,10 @@ def test_groupby_map_identity(self, by, use_da, shortcut, squeeze, recwarn) -> N def identity(x): return x - grouped = expected.groupby(by, squeeze=squeeze) + grouped = expected.groupby(by) actual = grouped.map(identity, shortcut=shortcut) assert_identical(expected, actual) - # abc is not a dim coordinate so no warnings expected! - if (by.name if use_da else by) != "abc": - assert len(recwarn) == (1 if squeeze in [None, True] else 0) - def test_groupby_sum(self) -> None: array = self.da grouped = array.groupby("abc") @@ -1448,10 +1399,9 @@ def change_metadata(x): expected = change_metadata(expected) assert_equal(expected, actual) - @pytest.mark.parametrize("squeeze", [True, False]) - def test_groupby_math_squeeze(self, squeeze: bool) -> None: + def test_groupby_math_squeeze(self) -> None: array = self.da - grouped = array.groupby("x", squeeze=squeeze) + grouped = array.groupby("x") expected = array + array.coords["x"] actual = grouped + array.coords["x"] @@ -1521,7 +1471,7 @@ def test_groupby_restore_dim_order(self) -> None: ("a", ("a", "y")), ("b", ("x", "b")), ]: - result = array.groupby(by, squeeze=False).map(lambda x: x.squeeze()) + result = array.groupby(by).map(lambda x: x.squeeze()) assert result.dims == expected_dims def test_groupby_restore_coord_dims(self) -> None: @@ -1541,7 +1491,7 @@ def test_groupby_restore_coord_dims(self) -> None: ("a", ("a", "y")), ("b", ("x", "b")), ]: - result = array.groupby(by, squeeze=False, restore_coord_dims=True).map( + result = array.groupby(by, restore_coord_dims=True).map( lambda x: x.squeeze() )["c"] assert result.dims == expected_dims @@ -1624,7 +1574,7 @@ def test_groupby_bins( bins = [0, 1.5, 5] df = array.to_dataframe() - df["dim_0_bins"] = pd.cut(array["dim_0"], bins, **cut_kwargs) + df["dim_0_bins"] = pd.cut(array["dim_0"], bins, **cut_kwargs) # type: ignore[call-overload] expected_df = df.groupby("dim_0_bins", observed=True).sum() # TODO: can't convert df with IntervalIndex to Xarray @@ -1690,7 +1640,7 @@ def test_groupby_bins_empty(self) -> None: array = DataArray(np.arange(4), [("x", range(4))]) # one of these bins will be empty bins = [0, 4, 5] - bin_coords = pd.cut(array["x"], bins).categories + bin_coords = pd.cut(array["x"], bins).categories # type: ignore[call-overload] actual = array.groupby_bins("x", bins).sum() expected = DataArray([6, np.nan], dims="x_bins", coords={"x_bins": bin_coords}) assert_identical(expected, actual) @@ -1701,7 +1651,7 @@ def test_groupby_bins_empty(self) -> None: def test_groupby_bins_multidim(self) -> None: array = self.make_groupby_multidim_example_array() bins = [0, 15, 20] - bin_coords = pd.cut(array["lat"].values.flat, bins).categories + bin_coords = pd.cut(array["lat"].values.flat, bins).categories # type: ignore[call-overload] expected = DataArray([16, 40], dims="lat_bins", coords={"lat_bins": bin_coords}) actual = array.groupby_bins("lat", bins).map(lambda x: x.sum()) assert_identical(expected, actual) @@ -1768,6 +1718,23 @@ def test_groupby_fillna(self) -> None: actual = a.groupby("b").fillna(DataArray([0, 2], dims="b")) assert_identical(expected, actual) + @pytest.mark.parametrize("use_flox", [True, False]) + def test_groupby_fastpath_for_monotonic(self, use_flox: bool) -> None: + # Fixes https://github.com/pydata/xarray/issues/6220 + # Fixes https://github.com/pydata/xarray/issues/9279 + index = [1, 2, 3, 4, 7, 9, 10] + array = DataArray(np.arange(len(index)), [("idx", index)]) + array_rev = array.copy().assign_coords({"idx": index[::-1]}) + fwd = array.groupby("idx", squeeze=False) + rev = array_rev.groupby("idx", squeeze=False) + + for gb in [fwd, rev]: + assert all([isinstance(elem, slice) for elem in gb._group_indices]) + + with xr.set_options(use_flox=use_flox): + assert_identical(fwd.sum(), array) + assert_identical(rev.sum(), array_rev) + class TestDataArrayResample: @pytest.mark.parametrize("use_cftime", [True, False]) @@ -1804,9 +1771,13 @@ def resample_as_pandas(array, *args, **kwargs): expected = resample_as_pandas(array, "24h", closed="right") assert_identical(expected, actual) - with pytest.raises(ValueError, match=r"index must be monotonic"): + with pytest.raises(ValueError, match=r"Index must be monotonic"): array[[2, 0, 1]].resample(time="1D") + reverse = array.isel(time=slice(-1, None, -1)) + with pytest.raises(ValueError): + reverse.resample(time="1D").mean() + @pytest.mark.parametrize("use_cftime", [True, False]) def test_resample_doctest(self, use_cftime: bool) -> None: # run the doctest example here so we are not surprised @@ -2180,19 +2151,6 @@ def test_upsample_interpolate_dask(self, chunked_time: bool) -> None: # done here due to floating point arithmetic assert_allclose(expected, actual, rtol=1e-16) - def test_resample_base(self) -> None: - times = pd.date_range("2000-01-01T02:03:01", freq="6h", periods=10) - array = DataArray(np.arange(10), [("time", times)]) - - base = 11 - - with pytest.warns(FutureWarning, match="the `base` parameter to resample"): - actual = array.resample(time="24h", base=base).mean() - expected = DataArray( - array.to_series().resample("24h", offset=f"{base}h").mean() - ) - assert_identical(expected, actual) - def test_resample_offset(self) -> None: times = pd.date_range("2000-01-01T02:03:01", freq="6h", periods=10) array = DataArray(np.arange(10), [("time", times)]) @@ -2211,38 +2169,6 @@ def test_resample_origin(self) -> None: expected = DataArray(array.to_series().resample("24h", origin=origin).mean()) assert_identical(expected, actual) - @pytest.mark.parametrize( - "loffset", - [ - "-12h", - datetime.timedelta(hours=-12), - pd.Timedelta(hours=-12), - pd.DateOffset(hours=-12), - ], - ) - def test_resample_loffset(self, loffset) -> None: - times = pd.date_range("2000-01-01", freq="6h", periods=10) - array = DataArray(np.arange(10), [("time", times)]) - - with pytest.warns(FutureWarning, match="`loffset` parameter"): - actual = array.resample(time="24h", loffset=loffset).mean() - series = array.to_series().resample("24h").mean() - if not isinstance(loffset, pd.DateOffset): - loffset = pd.Timedelta(loffset) - series.index = series.index + loffset - expected = DataArray(series) - assert_identical(actual, expected) - - def test_resample_invalid_loffset(self) -> None: - times = pd.date_range("2000-01-01", freq="6h", periods=10) - array = DataArray(np.arange(10), [("time", times)]) - - with pytest.warns( - FutureWarning, match="Following pandas, the `loffset` parameter" - ): - with pytest.raises(ValueError, match="`loffset` must be"): - array.resample(time="24h", loffset=1).mean() # type: ignore - class TestDatasetResample: def test_resample_and_first(self) -> None: @@ -2313,17 +2239,6 @@ def test_resample_by_mean_with_keep_attrs(self) -> None: expected = ds.attrs assert expected == actual - def test_resample_loffset(self) -> None: - times = pd.date_range("2000-01-01", freq="6h", periods=10) - ds = Dataset( - { - "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), - "bar": ("time", np.random.randn(10), {"meta": "data"}), - "time": times, - } - ) - ds.attrs["dsmeta"] = "dsdata" - def test_resample_by_mean_discarding_attrs(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( @@ -2383,25 +2298,6 @@ def test_resample_drop_nondim_coords(self) -> None: actual = ds.resample(time="1h").interpolate("linear") assert "tc" not in actual.coords - def test_resample_old_api(self) -> None: - times = pd.date_range("2000-01-01", freq="6h", periods=10) - ds = Dataset( - { - "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), - "bar": ("time", np.random.randn(10), {"meta": "data"}), - "time": times, - } - ) - - with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): - ds.resample("1D", "time") # type: ignore[arg-type] - - with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): - ds.resample("1D", dim="time", how="mean") # type: ignore[arg-type] - - with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): - ds.resample("1D", dim="time") # type: ignore[arg-type] - def test_resample_ds_da_are_the_same(self) -> None: time = pd.date_range("2000-01-01", freq="6h", periods=365 * 4) ds = xr.Dataset( @@ -2586,8 +2482,8 @@ def test_groupby_dim_no_dim_equal(use_flox: bool) -> None: data=[1, 2, 3, 4], dims="lat", coords={"lat": np.linspace(0, 1.01, 4)} ) with xr.set_options(use_flox=use_flox): - actual1 = da.drop_vars("lat").groupby("lat", squeeze=False).sum() - actual2 = da.groupby("lat", squeeze=False).sum() + actual1 = da.drop_vars("lat").groupby("lat").sum() + actual2 = da.groupby("lat").sum() assert_identical(actual1, actual2.drop_vars("lat")) @@ -2606,3 +2502,62 @@ def test_default_flox_method() -> None: assert kwargs["method"] == "cohorts" else: assert "method" not in kwargs + + +@requires_cftime +@pytest.mark.filterwarnings("ignore") +def test_cftime_resample_gh_9108(): + import cftime + + ds = Dataset( + {"pr": ("time", np.random.random((10,)))}, + coords={"time": xr.date_range("0001-01-01", periods=10, freq="D")}, + ) + actual = ds.resample(time="ME").mean() + expected = ds.mean("time").expand_dims( + time=[cftime.DatetimeGregorian(1, 1, 31, 0, 0, 0, 0, has_year_zero=False)] + ) + assert actual.time.data[0].has_year_zero == ds.time.data[0].has_year_zero + assert_equal(actual, expected) + + +def test_custom_grouper() -> None: + class YearGrouper(Grouper): + """ + An example re-implementation of ``.groupby("time.year")``. + """ + + def factorize(self, group) -> EncodedGroups: + assert np.issubdtype(group.dtype, np.datetime64) + year = group.dt.year.data + codes_, uniques = pd.factorize(year) + codes = group.copy(data=codes_).rename("year") + return EncodedGroups(codes=codes, full_index=pd.Index(uniques)) + + da = xr.DataArray( + dims="time", + data=np.arange(20), + coords={"time": ("time", pd.date_range("2000-01-01", freq="3MS", periods=20))}, + name="foo", + ) + ds = da.to_dataset() + + expected = ds.groupby("time.year").mean() + actual = ds.groupby(time=YearGrouper()).mean() + assert_identical(expected, actual) + + actual = ds.groupby({"time": YearGrouper()}).mean() + assert_identical(expected, actual) + + expected = ds.foo.groupby("time.year").mean() + actual = ds.foo.groupby(time=YearGrouper()).mean() + assert_identical(expected, actual) + + actual = ds.foo.groupby({"time": YearGrouper()}).mean() + assert_identical(expected, actual) + + for obj in [ds, ds.foo]: + with pytest.raises(ValueError): + obj.groupby("time.year", time=YearGrouper()) + with pytest.raises(ValueError): + obj.groupby() diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 5ebdfd5da6e..48e254b037b 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -410,13 +410,15 @@ def test_stack(self) -> None: "y": xr.Variable("y", pd.Index([1, 3, 2])), } - index = PandasMultiIndex.stack(prod_vars, "z") + index_xr = PandasMultiIndex.stack(prod_vars, "z") - assert index.dim == "z" + assert index_xr.dim == "z" + index_pd = index_xr.index + assert isinstance(index_pd, pd.MultiIndex) # TODO: change to tuple when pandas 3 is minimum - assert list(index.index.names) == ["x", "y"] + assert list(index_pd.names) == ["x", "y"] np.testing.assert_array_equal( - index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] + index_pd.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] ) with pytest.raises( @@ -433,13 +435,15 @@ def test_stack_non_unique(self) -> None: "y": xr.Variable("y", pd.Index([1, 1, 2])), } - index = PandasMultiIndex.stack(prod_vars, "z") + index_xr = PandasMultiIndex.stack(prod_vars, "z") + index_pd = index_xr.index + assert isinstance(index_pd, pd.MultiIndex) np.testing.assert_array_equal( - index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] + index_pd.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] ) - np.testing.assert_array_equal(index.index.levels[0], ["b", "a"]) - np.testing.assert_array_equal(index.index.levels[1], [1, 2]) + np.testing.assert_array_equal(index_pd.levels[0], ["b", "a"]) + np.testing.assert_array_equal(index_pd.levels[1], [1, 2]) def test_unstack(self) -> None: pd_midx = pd.MultiIndex.from_product( @@ -600,10 +604,7 @@ def indexes( _, variables = indexes_and_vars - if isinstance(x_idx, Index): - index_type = Index - else: - index_type = pd.Index + index_type = Index if isinstance(x_idx, Index) else pd.Index return Indexes(indexes, variables, index_type=index_type) diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 3d3584448de..7687765e659 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -180,7 +180,7 @@ def test_init(self, expected: Any) -> None: # Fail: ( ("x",), - NamedArray("time", np.array([1, 2, 3])), + NamedArray("time", np.array([1, 2, 3], dtype=np.dtype(np.int64))), np.array([1, 2, 3]), True, ), diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index fa08e9975ab..a973f6b11f7 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -6,7 +6,7 @@ from collections.abc import Generator, Hashable from copy import copy from datetime import date, timedelta -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, cast import numpy as np import pandas as pd @@ -158,9 +158,10 @@ def setup(self) -> Generator: plt.close("all") def pass_in_axis(self, plotmethod, subplot_kw=None) -> None: - fig, axs = plt.subplots(ncols=2, subplot_kw=subplot_kw) - plotmethod(ax=axs[0]) - assert axs[0].has_data() + fig, axs = plt.subplots(ncols=2, subplot_kw=subplot_kw, squeeze=False) + ax = axs[0, 0] + plotmethod(ax=ax) + assert ax.has_data() @pytest.mark.slow def imshow_called(self, plotmethod) -> bool: @@ -240,9 +241,9 @@ def test_1d_x_y_kw(self) -> None: xy: list[list[None | str]] = [[None, None], [None, "z"], ["z", None]] - f, ax = plt.subplots(3, 1) + f, axs = plt.subplots(3, 1, squeeze=False) for aa, (x, y) in enumerate(xy): - da.plot(x=x, y=y, ax=ax.flat[aa]) + da.plot(x=x, y=y, ax=axs.flat[aa]) with pytest.raises(ValueError, match=r"Cannot specify both"): da.plot(x="z", y="z") @@ -527,7 +528,7 @@ def test__infer_interval_breaks(self) -> None: [-0.5, 0.5, 5.0, 9.5, 10.5], _infer_interval_breaks([0, 1, 9, 10]) ) assert_array_equal( - pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"), + pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"), # type: ignore[operator] _infer_interval_breaks(pd.date_range("20000101", periods=3)), ) @@ -1047,7 +1048,9 @@ def test_list_levels(self) -> None: assert cmap_params["cmap"].N == 5 assert cmap_params["norm"].N == 6 - for wrap_levels in [list, np.array, pd.Index, DataArray]: + for wrap_levels in cast( + list[Callable[[Any], dict[Any, Any]]], [list, np.array, pd.Index, DataArray] + ): cmap_params = _determine_cmap_params(data, levels=wrap_levels(orig_levels)) assert_array_equal(cmap_params["levels"], orig_levels) @@ -1566,7 +1569,9 @@ def test_colorbar_kwargs(self) -> None: assert "MyLabel" in alltxt assert "testvar" not in alltxt # change cbar ax - fig, (ax, cax) = plt.subplots(1, 2) + fig, axs = plt.subplots(1, 2, squeeze=False) + ax = axs[0, 0] + cax = axs[0, 1] self.plotmethod( ax=ax, cbar_ax=cax, add_colorbar=True, cbar_kwargs={"label": "MyBar"} ) @@ -1576,7 +1581,9 @@ def test_colorbar_kwargs(self) -> None: assert "MyBar" in alltxt assert "testvar" not in alltxt # note that there are two ways to achieve this - fig, (ax, cax) = plt.subplots(1, 2) + fig, axs = plt.subplots(1, 2, squeeze=False) + ax = axs[0, 0] + cax = axs[0, 1] self.plotmethod( ax=ax, add_colorbar=True, cbar_kwargs={"label": "MyBar", "cax": cax} ) @@ -3371,16 +3378,16 @@ def test_plot1d_default_rcparams() -> None: # see overlapping markers: fig, ax = plt.subplots(1, 1) ds.plot.scatter(x="A", y="B", marker="o", ax=ax) - np.testing.assert_allclose( - ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("w") - ) + actual: np.ndarray = mpl.colors.to_rgba_array("w") + expected: np.ndarray = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error? + np.testing.assert_allclose(actual, expected) # Facetgrids should have the default value as well: fg = ds.plot.scatter(x="A", y="B", col="x", marker="o") ax = fg.axs.ravel()[0] - np.testing.assert_allclose( - ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("w") - ) + actual = mpl.colors.to_rgba_array("w") + expected = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error? + np.testing.assert_allclose(actual, expected) # scatter should not emit any warnings when using unfilled markers: with assert_no_warnings(): @@ -3390,9 +3397,9 @@ def test_plot1d_default_rcparams() -> None: # Prioritize edgecolor argument over default plot1d values: fig, ax = plt.subplots(1, 1) ds.plot.scatter(x="A", y="B", marker="o", ax=ax, edgecolor="k") - np.testing.assert_allclose( - ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("k") - ) + actual = mpl.colors.to_rgba_array("k") + expected = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error? + np.testing.assert_allclose(actual, expected) @requires_matplotlib diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 89f6ebba2c3..79869e63ae7 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -206,9 +206,9 @@ def test_rolling_pandas_compat( index=window, center=center, min_periods=min_periods ).reduce(np.nanmean) - np.testing.assert_allclose(s_rolling.values, da_rolling.values) + np.testing.assert_allclose(np.asarray(s_rolling.values), da_rolling.values) np.testing.assert_allclose(s_rolling.index, da_rolling["index"]) - np.testing.assert_allclose(s_rolling.values, da_rolling_np.values) + np.testing.assert_allclose(np.asarray(s_rolling.values), da_rolling_np.values) np.testing.assert_allclose(s_rolling.index, da_rolling_np["index"]) @pytest.mark.parametrize("center", (True, False)) @@ -221,12 +221,14 @@ def test_rolling_construct(self, center: bool, window: int) -> None: da_rolling = da.rolling(index=window, center=center, min_periods=1) da_rolling_mean = da_rolling.construct("window").mean("window") - np.testing.assert_allclose(s_rolling.values, da_rolling_mean.values) + np.testing.assert_allclose(np.asarray(s_rolling.values), da_rolling_mean.values) np.testing.assert_allclose(s_rolling.index, da_rolling_mean["index"]) # with stride da_rolling_mean = da_rolling.construct("window", stride=2).mean("window") - np.testing.assert_allclose(s_rolling.values[::2], da_rolling_mean.values) + np.testing.assert_allclose( + np.asarray(s_rolling.values[::2]), da_rolling_mean.values + ) np.testing.assert_allclose(s_rolling.index[::2], da_rolling_mean["index"]) # with fill_value @@ -649,7 +651,9 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None: index=window, center=center, min_periods=min_periods ).mean() - np.testing.assert_allclose(df_rolling["x"].values, ds_rolling["x"].values) + np.testing.assert_allclose( + np.asarray(df_rolling["x"].values), ds_rolling["x"].values + ) np.testing.assert_allclose(df_rolling.index, ds_rolling["index"]) @pytest.mark.parametrize("center", (True, False)) @@ -668,7 +672,9 @@ def test_rolling_construct(self, center: bool, window: int) -> None: ds_rolling = ds.rolling(index=window, center=center) ds_rolling_mean = ds_rolling.construct("window").mean("window") - np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values) + np.testing.assert_allclose( + np.asarray(df_rolling["x"].values), ds_rolling_mean["x"].values + ) np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"]) # with fill_value @@ -695,7 +701,7 @@ def test_rolling_construct_stride(self, center: bool, window: int) -> None: ds_rolling = ds.rolling(index=window, center=center) ds_rolling_mean = ds_rolling.construct("w", stride=2).mean("w") np.testing.assert_allclose( - df_rolling_mean["x"][::2].values, ds_rolling_mean["x"].values + np.asarray(df_rolling_mean["x"][::2].values), ds_rolling_mean["x"].values ) np.testing.assert_allclose(df_rolling_mean.index[::2], ds_rolling_mean["index"]) @@ -704,7 +710,7 @@ def test_rolling_construct_stride(self, center: bool, window: int) -> None: ds2_rolling = ds2.rolling(index=window, center=center) ds2_rolling_mean = ds2_rolling.construct("w", stride=2).mean("w") np.testing.assert_allclose( - df_rolling_mean["x"][::2].values, ds2_rolling_mean["x"].values + np.asarray(df_rolling_mean["x"][::2].values), ds2_rolling_mean["x"].values ) # Mixed coordinates, indexes and 2D coordinates diff --git a/xarray/tests/test_strategies.py b/xarray/tests/test_strategies.py index 47f54382c03..79ae4769005 100644 --- a/xarray/tests/test_strategies.py +++ b/xarray/tests/test_strategies.py @@ -217,7 +217,7 @@ def test_make_strategies_namespace(self, data): warnings.filterwarnings( "ignore", category=UserWarning, message=".+See NEP 47." ) - from numpy import ( # type: ignore[no-redef,unused-ignore] + from numpy import ( # type: ignore[attr-defined,no-redef,unused-ignore] array_api as nxp, ) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 081bf09484a..3f3f1756e45 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -178,8 +178,8 @@ def _assertIndexedLikeNDArray(self, variable, expected_value0, expected_dtype=No # check type or dtype is consistent for both ndarray and Variable if expected_dtype is None: # check output type instead of array dtype - assert type(variable.values[0]) == type(expected_value0) - assert type(variable[0].values) == type(expected_value0) + assert type(variable.values[0]) is type(expected_value0) + assert type(variable[0].values) is type(expected_value0) elif expected_dtype is not False: assert variable.values[0].dtype == expected_dtype assert variable[0].values.dtype == expected_dtype @@ -2649,7 +2649,7 @@ def test_tz_datetime(self) -> None: tz = pytz.timezone("America/New_York") times_ns = pd.date_range("2000", periods=1, tz=tz) - times_s = times_ns.astype(pd.DatetimeTZDtype("s", tz)) + times_s = times_ns.astype(pd.DatetimeTZDtype("s", tz)) # type: ignore[arg-type] with warnings.catch_warnings(): warnings.simplefilter("ignore") actual: T_DuckArray = as_compatible_data(times_s) @@ -2661,7 +2661,7 @@ def test_tz_datetime(self) -> None: warnings.simplefilter("ignore") actual2: T_DuckArray = as_compatible_data(series) - np.testing.assert_array_equal(actual2, series.values) + np.testing.assert_array_equal(actual2, np.asarray(series.values)) assert actual2.dtype == np.dtype("datetime64[ns]") def test_full_like(self) -> None: @@ -2978,26 +2978,35 @@ def test_datetime_conversion_warning(values, warns) -> None: ) -def test_pandas_two_only_datetime_conversion_warnings() -> None: - # Note these tests rely on pandas features that are only present in pandas - # 2.0.0 and above, and so for now cannot be parametrized. - cases = [ - (pd.date_range("2000", periods=1), "datetime64[s]"), - (pd.Series(pd.date_range("2000", periods=1)), "datetime64[s]"), - ( - pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")), - pd.DatetimeTZDtype("s", pytz.timezone("America/New_York")), +tz_ny = pytz.timezone("America/New_York") + + +@pytest.mark.parametrize( + ["data", "dtype"], + [ + pytest.param(pd.date_range("2000", periods=1), "datetime64[s]", id="index-sec"), + pytest.param( + pd.Series(pd.date_range("2000", periods=1)), + "datetime64[s]", + id="series-sec", ), - ( - pd.Series( - pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")) - ), - pd.DatetimeTZDtype("s", pytz.timezone("America/New_York")), + pytest.param( + pd.date_range("2000", periods=1, tz=tz_ny), + pd.DatetimeTZDtype("s", tz_ny), # type: ignore[arg-type] + id="index-timezone", ), - ] - for data, dtype in cases: - with pytest.warns(UserWarning, match="non-nanosecond precision datetime"): - var = Variable(["time"], data.astype(dtype)) + pytest.param( + pd.Series(pd.date_range("2000", periods=1, tz=tz_ny)), + pd.DatetimeTZDtype("s", tz_ny), # type: ignore[arg-type] + id="series-timezone", + ), + ], +) +def test_pandas_two_only_datetime_conversion_warnings( + data: pd.DatetimeIndex | pd.Series, dtype: str | pd.DatetimeTZDtype +) -> None: + with pytest.warns(UserWarning, match="non-nanosecond precision datetime"): + var = Variable(["time"], data.astype(dtype)) # type: ignore[arg-type] if var.dtype.kind == "M": assert var.dtype == np.dtype("datetime64[ns]") @@ -3006,9 +3015,7 @@ def test_pandas_two_only_datetime_conversion_warnings() -> None: # the case that the variable is backed by a timezone-aware # DatetimeIndex, and thus is hidden within the PandasIndexingAdapter class. assert isinstance(var._data, PandasIndexingAdapter) - assert var._data.array.dtype == pd.DatetimeTZDtype( - "ns", pytz.timezone("America/New_York") - ) + assert var._data.array.dtype == pd.DatetimeTZDtype("ns", tz_ny) @pytest.mark.parametrize( diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index b59dc36c108..d93c94b1a76 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -91,19 +91,6 @@ def reduce( class {obj}{cls}Aggregations: _obj: {obj} - def _reduce_without_squeeze_warn( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, - ) -> {obj}: - raise NotImplementedError() - def reduce( self, func: Callable[..., Any], @@ -128,19 +115,6 @@ def _flox_reduce( class {obj}{cls}Aggregations: _obj: {obj} - def _reduce_without_squeeze_warn( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, - ) -> {obj}: - raise NotImplementedError() - def reduce( self, func: Callable[..., Any], @@ -457,7 +431,7 @@ def generate_code(self, method, has_keep_attrs): if method_is_not_flox_supported: return f"""\ - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.{method.array_method}, dim=dim,{extra_kwargs} keep_attrs=keep_attrs, @@ -484,7 +458,7 @@ def generate_code(self, method, has_keep_attrs): **kwargs, ) else: - return self._reduce_without_squeeze_warn( + return self.reduce( duck_array_ops.{method.array_method}, dim=dim,{extra_kwargs} keep_attrs=keep_attrs, diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index ee4dd68b3ba..a9f66cdc614 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -88,12 +88,10 @@ def {method}(self, other: {other_type}) -> {return_type}:{type_ignore} return self._binary_op(other, {func})""" template_binop_overload = """ @overload{overload_type_ignore} - def {method}(self, other: {overload_type}) -> {overload_type}: - ... + def {method}(self, other: {overload_type}) -> {overload_type}: ... @overload - def {method}(self, other: {other_type}) -> {return_type}: - ... + def {method}(self, other: {other_type}) -> {return_type}: ... def {method}(self, other: {other_type}) -> {return_type} | {overload_type}:{type_ignore} return self._binary_op(other, {func})""" @@ -129,7 +127,7 @@ def {method}(self, *args: Any, **kwargs: Any) -> Self: # The type ignores might not be necessary anymore at some point. # # We require a "hack" to tell type checkers that e.g. Variable + DataArray = DataArray -# In reality this returns NotImplementes, but this is not a valid type in python 3.9. +# In reality this returns NotImplemented, but this is not a valid type in python 3.9. # Therefore, we return DataArray. In reality this would call DataArray.__add__(Variable) # TODO: change once python 3.10 is the minimum. # @@ -216,6 +214,10 @@ def unops() -> list[OpsType]: ] +# We use short names T_DA and T_DS to keep below 88 lines so +# ruff does not reformat everything. When reformatting, the +# type-ignores end up in the wrong line :/ + ops_info = {} ops_info["DatasetOpsMixin"] = ( binops(other_type="DsCompatible") + inplace(other_type="DsCompatible") + unops() @@ -224,12 +226,12 @@ def unops() -> list[OpsType]: binops(other_type="DaCompatible") + inplace(other_type="DaCompatible") + unops() ) ops_info["VariableOpsMixin"] = ( - binops_overload(other_type="VarCompatible", overload_type="T_DataArray") + binops_overload(other_type="VarCompatible", overload_type="T_DA") + inplace(other_type="VarCompatible", type_ignore="misc") + unops() ) ops_info["DatasetGroupByOpsMixin"] = binops( - other_type="GroupByCompatible", return_type="Dataset" + other_type="Dataset | DataArray", return_type="Dataset" ) ops_info["DataArrayGroupByOpsMixin"] = binops( other_type="T_Xarray", return_type="T_Xarray" @@ -237,6 +239,7 @@ def unops() -> list[OpsType]: MODULE_PREAMBLE = '''\ """Mixin classes with arithmetic operators.""" + # This file was generated using xarray.util.generate_ops. Do not edit manually. from __future__ import annotations @@ -248,15 +251,15 @@ def unops() -> list[OpsType]: from xarray.core.types import ( DaCompatible, DsCompatible, - GroupByCompatible, Self, - T_DataArray, T_Xarray, VarCompatible, ) if TYPE_CHECKING: - from xarray.core.dataset import Dataset''' + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.types import T_DataArray as T_DA''' CLASS_PREAMBLE = """{newline}