diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index ea8ae44bb4d..407fda610fc 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -326,18 +326,23 @@ def as_integer_slice(value): class IndexCallable: - """Provide getitem syntax for a callable object.""" + """Provide getitem and setitem syntax for callable objects.""" - __slots__ = ("func",) + __slots__ = ("getter", "setter") - def __init__(self, func): - self.func = func + def __init__(self, getter, setter=None): + self.getter = getter + self.setter = setter def __getitem__(self, key): - return self.func(key) + return self.getter(key) def __setitem__(self, key, value): - raise NotImplementedError + if self.setter is None: + raise NotImplementedError( + "Setting values is not supported for this indexer." + ) + self.setter(key, value) class BasicIndexer(ExplicitIndexer): @@ -486,10 +491,24 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: return np.asarray(self.get_duck_array(), dtype=dtype) def _oindex_get(self, key): - raise NotImplementedError("This method should be overridden") + raise NotImplementedError( + f"{self.__class__.__name__}._oindex_get method should be overridden" + ) def _vindex_get(self, key): - raise NotImplementedError("This method should be overridden") + raise NotImplementedError( + f"{self.__class__.__name__}._vindex_get method should be overridden" + ) + + def _oindex_set(self, key, value): + raise NotImplementedError( + f"{self.__class__.__name__}._oindex_set method should be overridden" + ) + + def _vindex_set(self, key, value): + raise NotImplementedError( + f"{self.__class__.__name__}._vindex_set method should be overridden" + ) def _check_and_raise_if_non_basic_indexer(self, key): if isinstance(key, (VectorizedIndexer, OuterIndexer)): @@ -500,11 +519,11 @@ def _check_and_raise_if_non_basic_indexer(self, key): @property def oindex(self): - return IndexCallable(self._oindex_get) + return IndexCallable(self._oindex_get, self._oindex_set) @property def vindex(self): - return IndexCallable(self._vindex_get) + return IndexCallable(self._vindex_get, self._vindex_set) class ImplicitToExplicitIndexingAdapter(NDArrayMixin): @@ -616,12 +635,18 @@ def __getitem__(self, indexer): self._check_and_raise_if_non_basic_indexer(indexer) return type(self)(self.array, self._updated_key(indexer)) + def _vindex_set(self, key, value): + raise NotImplementedError( + "Lazy item assignment with the vectorized indexer is not yet " + "implemented. Load your data first by .load() or compute()." + ) + + def _oindex_set(self, key, value): + full_key = self._updated_key(key) + self.array.oindex[full_key] = value + def __setitem__(self, key, value): - if isinstance(key, VectorizedIndexer): - raise NotImplementedError( - "Lazy item assignment with the vectorized indexer is not yet " - "implemented. Load your data first by .load() or compute()." - ) + self._check_and_raise_if_non_basic_indexer(key) full_key = self._updated_key(key) self.array[full_key] = value @@ -657,7 +682,6 @@ def shape(self) -> tuple[int, ...]: return np.broadcast(*self.key.tuple).shape def get_duck_array(self): - if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): array = apply_indexer(self.array, self.key) else: @@ -739,8 +763,18 @@ def __getitem__(self, key): def transpose(self, order): return self.array.transpose(order) + def _vindex_set(self, key, value): + self._ensure_copied() + self.array.vindex[key] = value + + def _oindex_set(self, key, value): + self._ensure_copied() + self.array.oindex[key] = value + def __setitem__(self, key, value): + self._check_and_raise_if_non_basic_indexer(key) self._ensure_copied() + self.array[key] = value def __deepcopy__(self, memo): @@ -779,7 +813,14 @@ def __getitem__(self, key): def transpose(self, order): return self.array.transpose(order) + def _vindex_set(self, key, value): + self.array.vindex[key] = value + + def _oindex_set(self, key, value): + self.array.oindex[key] = value + def __setitem__(self, key, value): + self._check_and_raise_if_non_basic_indexer(key) self.array[key] = value @@ -950,6 +991,16 @@ def apply_indexer(indexable, indexer): return indexable[indexer] +def set_with_indexer(indexable, indexer, value): + """Set values in an indexable object using an indexer.""" + if isinstance(indexer, VectorizedIndexer): + indexable.vindex[indexer] = value + elif isinstance(indexer, OuterIndexer): + indexable.oindex[indexer] = value + else: + indexable[indexer] = value + + def decompose_indexer( indexer: ExplicitIndexer, shape: tuple[int, ...], indexing_support: IndexingSupport ) -> tuple[ExplicitIndexer, ExplicitIndexer]: @@ -1399,24 +1450,6 @@ def __init__(self, array): ) self.array = array - def _indexing_array_and_key(self, key): - if isinstance(key, OuterIndexer): - array = self.array - key = _outer_to_numpy_indexer(key, self.array.shape) - elif isinstance(key, VectorizedIndexer): - array = NumpyVIndexAdapter(self.array) - key = key.tuple - elif isinstance(key, BasicIndexer): - array = self.array - # We want 0d slices rather than scalars. This is achieved by - # appending an ellipsis (see - # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). - key = key.tuple + (Ellipsis,) - else: - raise TypeError(f"unexpected key type: {type(key)}") - - return array, key - def transpose(self, order): return self.array.transpose(order) @@ -1430,14 +1463,18 @@ def _vindex_get(self, key): def __getitem__(self, key): self._check_and_raise_if_non_basic_indexer(key) - array, key = self._indexing_array_and_key(key) + + array = self.array + # We want 0d slices rather than scalars. This is achieved by + # appending an ellipsis (see + # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). + key = key.tuple + (Ellipsis,) return array[key] - def __setitem__(self, key, value): - array, key = self._indexing_array_and_key(key) + def _safe_setitem(self, array, key, value): try: array[key] = value - except ValueError: + except ValueError as exc: # More informative exception if read-only view if not array.flags.writeable and not array.flags.owndata: raise ValueError( @@ -1445,7 +1482,24 @@ def __setitem__(self, key, value): "Do you want to .copy() array first?" ) else: - raise + raise exc + + def _oindex_set(self, key, value): + key = _outer_to_numpy_indexer(key, self.array.shape) + self._safe_setitem(self.array, key, value) + + def _vindex_set(self, key, value): + array = NumpyVIndexAdapter(self.array) + self._safe_setitem(array, key.tuple, value) + + def __setitem__(self, key, value): + self._check_and_raise_if_non_basic_indexer(key) + array = self.array + # We want 0d slices rather than scalars. This is achieved by + # appending an ellipsis (see + # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). + key = key.tuple + (Ellipsis,) + self._safe_setitem(array, key, value) class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter): @@ -1488,13 +1542,15 @@ def __getitem__(self, key): self._check_and_raise_if_non_basic_indexer(key) return self.array[key.tuple] + def _oindex_set(self, key, value): + self.array[key.tuple] = value + + def _vindex_set(self, key, value): + raise TypeError("Vectorized indexing is not supported") + def __setitem__(self, key, value): - if isinstance(key, (BasicIndexer, OuterIndexer)): - self.array[key.tuple] = value - elif isinstance(key, VectorizedIndexer): - raise TypeError("Vectorized indexing is not supported") - else: - raise TypeError(f"Unrecognized indexer: {key}") + self._check_and_raise_if_non_basic_indexer(key) + self.array[key.tuple] = value def transpose(self, order): xp = self.array.__array_namespace__() @@ -1530,19 +1586,20 @@ def __getitem__(self, key): self._check_and_raise_if_non_basic_indexer(key) return self.array[key.tuple] + def _oindex_set(self, key, value): + num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple) + if num_non_slices > 1: + raise NotImplementedError( + "xarray can't set arrays with multiple " "array indices to dask yet." + ) + self.array[key.tuple] = value + + def _vindex_set(self, key, value): + self.array.vindex[key.tuple] = value + def __setitem__(self, key, value): - if isinstance(key, BasicIndexer): - self.array[key.tuple] = value - elif isinstance(key, VectorizedIndexer): - self.array.vindex[key.tuple] = value - elif isinstance(key, OuterIndexer): - num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple) - if num_non_slices > 1: - raise NotImplementedError( - "xarray can't set arrays with multiple " - "array indices to dask yet." - ) - self.array[key.tuple] = value + self._check_and_raise_if_non_basic_indexer(key) + self.array[key.tuple] = value def transpose(self, order): return self.array.transpose(order) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index cad48d0775a..2ac0c04d726 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -849,7 +849,7 @@ def __setitem__(self, key, value): value = np.moveaxis(value, new_order, range(len(new_order))) indexable = as_indexable(self._data) - indexable[index_tuple] = value + indexing.set_with_indexer(indexable, index_tuple, value) @property def encoding(self) -> dict[Any, Any]: diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index c3989bbf23e..e650c454eac 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -23,6 +23,28 @@ B = IndexerMaker(indexing.BasicIndexer) +class TestIndexCallable: + def test_getitem(self): + def getter(key): + return key * 2 + + indexer = indexing.IndexCallable(getter) + assert indexer[3] == 6 + assert indexer[0] == 0 + assert indexer[-1] == -2 + + def test_setitem(self): + def getter(key): + return key * 2 + + def setter(key, value): + raise NotImplementedError("Setter not implemented") + + indexer = indexing.IndexCallable(getter, setter) + with pytest.raises(NotImplementedError): + indexer[3] = 6 + + class TestIndexers: def set_to_zero(self, x, i): x = x.copy() @@ -361,15 +383,8 @@ def test_vectorized_lazily_indexed_array(self) -> None: def check_indexing(v_eager, v_lazy, indexers): for indexer in indexers: - if isinstance(indexer, indexing.VectorizedIndexer): - actual = v_lazy.vindex[indexer] - expected = v_eager.vindex[indexer] - elif isinstance(indexer, indexing.OuterIndexer): - actual = v_lazy.oindex[indexer] - expected = v_eager.oindex[indexer] - else: - actual = v_lazy[indexer] - expected = v_eager[indexer] + actual = v_lazy[indexer] + expected = v_eager[indexer] assert expected.shape == actual.shape assert isinstance( actual._data, @@ -406,6 +421,41 @@ def check_indexing(v_eager, v_lazy, indexers): ] check_indexing(v_eager, v_lazy, indexers) + def test_lazily_indexed_array_vindex_setitem(self) -> None: + + lazy = indexing.LazilyIndexedArray(np.random.rand(10, 20, 30)) + + # vectorized indexing + indexer = indexing.VectorizedIndexer( + (np.array([0, 1]), np.array([0, 1]), slice(None, None, None)) + ) + with pytest.raises( + NotImplementedError, + match=r"Lazy item assignment with the vectorized indexer is not yet", + ): + lazy.vindex[indexer] = 0 + + @pytest.mark.parametrize( + "indexer_class, key, value", + [ + (indexing.OuterIndexer, (0, 1, slice(None, None, None)), 10), + (indexing.BasicIndexer, (0, 1, slice(None, None, None)), 10), + ], + ) + def test_lazily_indexed_array_setitem(self, indexer_class, key, value) -> None: + original = np.random.rand(10, 20, 30) + x = indexing.NumpyIndexingAdapter(original) + lazy = indexing.LazilyIndexedArray(x) + + if indexer_class is indexing.BasicIndexer: + indexer = indexer_class(key) + lazy[indexer] = value + elif indexer_class is indexing.OuterIndexer: + indexer = indexer_class(key) + lazy.oindex[indexer] = value + + assert_array_equal(original[key], value) + class TestCopyOnWriteArray: def test_setitem(self) -> None: