From 261fb7880523db79b5bc3715477374d089e9a574 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 14 May 2021 18:56:15 +0200 Subject: [PATCH 01/11] xindexes also returns multi-index levels as keys --- xarray/core/dataarray.py | 18 ++++++++++++++---- xarray/core/dataset.py | 20 +++++++++++++++----- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 21daed1cec1..3cf4d9cfdb7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -772,6 +772,11 @@ def encoding(self) -> Dict[Hashable, Any]: def encoding(self, value: Mapping[Hashable, Any]) -> None: self.variable.encoding = value + def _get_indexes(self): + if self._indexes is None: + self._indexes = default_indexes(self._coords, self.dims) + return self._indexes + @property def indexes(self) -> Indexes: """Mapping of pandas.Index objects used for label based indexing. @@ -784,14 +789,19 @@ def indexes(self) -> Indexes: DataArray.xindexes """ - return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) + return Indexes( + {k: idx.to_pandas_index() for k, idx in self._get_indexes().items()} + ) @property def xindexes(self) -> Indexes: """Mapping of xarray Index objects used for label based indexing.""" - if self._indexes is None: - self._indexes = default_indexes(self._coords, self.dims) - return Indexes(self._indexes) + indexes = self._get_indexes().copy() + + for level, dim in self._level_coords.items(): + indexes[level] = indexes[dim] + + return Indexes(indexes) @property def coords(self) -> DataArrayCoordinates: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 706ccbde8c4..7d5e23d61ce 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1318,7 +1318,7 @@ def _level_coords(self) -> Dict[str, Hashable]: coordinate name. """ level_coords: Dict[str, Hashable] = {} - for name, index in self.xindexes.items(): + for name, index in self._get_indexes().items(): # TODO: benbovy - flexible indexes: update when MultIndex has its own xarray class. pd_index = index.to_pandas_index() if isinstance(pd_index, pd.MultiIndex): @@ -1607,6 +1607,11 @@ def identical(self, other: "Dataset") -> bool: except (TypeError, AttributeError): return False + def _get_indexes(self): + if self._indexes is None: + self._indexes = default_indexes(self._variables, self._dims) + return self._indexes + @property def indexes(self) -> Indexes: """Mapping of pandas.Index objects used for label based indexing. @@ -1619,14 +1624,19 @@ def indexes(self) -> Indexes: Dataset.xindexes """ - return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) + return Indexes( + {k: idx.to_pandas_index() for k, idx in self._get_indexes().items()} + ) @property def xindexes(self) -> Indexes: """Mapping of xarray Index objects used for label based indexing.""" - if self._indexes is None: - self._indexes = default_indexes(self._variables, self._dims) - return Indexes(self._indexes) + indexes = self._get_indexes().copy() + + for level, dim in self._level_coords.items(): + indexes[level] = indexes[dim] + + return Indexes(indexes) @property def coords(self) -> DatasetCoordinates: From 6f44bdd4e20013df7cc9bf5e25d1ae4222d08447 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 14 May 2021 18:57:34 +0200 Subject: [PATCH 02/11] wip: move label selection into PandasIndex add Index.query() method split pd.Index vs. pd.MultiIndex logic --- xarray/core/indexes.py | 195 +++++++++++++++++++++++++++++++++++++++- xarray/core/indexing.py | 71 +++++++++++---- 2 files changed, 246 insertions(+), 20 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index be362e1c942..074f8268bd8 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -9,6 +9,7 @@ Iterable, Mapping, Optional, + Sequence, Tuple, Union, ) @@ -19,7 +20,7 @@ from . import formatting, utils from .indexing import ExplicitlyIndexedNDArrayMixin, NumpyIndexingAdapter from .npcompat import DTypeLike -from .utils import is_scalar +from .utils import is_dict_like, is_scalar if TYPE_CHECKING: from .variable import Variable @@ -51,6 +52,9 @@ def to_pandas_index(self) -> pd.Index: """ raise TypeError(f"{type(self)} cannot be cast to a pandas.Index object.") + def query(self, labels: Dict[Hashable, Any]): # pragma: no cover + raise NotImplementedError + def equals(self, other): # pragma: no cover raise NotImplementedError() @@ -61,6 +65,57 @@ def intersection(self, other): # pragma: no cover raise NotImplementedError() +def _sanitize_slice_element(x): + from .dataarray import DataArray + from .variable import Variable + + if isinstance(x, (Variable, DataArray)): + x = x.values + + if isinstance(x, np.ndarray): + if x.ndim != 0: + raise ValueError( + f"cannot use non-scalar arrays in a slice for xarray indexing: {x}" + ) + x = x[()] + + return x + + +def _asarray_tuplesafe(values): + """ + Convert values into a numpy array of at most 1-dimension, while preserving + tuples. + + Adapted from pandas.core.common._asarray_tuplesafe + """ + if isinstance(values, tuple): + result = utils.to_0d_object_array(values) + else: + result = np.asarray(values) + if result.ndim == 2: + result = np.empty(len(values), dtype=object) + result[:] = values + + return result + + +def _is_nested_tuple(possible_tuple): + return isinstance(possible_tuple, tuple) and any( + isinstance(value, (tuple, list, slice)) for value in possible_tuple + ) + + +def get_indexer_nd(index, labels, method=None, tolerance=None): + """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional + labels + """ + flat_labels = np.ravel(labels) + flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance) + indexer = flat_indexer.reshape(labels.shape) + return indexer + + class PandasIndex(Index, ExplicitlyIndexedNDArrayMixin): """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" @@ -118,6 +173,144 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray: def shape(self) -> Tuple[int]: return (len(self.array),) + def _query_multiindex(self, labels): + index = self.array + new_index = None + + # label(s) given for multi-index level(s) + if all([lbl in index.names for lbl in labels]): + is_nested_vals = _is_nested_tuple(tuple(labels.values())) + if len(labels) == index.nlevels and not is_nested_vals: + indexer = index.get_loc(tuple(labels[k] for k in index.names)) + else: + for k, v in labels.items(): + # index should be an item (i.e. Hashable) not an array-like + if isinstance(v, Sequence) and not isinstance(v, str): + raise ValueError( + "Vectorized selection is not " + f"available along coordinate {k!r} (multi-index level)" + ) + indexer, new_index = index.get_loc_level( + tuple(labels.values()), level=tuple(labels.keys()) + ) + # GH2619. Raise a KeyError if nothing is chosen + if indexer.dtype.kind == "b" and indexer.sum() == 0: + raise KeyError(f"{labels} not found") + + # assume one label value given for the multi-index "array" (dimension) + else: + if len(labels) > 1: + coord_name = next(iter(set(labels) - set(index.names))) + raise ValueError( + f"cannot provide labels for both coordinate {coord_name!r} (multi-index array) " + f"and one or more coordinates among {index.names!r} (multi-index levels)" + ) + + coord_name, label = next(iter(labels.items())) + + if is_dict_like(label): + indexer, new_index = self._query_multiindex(label) + + elif isinstance(label, tuple): + if _is_nested_tuple(label): + indexer = index.get_locs(label) + elif len(label) == index.nlevels: + indexer = index.get_loc(label) + else: + indexer, new_index = index.get_loc_level( + label, level=list(range(len(label))) + ) + + else: + label = ( + label + if getattr(label, "ndim", 1) > 1 # vectorized-indexing + else _asarray_tuplesafe(label) + ) + if label.ndim == 0: + indexer, new_index = index.get_loc_level(label, level=0) + elif label.dtype.kind == "b": + indexer = label + else: + if label.ndim > 1: + raise ValueError( + "Vectorized selection is not available along " + f"coordinate {coord_name!r} with a multi-index" + ) + indexer = get_indexer_nd(index, label) + if np.any(indexer < 0): + raise KeyError(f"not all values found in index {coord_name!r}") + + if new_index is not None: + new_index = PandasIndex(new_index) + + return indexer, new_index + + def query( + self, labels, method=None, tolerance=None + ) -> Tuple[Any, Union["PandasIndex", None]]: + if isinstance(self.array, pd.MultiIndex): + return self._query_multiindex(labels) + + assert len(labels) == 1 + coord_name, label = next(iter(labels.items())) + index = self.array + + if isinstance(label, slice): + if method is not None or tolerance is not None: + raise NotImplementedError( + "cannot use ``method`` argument if any indexers are slice objects" + ) + indexer = index.slice_indexer( + _sanitize_slice_element(label.start), + _sanitize_slice_element(label.stop), + _sanitize_slice_element(label.step), + ) + if not isinstance(indexer, slice): + # unlike pandas, in xarray we never want to silently convert a + # slice indexer into an array indexer + raise KeyError( + "cannot represent labeled-based slice indexer for coordinate " + f"{coord_name!r} with a slice over integer positions; the index is " + "unsorted or non-unique" + ) + elif is_dict_like(label): + raise ValueError( + "cannot use a dict-like object for selection on " + "a dimension that does not have a MultiIndex" + ) + else: + label = ( + label + if getattr(label, "ndim", 1) > 1 # vectorized-indexing + else _asarray_tuplesafe(label) + ) + if label.ndim == 0: + # see https://github.com/pydata/xarray/pull/4292 for details + label_value = label[()] if label.dtype.kind in "mM" else label.item() + if isinstance(index, pd.CategoricalIndex): + if method is not None: + raise ValueError( + "'method' is not a valid kwarg when indexing using a CategoricalIndex." + ) + if tolerance is not None: + raise ValueError( + "'tolerance' is not a valid kwarg when indexing using a CategoricalIndex." + ) + indexer = index.get_loc(label_value) + else: + indexer = index.get_loc( + label_value, method=method, tolerance=tolerance + ) + elif label.dtype.kind == "b": + indexer = label + else: + indexer = get_indexer_nd(index, label, method, tolerance) + if np.any(indexer < 0): + raise KeyError(f"not all values found in index {coord_name!r}") + + return indexer, None + def equals(self, other): if isinstance(other, pd.Index): other = PandasIndex(other) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 76a0c6888b2..6cb8804a841 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -251,6 +251,40 @@ def get_dim_indexers(data_obj, indexers): return dim_indexers +def group_indexers_by_index(data_obj, indexers, method=None, tolerance=None): + # TODO: benbovy - flexible indexes: indexers are still grouped by dimension + # - Make xarray.Index hashable so that it can be used as key in a mapping? + indexes = {} + grouped_indexers = defaultdict(dict) + + for key, label in indexers.items(): + try: + index = data_obj.xindexes[key] + coord = data_obj.coords[key] + dim = coord.dims[0] + if dim not in indexes: + indexes[dim] = index + + label = maybe_cast_to_coords_dtype(label, coord.dtype) + grouped_indexers[dim][key] = label + + except KeyError: + if key in data_obj.coords: + raise KeyError(f"no index found for coordinate {key}") + elif key not in data_obj.dims: + raise KeyError(f"{key} is not a valid dimension or coordinate") + # key is a dimension without coordinate: we'll reuse the provided labels + elif method is not None or tolerance is not None: + raise ValueError( + "cannot supply ``method`` or ``tolerance`` " + "when the indexed dimension does not have " + "an associated coordinate." + ) + grouped_indexers[None][key] = label + + return indexes, grouped_indexers + + def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): """Given an xarray data object and label based indexers, return a mapping of equivalent location based indexers. Also return a mapping of updated @@ -262,26 +296,25 @@ def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): pos_indexers = {} new_indexes = {} - dim_indexers = get_dim_indexers(data_obj, indexers) - for dim, label in dim_indexers.items(): - try: - index = data_obj.xindexes[dim].to_pandas_index() - except KeyError: - # no index for this dimension: reuse the provided labels - if method is not None or tolerance is not None: - raise ValueError( - "cannot supply ``method`` or ``tolerance`` " - "when the indexed dimension does not have " - "an associated coordinate." - ) + indexes, grouped_indexers = group_indexers_by_index( + data_obj, indexers, method, tolerance + ) + + forward_pos_indexers = grouped_indexers.pop(None, None) + if forward_pos_indexers is not None: + for dim, label in forward_pos_indexers.items(): pos_indexers[dim] = label - else: - coords_dtype = data_obj.coords[dim].dtype - label = maybe_cast_to_coords_dtype(label, coords_dtype) - idxr, new_idx = convert_label_indexer(index, label, dim, method, tolerance) - pos_indexers[dim] = idxr - if new_idx is not None: - new_indexes[dim] = new_idx + + for dim, index in indexes.items(): + labels = grouped_indexers[dim] + idxr, new_idx = index.query(labels, method=method, tolerance=tolerance) + pos_indexers[dim] = idxr + if new_idx is not None: + new_indexes[dim] = new_idx + + # TODO: benbovy - flexible indexes: support the following cases: + # - an index query returns positional indexers over multiple dimensions + # - check/combine positional indexers returned by multiple indexes over the same dimension return pos_indexers, new_indexes From eeb47557194e0abcebc3cd6da7d748e43f0bccbc Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 May 2021 12:12:20 +0200 Subject: [PATCH 03/11] Revert "xindexes also returns multi-index levels as keys" This reverts commit 261fb7880523db79b5bc3715477374d089e9a574. Let's keep this for later. There are too many places in Xarray that assume that xindexes keys are dimension names. --- xarray/core/dataarray.py | 18 ++++-------------- xarray/core/dataset.py | 20 +++++--------------- xarray/core/indexing.py | 8 +++++++- 3 files changed, 16 insertions(+), 30 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 3cf4d9cfdb7..21daed1cec1 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -772,11 +772,6 @@ def encoding(self) -> Dict[Hashable, Any]: def encoding(self, value: Mapping[Hashable, Any]) -> None: self.variable.encoding = value - def _get_indexes(self): - if self._indexes is None: - self._indexes = default_indexes(self._coords, self.dims) - return self._indexes - @property def indexes(self) -> Indexes: """Mapping of pandas.Index objects used for label based indexing. @@ -789,19 +784,14 @@ def indexes(self) -> Indexes: DataArray.xindexes """ - return Indexes( - {k: idx.to_pandas_index() for k, idx in self._get_indexes().items()} - ) + return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) @property def xindexes(self) -> Indexes: """Mapping of xarray Index objects used for label based indexing.""" - indexes = self._get_indexes().copy() - - for level, dim in self._level_coords.items(): - indexes[level] = indexes[dim] - - return Indexes(indexes) + if self._indexes is None: + self._indexes = default_indexes(self._coords, self.dims) + return Indexes(self._indexes) @property def coords(self) -> DataArrayCoordinates: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7d5e23d61ce..706ccbde8c4 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1318,7 +1318,7 @@ def _level_coords(self) -> Dict[str, Hashable]: coordinate name. """ level_coords: Dict[str, Hashable] = {} - for name, index in self._get_indexes().items(): + for name, index in self.xindexes.items(): # TODO: benbovy - flexible indexes: update when MultIndex has its own xarray class. pd_index = index.to_pandas_index() if isinstance(pd_index, pd.MultiIndex): @@ -1607,11 +1607,6 @@ def identical(self, other: "Dataset") -> bool: except (TypeError, AttributeError): return False - def _get_indexes(self): - if self._indexes is None: - self._indexes = default_indexes(self._variables, self._dims) - return self._indexes - @property def indexes(self) -> Indexes: """Mapping of pandas.Index objects used for label based indexing. @@ -1624,19 +1619,14 @@ def indexes(self) -> Indexes: Dataset.xindexes """ - return Indexes( - {k: idx.to_pandas_index() for k, idx in self._get_indexes().items()} - ) + return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) @property def xindexes(self) -> Indexes: """Mapping of xarray Index objects used for label based indexing.""" - indexes = self._get_indexes().copy() - - for level, dim in self._level_coords.items(): - indexes[level] = indexes[dim] - - return Indexes(indexes) + if self._indexes is None: + self._indexes = default_indexes(self._variables, self._dims) + return Indexes(self._indexes) @property def coords(self) -> DatasetCoordinates: diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 6cb8804a841..6f9cfaa4db4 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -257,9 +257,15 @@ def group_indexers_by_index(data_obj, indexers, method=None, tolerance=None): indexes = {} grouped_indexers = defaultdict(dict) + # TODO: data_obj.xindexes should eventually return the PandasIndex instance + # for each multi-index levels + xindexes = dict(data_obj.xindexes) + for level, dim in data_obj._level_coords.items(): + xindexes[level] = xindexes[dim] + for key, label in indexers.items(): try: - index = data_obj.xindexes[key] + index = xindexes[key] coord = data_obj.coords[key] dim = coord.dims[0] if dim not in indexes: From 0b49c3eb0fb2f6c8ef730c7ded6b358a2bf3b64d Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 May 2021 12:15:12 +0200 Subject: [PATCH 04/11] fix broken tests --- xarray/core/indexes.py | 19 +++++++++++++++++-- xarray/tests/test_dataset.py | 2 +- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 074f8268bd8..d12b0079ded 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -209,7 +209,22 @@ def _query_multiindex(self, labels): coord_name, label = next(iter(labels.items())) if is_dict_like(label): - indexer, new_index = self._query_multiindex(label) + return self._query_multiindex(label) + + elif isinstance(label, slice): + indexer = index.slice_indexer( + _sanitize_slice_element(label.start), + _sanitize_slice_element(label.stop), + _sanitize_slice_element(label.step), + ) + if not isinstance(indexer, slice): + # unlike pandas, in xarray we never want to silently convert a + # slice indexer into an array indexer + raise KeyError( + "cannot represent labeled-based slice indexer for dimension " + f"{coord_name!r} with a slice over integer positions; the index is " + "unsorted or non-unique" + ) elif isinstance(label, tuple): if _is_nested_tuple(label): @@ -228,7 +243,7 @@ def _query_multiindex(self, labels): else _asarray_tuplesafe(label) ) if label.ndim == 0: - indexer, new_index = index.get_loc_level(label, level=0) + indexer, new_index = index.get_loc_level(label.item(), level=0) elif label.dtype.kind == "b": indexer = label else: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b8e1cd4b03b..59787609c3a 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1427,7 +1427,7 @@ def test_sel_dataarray_mindex(self): with pytest.raises( ValueError, - match=r"Vectorized selection is not available along MultiIndex variable: x", + match=r"Vectorized selection is not available along coordinate 'x' with a multi-index", ): mds.sel( x=xr.DataArray( From 1688974820233adf25c5df677778d271b5fe39e1 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 May 2021 13:12:02 +0200 Subject: [PATCH 05/11] remove old code + move/update tests --- xarray/core/indexes.py | 5 + xarray/core/indexing.py | 186 +--------------------------------- xarray/tests/test_indexes.py | 60 +++++++++++ xarray/tests/test_indexing.py | 82 ++++----------- 4 files changed, 88 insertions(+), 245 deletions(-) create mode 100644 xarray/tests/test_indexes.py diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index d12b0079ded..6e425a521b5 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -209,6 +209,11 @@ def _query_multiindex(self, labels): coord_name, label = next(iter(labels.items())) if is_dict_like(label): + invalid_levels = [name for name in label if name not in index.names] + if invalid_levels: + raise ValueError( + f"invalid multi-index level names {invalid_levels}" + ) return self._query_multiindex(label) elif isinstance(label, slice): diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 6f9cfaa4db4..86f13b787de 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -3,7 +3,7 @@ import operator from collections import defaultdict from distutils.version import LooseVersion -from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union +from typing import Any, Callable, Iterable, List, Tuple, Union import numpy as np import pandas as pd @@ -22,7 +22,7 @@ is_duck_dask_array, sparse_array_type, ) -from .utils import is_dict_like, maybe_cast_to_coords_dtype +from .utils import maybe_cast_to_coords_dtype def expanded_indexer(key, ndim): @@ -59,47 +59,6 @@ def _expand_slice(slice_, size): return np.arange(*slice_.indices(size)) -def _sanitize_slice_element(x): - from .dataarray import DataArray - from .variable import Variable - - if isinstance(x, (Variable, DataArray)): - x = x.values - - if isinstance(x, np.ndarray): - if x.ndim != 0: - raise ValueError( - f"cannot use non-scalar arrays in a slice for xarray indexing: {x}" - ) - x = x[()] - - return x - - -def _asarray_tuplesafe(values): - """ - Convert values into a numpy array of at most 1-dimension, while preserving - tuples. - - Adapted from pandas.core.common._asarray_tuplesafe - """ - if isinstance(values, tuple): - result = utils.to_0d_object_array(values) - else: - result = np.asarray(values) - if result.ndim == 2: - result = np.empty(len(values), dtype=object) - result[:] = values - - return result - - -def _is_nested_tuple(possible_tuple): - return isinstance(possible_tuple, tuple) and any( - isinstance(value, (tuple, list, slice)) for value in possible_tuple - ) - - def get_indexer_nd(index, labels, method=None, tolerance=None): """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional labels @@ -110,147 +69,6 @@ def get_indexer_nd(index, labels, method=None, tolerance=None): return indexer -def convert_label_indexer(index, label, index_name="", method=None, tolerance=None): - """Given a pandas.Index and labels (e.g., from __getitem__) for one - dimension, return an indexer suitable for indexing an ndarray along that - dimension. If `index` is a pandas.MultiIndex and depending on `label`, - return a new pandas.Index or pandas.MultiIndex (otherwise return None). - """ - from .indexes import PandasIndex - - new_index = None - - if isinstance(label, slice): - if method is not None or tolerance is not None: - raise NotImplementedError( - "cannot use ``method`` argument if any indexers are slice objects" - ) - indexer = index.slice_indexer( - _sanitize_slice_element(label.start), - _sanitize_slice_element(label.stop), - _sanitize_slice_element(label.step), - ) - if not isinstance(indexer, slice): - # unlike pandas, in xarray we never want to silently convert a - # slice indexer into an array indexer - raise KeyError( - "cannot represent labeled-based slice indexer for dimension " - f"{index_name!r} with a slice over integer positions; the index is " - "unsorted or non-unique" - ) - - elif is_dict_like(label): - is_nested_vals = _is_nested_tuple(tuple(label.values())) - if not isinstance(index, pd.MultiIndex): - raise ValueError( - "cannot use a dict-like object for selection on " - "a dimension that does not have a MultiIndex" - ) - elif len(label) == index.nlevels and not is_nested_vals: - indexer = index.get_loc(tuple(label[k] for k in index.names)) - else: - for k, v in label.items(): - # index should be an item (i.e. Hashable) not an array-like - if isinstance(v, Sequence) and not isinstance(v, str): - raise ValueError( - "Vectorized selection is not " - "available along level variable: " + k - ) - indexer, new_index = index.get_loc_level( - tuple(label.values()), level=tuple(label.keys()) - ) - - # GH2619. Raise a KeyError if nothing is chosen - if indexer.dtype.kind == "b" and indexer.sum() == 0: - raise KeyError(f"{label} not found") - - elif isinstance(label, tuple) and isinstance(index, pd.MultiIndex): - if _is_nested_tuple(label): - indexer = index.get_locs(label) - elif len(label) == index.nlevels: - indexer = index.get_loc(label) - else: - indexer, new_index = index.get_loc_level( - label, level=list(range(len(label))) - ) - else: - label = ( - label - if getattr(label, "ndim", 1) > 1 # vectorized-indexing - else _asarray_tuplesafe(label) - ) - if label.ndim == 0: - # see https://github.com/pydata/xarray/pull/4292 for details - label_value = label[()] if label.dtype.kind in "mM" else label.item() - if isinstance(index, pd.MultiIndex): - indexer, new_index = index.get_loc_level(label_value, level=0) - elif isinstance(index, pd.CategoricalIndex): - if method is not None: - raise ValueError( - "'method' is not a valid kwarg when indexing using a CategoricalIndex." - ) - if tolerance is not None: - raise ValueError( - "'tolerance' is not a valid kwarg when indexing using a CategoricalIndex." - ) - indexer = index.get_loc(label_value) - else: - indexer = index.get_loc(label_value, method=method, tolerance=tolerance) - elif label.dtype.kind == "b": - indexer = label - else: - if isinstance(index, pd.MultiIndex) and label.ndim > 1: - raise ValueError( - "Vectorized selection is not available along " - "MultiIndex variable: " + index_name - ) - indexer = get_indexer_nd(index, label, method, tolerance) - if np.any(indexer < 0): - raise KeyError(f"not all values found in index {index_name!r}") - - if new_index is not None: - new_index = PandasIndex(new_index) - - return indexer, new_index - - -def get_dim_indexers(data_obj, indexers): - """Given a xarray data object and label based indexers, return a mapping - of label indexers with only dimension names as keys. - - It groups multiple level indexers given on a multi-index dimension - into a single, dictionary indexer for that dimension (Raise a ValueError - if it is not possible). - """ - invalid = [ - k - for k in indexers - if k not in data_obj.dims and k not in data_obj._level_coords - ] - if invalid: - raise ValueError(f"dimensions or multi-index levels {invalid!r} do not exist") - - level_indexers = defaultdict(dict) - dim_indexers = {} - for key, label in indexers.items(): - (dim,) = data_obj[key].dims - if key != dim: - # assume here multi-index level indexer - level_indexers[dim][key] = label - else: - dim_indexers[key] = label - - for dim, level_labels in level_indexers.items(): - if dim_indexers.get(dim, False): - raise ValueError( - "cannot combine multi-index level indexers with an indexer for " - f"dimension {dim}" - ) - dim_indexers[dim] = level_labels - - return dim_indexers - - def group_indexers_by_index(data_obj, indexers, method=None, tolerance=None): # TODO: benbovy - flexible indexes: indexers are still grouped by dimension # - Make xarray.Index hashable so that it can be used as key in a mapping? diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py new file mode 100644 index 00000000000..e8b398cfa43 --- /dev/null +++ b/xarray/tests/test_indexes.py @@ -0,0 +1,60 @@ +import numpy as np +import pandas as pd +import pytest + +from xarray.core.indexes import PandasIndex, _asarray_tuplesafe + + +def test_asarray_tuplesafe(): + res = _asarray_tuplesafe(("a", 1)) + assert isinstance(res, np.ndarray) + assert res.ndim == 0 + assert res.item() == ("a", 1) + + res = _asarray_tuplesafe([(0,), (1,)]) + assert res.shape == (2,) + assert res[0] == (0,) + assert res[1] == (1,) + + +class TestPandasIndex: + def test_query(self): + # TODO: add tests that aren't just for edge cases + index = PandasIndex(pd.Index([1, 2, 3])) + with pytest.raises(KeyError, match=r"not all values found"): + index.query({"x": [0]}) + with pytest.raises(KeyError): + index.query({"x": 0}) + with pytest.raises(ValueError, match=r"does not have a MultiIndex"): + index.query({"x": {"one": 0}}) + + index = PandasIndex( + pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + ) + with pytest.raises(KeyError, match=r"not all values found"): + index.query({"x": [0]}) + with pytest.raises(KeyError): + index.query({"x": 0}) + with pytest.raises(ValueError, match=r"cannot provide labels for both.*"): + index.query({"one": 0, "x": "a"}) + with pytest.raises(ValueError, match=r"invalid multi-index level names"): + index.query({"x": {"three": 0}}) + with pytest.raises(IndexError): + index.query({"x": (slice(None), 1, "no_level")}) + + def test_query_datetime(self): + index = PandasIndex(pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"])) + actual = index.query({"x": "2001-01-01"}) + expected = (1, None) + assert actual == expected + + actual = index.query({"x": index.to_pandas_index().to_numpy()[1]}) + assert actual == expected + + def test_query_unsorted_datetime_index_raises(self): + index = PandasIndex(pd.to_datetime(["2001", "2000", "2002"])) + with pytest.raises(KeyError): + # pandas will try to convert this into an array indexer. We should + # raise instead, so we can be sure the result of indexing with a + # slice is always a view. + index.query({"x": slice("2001", "2002")}) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 96ad7c923e3..b23643717df 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -40,17 +40,6 @@ def test_expanded_indexer(self): with pytest.raises(IndexError, match=r"too many indices"): indexing.expanded_indexer(arr[1, 2, 3], 2) - def test_asarray_tuplesafe(self): - res = indexing._asarray_tuplesafe(("a", 1)) - assert isinstance(res, np.ndarray) - assert res.ndim == 0 - assert res.item() == ("a", 1) - - res = indexing._asarray_tuplesafe([(0,), (1,)]) - assert res.shape == (2,) - assert res[0] == (0,) - assert res[1] == (1,) - def test_stacked_multiindex_min_max(self): data = np.random.randn(3, 23, 4) da = DataArray( @@ -66,58 +55,29 @@ def test_stacked_multiindex_min_max(self): assert_array_equal(da2.loc["a", s.max()], data[2, 22, 0]) assert_array_equal(da2.loc["b", s.min()], data[0, 0, 1]) - def test_convert_label_indexer(self): - # TODO: add tests that aren't just for edge cases - index = pd.Index([1, 2, 3]) - with pytest.raises(KeyError, match=r"not all values found"): - indexing.convert_label_indexer(index, [0]) - with pytest.raises(KeyError): - indexing.convert_label_indexer(index, 0) - with pytest.raises(ValueError, match=r"does not have a MultiIndex"): - indexing.convert_label_indexer(index, {"one": 0}) - + def test_group_indexers_by_index(self): mindex = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) - with pytest.raises(KeyError, match=r"not all values found"): - indexing.convert_label_indexer(mindex, [0]) - with pytest.raises(KeyError): - indexing.convert_label_indexer(mindex, 0) - with pytest.raises(ValueError): - indexing.convert_label_indexer(index, {"three": 0}) - with pytest.raises(IndexError): - indexing.convert_label_indexer(mindex, (slice(None), 1, "no_level")) - - def test_convert_label_indexer_datetime(self): - index = pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]) - actual = indexing.convert_label_indexer(index, "2001-01-01") - expected = (1, None) - assert actual == expected - - actual = indexing.convert_label_indexer(index, index.to_numpy()[1]) - assert actual == expected - - def test_convert_unsorted_datetime_index_raises(self): - index = pd.to_datetime(["2001", "2000", "2002"]) - with pytest.raises(KeyError): - # pandas will try to convert this into an array indexer. We should - # raise instead, so we can be sure the result of indexing with a - # slice is always a view. - indexing.convert_label_indexer(index, slice("2001", "2002")) - - def test_get_dim_indexers(self): - mindex = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) - mdata = DataArray(range(4), [("x", mindex)]) - - dim_indexers = indexing.get_dim_indexers(mdata, {"one": "a", "two": 1}) - assert dim_indexers == {"x": {"one": "a", "two": 1}} - - with pytest.raises(ValueError, match=r"cannot combine"): - indexing.get_dim_indexers(mdata, {"x": "a", "two": 1}) - - with pytest.raises(ValueError, match=r"do not exist"): - indexing.get_dim_indexers(mdata, {"y": "a"}) + data = DataArray( + np.zeros((4, 2, 2)), coords={"x": mindex, "y": [1, 2]}, dims=("x", "y", "z") + ) + data.coords["y2"] = ("y", [2.0, 3.0]) - with pytest.raises(ValueError, match=r"do not exist"): - indexing.get_dim_indexers(mdata, {"four": 1}) + indexes, grouped_indexers = indexing.group_indexers_by_index( + data, {"z": 0, "one": "a", "two": 1, "y": 0} + ) + assert indexes == {"x": data.xindexes["x"], "y": data.xindexes["y"]} + assert grouped_indexers == { + "x": {"one": "a", "two": 1}, + "y": {"y": 0}, + None: {"z": 0}, + } + + with pytest.raises(KeyError, match=r"no index found for coordinate y2"): + indexing.group_indexers_by_index(data, {"y2": 2.0}) + with pytest.raises(KeyError, match=r"w is not a valid dimension or coordinate"): + indexing.group_indexers_by_index(data, {"w": "a"}) + with pytest.raises(ValueError, match=r"cannot supply.*"): + indexing.group_indexers_by_index(data, {"z": 1}, method="nearest") def test_remap_label_indexers(self): def test_indexer(data, x, expected_pos, expected_idx=None): From f7c8385f3fe916c9d29162fbaa5da1ac12cefc72 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 May 2021 13:18:50 +0200 Subject: [PATCH 06/11] remove duplicate function --- xarray/core/alignment.py | 3 +-- xarray/core/indexing.py | 10 ---------- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index f6e026c0109..2ce09cb65e9 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -18,8 +18,7 @@ import pandas as pd from . import dtypes -from .indexes import Index, PandasIndex -from .indexing import get_indexer_nd +from .indexes import Index, PandasIndex, get_indexer_nd from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index from .variable import IndexVariable, Variable diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 86f13b787de..13abede3690 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -59,16 +59,6 @@ def _expand_slice(slice_, size): return np.arange(*slice_.indices(size)) -def get_indexer_nd(index, labels, method=None, tolerance=None): - """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional - labels - """ - flat_labels = np.ravel(labels) - flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance) - indexer = flat_indexer.reshape(labels.shape) - return indexer - - def group_indexers_by_index(data_obj, indexers, method=None, tolerance=None): # TODO: benbovy - flexible indexes: indexers are still grouped by dimension # - Make xarray.Index hashable so that it can be used as key in a mapping? From b9bfbde7cc459560167a1326625e3543668e0af4 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 May 2021 15:51:14 +0200 Subject: [PATCH 07/11] add PandasMultiIndex class + refactor query impl --- xarray/core/alignment.py | 4 +- xarray/core/dataarray.py | 10 +- xarray/core/dataset.py | 13 +- xarray/core/indexes.py | 248 ++++++++++++++++++----------------- xarray/core/variable.py | 6 +- xarray/tests/test_indexes.py | 33 ++--- 6 files changed, 165 insertions(+), 149 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 2ce09cb65e9..0d291718d3c 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -18,7 +18,7 @@ import pandas as pd from . import dtypes -from .indexes import Index, PandasIndex, get_indexer_nd +from .indexes import Index, PandasIndex, get_indexer_nd, wrap_pandas_index from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index from .variable import IndexVariable, Variable @@ -566,7 +566,7 @@ def reindex_variables( "from that to be indexed along {:s}".format(str(indexer.dims), dim) ) - target = new_indexes[dim] = PandasIndex(safe_cast_to_index(indexers[dim])) + target = new_indexes[dim] = wrap_pandas_index(safe_cast_to_index(indexers[dim])) if dim in indexes: # TODO (benbovy - flexible indexes): support other indexes than pd.Index? diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 21daed1cec1..fe405853b10 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -51,7 +51,13 @@ ) from .dataset import Dataset, split_indexes from .formatting import format_item -from .indexes import Index, Indexes, PandasIndex, default_indexes, propagate_indexes +from .indexes import ( + Index, + Indexes, + default_indexes, + propagate_indexes, + wrap_pandas_index, +) from .indexing import is_fancy_indexer from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords from .options import OPTIONS, _get_keep_attrs @@ -1009,7 +1015,7 @@ def copy(self, deep: bool = True, data: Any = None) -> "DataArray": # TODO: benbovy: flexible indexes: support all xarray indexes (not just pandas.Index) # xarray Index needs a copy method. indexes = { - k: PandasIndex(v.to_pandas_index().copy(deep=deep)) + k: wrap_pandas_index(v.to_pandas_index().copy(deep=deep)) for k, v in self._indexes.items() } return self._replace(variable, coords, indexes=indexes) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 706ccbde8c4..fbc9772b80f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -64,11 +64,13 @@ Index, Indexes, PandasIndex, + PandasMultiIndex, default_indexes, isel_variable_and_index, propagate_indexes, remove_unused_levels_categories, roll_index, + wrap_pandas_index, ) from .indexing import is_fancy_indexer from .merge import ( @@ -3165,10 +3167,9 @@ def _rename_indexes(self, name_dict, dims_set): continue if isinstance(index, pd.MultiIndex): new_names = [name_dict.get(k, k) for k in index.names] - new_index = index.rename(names=new_names) + indexes[new_name] = PandasMultiIndex(index.rename(names=new_names)) else: - new_index = index.rename(new_name) - indexes[new_name] = PandasIndex(new_index) + indexes[new_name] = PandasIndex(index.rename(new_name)) return indexes def _rename_all(self, name_dict, dims_dict): @@ -3397,7 +3398,7 @@ def swap_dims( if new_index.nlevels == 1: # make sure index name matches dimension name new_index = new_index.rename(k) - indexes[k] = PandasIndex(new_index) + indexes[k] = wrap_pandas_index(new_index) else: var = v.to_base_variable() var.dims = dims @@ -3670,7 +3671,7 @@ def reorder_levels( raise ValueError(f"coordinate {dim} has no MultiIndex") new_index = index.reorder_levels(order) variables[dim] = IndexVariable(coord.dims, new_index) - indexes[dim] = PandasIndex(new_index) + indexes[dim] = PandasMultiIndex(new_index) return self._replace(variables, indexes=indexes) @@ -3698,7 +3699,7 @@ def _stack_once(self, dims, new_dim): coord_names = set(self._coord_names) - set(dims) | {new_dim} indexes = {k: v for k, v in self.xindexes.items() if k not in dims} - indexes[new_dim] = PandasIndex(idx) + indexes[new_dim] = wrap_pandas_index(idx) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 6e425a521b5..c1c70583e5a 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -82,6 +82,27 @@ def _sanitize_slice_element(x): return x +def _query_slice(index, label, coord_name="", method=None, tolerance=None): + if method is not None or tolerance is not None: + raise NotImplementedError( + "cannot use ``method`` argument if any indexers are slice objects" + ) + indexer = index.slice_indexer( + _sanitize_slice_element(label.start), + _sanitize_slice_element(label.stop), + _sanitize_slice_element(label.step), + ) + if not isinstance(indexer, slice): + # unlike pandas, in xarray we never want to silently convert a + # slice indexer into an array indexer + raise KeyError( + "cannot represent labeled-based slice indexer for coordinate " + f"{coord_name!r} with a slice over integer positions; the index is " + "unsorted or non-unique" + ) + return indexer + + def _asarray_tuplesafe(values): """ Convert values into a numpy array of at most 1-dimension, while preserving @@ -173,127 +194,15 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray: def shape(self) -> Tuple[int]: return (len(self.array),) - def _query_multiindex(self, labels): - index = self.array - new_index = None - - # label(s) given for multi-index level(s) - if all([lbl in index.names for lbl in labels]): - is_nested_vals = _is_nested_tuple(tuple(labels.values())) - if len(labels) == index.nlevels and not is_nested_vals: - indexer = index.get_loc(tuple(labels[k] for k in index.names)) - else: - for k, v in labels.items(): - # index should be an item (i.e. Hashable) not an array-like - if isinstance(v, Sequence) and not isinstance(v, str): - raise ValueError( - "Vectorized selection is not " - f"available along coordinate {k!r} (multi-index level)" - ) - indexer, new_index = index.get_loc_level( - tuple(labels.values()), level=tuple(labels.keys()) - ) - # GH2619. Raise a KeyError if nothing is chosen - if indexer.dtype.kind == "b" and indexer.sum() == 0: - raise KeyError(f"{labels} not found") - - # assume one label value given for the multi-index "array" (dimension) - else: - if len(labels) > 1: - coord_name = next(iter(set(labels) - set(index.names))) - raise ValueError( - f"cannot provide labels for both coordinate {coord_name!r} (multi-index array) " - f"and one or more coordinates among {index.names!r} (multi-index levels)" - ) - - coord_name, label = next(iter(labels.items())) - - if is_dict_like(label): - invalid_levels = [name for name in label if name not in index.names] - if invalid_levels: - raise ValueError( - f"invalid multi-index level names {invalid_levels}" - ) - return self._query_multiindex(label) - - elif isinstance(label, slice): - indexer = index.slice_indexer( - _sanitize_slice_element(label.start), - _sanitize_slice_element(label.stop), - _sanitize_slice_element(label.step), - ) - if not isinstance(indexer, slice): - # unlike pandas, in xarray we never want to silently convert a - # slice indexer into an array indexer - raise KeyError( - "cannot represent labeled-based slice indexer for dimension " - f"{coord_name!r} with a slice over integer positions; the index is " - "unsorted or non-unique" - ) - - elif isinstance(label, tuple): - if _is_nested_tuple(label): - indexer = index.get_locs(label) - elif len(label) == index.nlevels: - indexer = index.get_loc(label) - else: - indexer, new_index = index.get_loc_level( - label, level=list(range(len(label))) - ) - - else: - label = ( - label - if getattr(label, "ndim", 1) > 1 # vectorized-indexing - else _asarray_tuplesafe(label) - ) - if label.ndim == 0: - indexer, new_index = index.get_loc_level(label.item(), level=0) - elif label.dtype.kind == "b": - indexer = label - else: - if label.ndim > 1: - raise ValueError( - "Vectorized selection is not available along " - f"coordinate {coord_name!r} with a multi-index" - ) - indexer = get_indexer_nd(index, label) - if np.any(indexer < 0): - raise KeyError(f"not all values found in index {coord_name!r}") - - if new_index is not None: - new_index = PandasIndex(new_index) - - return indexer, new_index - def query( self, labels, method=None, tolerance=None ) -> Tuple[Any, Union["PandasIndex", None]]: - if isinstance(self.array, pd.MultiIndex): - return self._query_multiindex(labels) - assert len(labels) == 1 coord_name, label = next(iter(labels.items())) index = self.array if isinstance(label, slice): - if method is not None or tolerance is not None: - raise NotImplementedError( - "cannot use ``method`` argument if any indexers are slice objects" - ) - indexer = index.slice_indexer( - _sanitize_slice_element(label.start), - _sanitize_slice_element(label.stop), - _sanitize_slice_element(label.step), - ) - if not isinstance(indexer, slice): - # unlike pandas, in xarray we never want to silently convert a - # slice indexer into an array indexer - raise KeyError( - "cannot represent labeled-based slice indexer for coordinate " - f"{coord_name!r} with a slice over integer positions; the index is " - "unsorted or non-unique" - ) + indexer = _query_slice(index, label, coord_name, method, tolerance) elif is_dict_like(label): raise ValueError( "cannot use a dict-like object for selection on " @@ -333,18 +242,18 @@ def query( def equals(self, other): if isinstance(other, pd.Index): - other = PandasIndex(other) - return isinstance(other, PandasIndex) and self.array.equals(other.array) + other = type(self)(other) + return isinstance(other, type(self)) and self.array.equals(other.array) def union(self, other): if isinstance(other, pd.Index): - other = PandasIndex(other) - return PandasIndex(self.array.union(other.array)) + other = type(self)(other) + return type(self)(self.array.union(other.array)) def intersection(self, other): if isinstance(other, pd.Index): other = PandasIndex(other) - return PandasIndex(self.array.intersection(other.array)) + return type(self)(self.array.intersection(other.array)) def __getitem__( self, indexer @@ -367,7 +276,7 @@ def __getitem__( result = self.array[key] if isinstance(result, pd.Index): - result = PandasIndex(result, dtype=self.dtype) + result = type(self)(result, dtype=self.dtype) else: # result is a scalar if result is pd.NaT: @@ -406,7 +315,104 @@ def copy(self, deep: bool = True) -> "PandasIndex": # >>> len(pickle.dumps((self.array, self.array.copy(deep=False)))) # 8000341 array = self.array.copy(deep=True) if deep else self.array - return PandasIndex(array, self._dtype) + return type(self)(array, self._dtype) + + +class PandasMultiIndex(PandasIndex): + def query( + self, labels, method=None, tolerance=None + ) -> Tuple[Any, Union["PandasIndex", None]]: + if method is not None or tolerance is not None: + raise ValueError( + "multi-index does not support ``method`` and ``tolerance``" + ) + + index = self.array + new_index = None + + # label(s) given for multi-index level(s) + if all([lbl in index.names for lbl in labels]): + is_nested_vals = _is_nested_tuple(tuple(labels.values())) + if len(labels) == index.nlevels and not is_nested_vals: + indexer = index.get_loc(tuple(labels[k] for k in index.names)) + else: + for k, v in labels.items(): + # index should be an item (i.e. Hashable) not an array-like + if isinstance(v, Sequence) and not isinstance(v, str): + raise ValueError( + "Vectorized selection is not " + f"available along coordinate {k!r} (multi-index level)" + ) + indexer, new_index = index.get_loc_level( + tuple(labels.values()), level=tuple(labels.keys()) + ) + # GH2619. Raise a KeyError if nothing is chosen + if indexer.dtype.kind == "b" and indexer.sum() == 0: + raise KeyError(f"{labels} not found") + + # assume one label value given for the multi-index "array" (dimension) + else: + if len(labels) > 1: + coord_name = next(iter(set(labels) - set(index.names))) + raise ValueError( + f"cannot provide labels for both coordinate {coord_name!r} (multi-index array) " + f"and one or more coordinates among {index.names!r} (multi-index levels)" + ) + + coord_name, label = next(iter(labels.items())) + + if is_dict_like(label): + invalid_levels = [name for name in label if name not in index.names] + if invalid_levels: + raise ValueError( + f"invalid multi-index level names {invalid_levels}" + ) + return self.query(label) + + elif isinstance(label, slice): + indexer = _query_slice(index, label, coord_name) + + elif isinstance(label, tuple): + if _is_nested_tuple(label): + indexer = index.get_locs(label) + elif len(label) == index.nlevels: + indexer = index.get_loc(label) + else: + indexer, new_index = index.get_loc_level( + label, level=list(range(len(label))) + ) + + else: + label = ( + label + if getattr(label, "ndim", 1) > 1 # vectorized-indexing + else _asarray_tuplesafe(label) + ) + if label.ndim == 0: + indexer, new_index = index.get_loc_level(label.item(), level=0) + elif label.dtype.kind == "b": + indexer = label + else: + if label.ndim > 1: + raise ValueError( + "Vectorized selection is not available along " + f"coordinate {coord_name!r} with a multi-index" + ) + indexer = get_indexer_nd(index, label) + if np.any(indexer < 0): + raise KeyError(f"not all values found in index {coord_name!r}") + + if new_index is not None: + new_index = PandasIndex(new_index) + + return indexer, new_index + + +def wrap_pandas_index(index): + if isinstance(index, pd.MultiIndex): + return PandasMultiIndex(index) + else: + return PandasIndex(index) def remove_unused_levels_categories(index: pd.Index) -> pd.Index: @@ -517,7 +523,7 @@ def isel_variable_and_index( if isinstance(indexer, Variable): indexer = indexer.data pd_index = index.to_pandas_index() - new_index = PandasIndex(pd_index[indexer]) + new_index = wrap_pandas_index(pd_index[indexer]) return new_variable, new_index diff --git a/xarray/core/variable.py b/xarray/core/variable.py index cffaf2c3146..b3868fde66e 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -26,7 +26,7 @@ from . import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from .arithmetic import VariableArithmetic from .common import AbstractArray -from .indexes import PandasIndex +from .indexes import PandasIndex, wrap_pandas_index from .indexing import BasicIndexer, OuterIndexer, VectorizedIndexer, as_indexable from .options import _get_keep_attrs from .pycompat import ( @@ -179,7 +179,7 @@ def _maybe_wrap_data(data): all pass through unmodified. """ if isinstance(data, pd.Index): - return PandasIndex(data) + return wrap_pandas_index(data) return data @@ -554,7 +554,7 @@ def to_index_variable(self): def _to_xindex(self): # temporary function used internally as a replacement of to_index() # returns an xarray Index instance instead of a pd.Index instance - return PandasIndex(self.to_index()) + return wrap_pandas_index(self.to_index()) def to_index(self): """Convert this variable to a pandas.Index""" diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index e8b398cfa43..f91da2796ab 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -2,7 +2,7 @@ import pandas as pd import pytest -from xarray.core.indexes import PandasIndex, _asarray_tuplesafe +from xarray.core.indexes import PandasIndex, PandasMultiIndex, _asarray_tuplesafe def test_asarray_tuplesafe(): @@ -28,20 +28,6 @@ def test_query(self): with pytest.raises(ValueError, match=r"does not have a MultiIndex"): index.query({"x": {"one": 0}}) - index = PandasIndex( - pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) - ) - with pytest.raises(KeyError, match=r"not all values found"): - index.query({"x": [0]}) - with pytest.raises(KeyError): - index.query({"x": 0}) - with pytest.raises(ValueError, match=r"cannot provide labels for both.*"): - index.query({"one": 0, "x": "a"}) - with pytest.raises(ValueError, match=r"invalid multi-index level names"): - index.query({"x": {"three": 0}}) - with pytest.raises(IndexError): - index.query({"x": (slice(None), 1, "no_level")}) - def test_query_datetime(self): index = PandasIndex(pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"])) actual = index.query({"x": "2001-01-01"}) @@ -58,3 +44,20 @@ def test_query_unsorted_datetime_index_raises(self): # raise instead, so we can be sure the result of indexing with a # slice is always a view. index.query({"x": slice("2001", "2002")}) + + +class TestPandasMultiIndex: + def test_query(self): + index = PandasMultiIndex( + pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + ) + with pytest.raises(KeyError, match=r"not all values found"): + index.query({"x": [0]}) + with pytest.raises(KeyError): + index.query({"x": 0}) + with pytest.raises(ValueError, match=r"cannot provide labels for both.*"): + index.query({"one": 0, "x": "a"}) + with pytest.raises(ValueError, match=r"invalid multi-index level names"): + index.query({"x": {"three": 0}}) + with pytest.raises(IndexError): + index.query({"x": (slice(None), 1, "no_level")}) From 16fb7e112530430051ee36540c7c2fecaa510a06 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 May 2021 15:53:24 +0200 Subject: [PATCH 08/11] remove PandasIndex.from_variables for now Add it later in the refactoring when it will be needed elsewhere (e.g., in ``set_index``). --- xarray/core/indexes.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index c1c70583e5a..b42e6d5fcdf 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -165,14 +165,6 @@ def __init__( dtype_ = np.dtype(dtype) self._dtype = dtype_ - @classmethod - def from_variables(cls, variables: Dict[Hashable, "Variable"], **kwargs): - if len(variables) > 1: - raise ValueError("Cannot set a pandas.Index from more than one variable") - - varname, var = list(variables.items())[0] - return cls(var.data, dtype=var.dtype, coord_name=varname) - def to_pandas_index(self) -> pd.Index: return self.array From f7c549ca36a7f9f8a93e426ca85f7e20333167fe Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Mon, 17 May 2021 18:56:12 +0200 Subject: [PATCH 09/11] fix broken tests Is this what we want? --- xarray/core/indexes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index b42e6d5fcdf..c8bc891fbf6 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -235,7 +235,7 @@ def query( def equals(self, other): if isinstance(other, pd.Index): other = type(self)(other) - return isinstance(other, type(self)) and self.array.equals(other.array) + return self.array.equals(other.array) def union(self, other): if isinstance(other, pd.Index): From fdcc540b585d7ee80eeb77dae84b2b705ae26eac Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 28 May 2021 17:38:16 +0200 Subject: [PATCH 10/11] prevent loading values for xarray objs in slice --- xarray/core/indexes.py | 9 +++++---- xarray/tests/test_indexes.py | 3 +++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index c8bc891fbf6..543c04cf24c 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -69,14 +69,15 @@ def _sanitize_slice_element(x): from .dataarray import DataArray from .variable import Variable + if not isinstance(x, tuple) and len(np.shape(x)) != 0: + raise ValueError( + f"cannot use non-scalar arrays in a slice for xarray indexing: {x}" + ) + if isinstance(x, (Variable, DataArray)): x = x.values if isinstance(x, np.ndarray): - if x.ndim != 0: - raise ValueError( - f"cannot use non-scalar arrays in a slice for xarray indexing: {x}" - ) x = x[()] return x diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index f91da2796ab..defc6212228 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -51,6 +51,9 @@ def test_query(self): index = PandasMultiIndex( pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) ) + # test tuples inside slice are considered as scalar indexer values + assert index.query({"x": slice(("a", 1), ("b", 2))}) == (slice(0, 4), None) + with pytest.raises(KeyError, match=r"not all values found"): index.query({"x": [0]}) with pytest.raises(KeyError): From fda484988c074bfd371ed490641a383c9429c43a Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 28 May 2021 17:53:53 +0200 Subject: [PATCH 11/11] update what's new --- doc/whats-new.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index fb96ed6293c..e6a0c7982c8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -56,6 +56,13 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Explicit indexes refactor: add a ``xarray.Index.query()`` method in which + one may eventually provide a custom implementation of label-based data + selection (not ready yet for public use). Also refactor the internal, + pandas-specific implementation into ``PandasIndex.query()`` and + ``PandasMultiIndex.query()`` (:pull:`5322`). + By `Benoit Bovy `_. + .. _whats-new.0.18.2: v0.18.2 (19 May 2021)