diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index b8efde7..f67c8ec 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations @@ -95,15 +95,23 @@ def bitwise_invert(self, x): def bitwise_or(self, x, y): return binary_op(x, y, opx.bitwise_or) - # TODO: ONNX standard -> not cyclic @validate_core def bitwise_right_shift(self, x, y): - return binary_op( - x, - y, - lambda x, y: opx.bit_shift(x, y, direction="RIGHT"), - dtypes.uint64, - ) + # Since we need to perform arithmetic right-shift we have to be a bit more careful + if isinstance(x.dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)): + return binary_op( + x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64 + ) + elif isinstance(x.dtype, (dtypes.Integral, dtypes.NullableIntegral)): + MAX_POW = 63 + pow2 = ndx.pow(ndx.asarray(2, ndx.int64), ndx.where(y > MAX_POW, 0, y)) + return ndx.where( + y >= MAX_POW, + ndx.where(x >= 0, 0, -1), + ndx.floor_divide(x, pow2), + ).astype(x.dtype) + else: + return NotImplemented @validate_core def bitwise_xor(self, x, y): @@ -271,7 +279,7 @@ def positive(self, x): def pow(self, x, y): x, y = ndx.asarray(x), ndx.asarray(y) dtype = ndx.result_type(x, y) - if isinstance(dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)): + if isinstance(dtype, (dtypes.Integral, dtypes.NullableIntegral)): return binary_op(x, y, opx.pow, dtypes.int64) else: return binary_op(x, y, opx.pow) @@ -848,6 +856,7 @@ def all(self, x, *, axis=None, keepdims: bool = False): x = ndx.where(x.null, True, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(True, dtype=ndx.bool) + x = x.astype(ndx.bool) return ndx.min(x.astype(ndx.int8), axis=axis, keepdims=keepdims).astype( ndx.bool ) @@ -858,6 +867,7 @@ def any(self, x, *, axis=None, keepdims: bool = False): x = ndx.where(x.null, False, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(False, dtype=ndx.bool) + x = x.astype(ndx.bool) return ndx.max(x.astype(ndx.int8), axis=axis, keepdims=keepdims).astype( ndx.bool ) diff --git a/ndonnx/_core/_shapeimpl.py b/ndonnx/_core/_shapeimpl.py index 56a3c8d..ff8142e 100644 --- a/ndonnx/_core/_shapeimpl.py +++ b/ndonnx/_core/_shapeimpl.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations @@ -60,6 +60,8 @@ def broadcast_to(self, x, shape): return x._transmute(lambda corearray: opx.expand(corearray, shape)) def expand_dims(self, x, axis): + if axis < -x.ndim - 1 or axis > x.ndim: + raise IndexError(f"Axis must be in [-{x.ndim}, {x.ndim}]") return x._transmute( lambda corearray: opx.unsqueeze( corearray, axes=opx.const([axis], dtype=dtypes.int64) diff --git a/ndonnx/_opset_extensions.py b/ndonnx/_opset_extensions.py index 1ac10df..5056d0a 100644 --- a/ndonnx/_opset_extensions.py +++ b/ndonnx/_opset_extensions.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations @@ -547,7 +547,10 @@ def getitem( ) -> _CoreArray: if isinstance(index, _CoreArray): if get_dtype(index) == np.bool_: - if get_rank(corearray) < get_rank(index): + if get_rank(corearray) < get_rank(index) or any( + isinstance(xs, int) and isinstance(ks, int) and ks not in (xs, 0) + for xs, ks in zip(get_shape(corearray), get_shape(index)) + ): raise IndexError("Indexing with boolean array cannot happen") return getitem_null(corearray, index) else: diff --git a/tests/test_core.py b/tests/test_core.py index a9bfe0f..81a9547 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations @@ -996,18 +996,22 @@ def test_no_unsafe_cumulative_sum_cast(): @pytest.mark.parametrize("keepdims", [True, False]) +@pytest.mark.parametrize("func", [np.argmax, np.argmin]) @pytest.mark.parametrize( - "func, x", + "x", [ - (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int32)), - (np.argmax, np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)), - (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int8)), - (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float32)), - (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float64)), - (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)), - (np.argmin, np.array([[-11, 2, 3], [4, 5, -6]], dtype=np.int32)), - (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float64)), - (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int16)), + np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32), + np.array([[-11, 2, 3], [4, 5, -6]], dtype=np.int32), + # Test all types + np.array([1, 2, 3, 4, 5], dtype=np.int8), + np.array([1, 2, 3, 4, 5], dtype=np.int16), + np.array([1, 2, 3, 4, 5], dtype=np.int32), + np.array([1, 2, 3, 4, 5], dtype=np.int32), + np.array([1, 2, 3, 4, 5], dtype=np.uint8), + np.array([1, 2, 3, 4, 5], dtype=np.uint16), + # (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.uint64)), -> onnxruntime + np.array([1, 2, 3, 4, 5], dtype=np.float32), + np.array([1, 2, 3, 4, 5], dtype=np.float64), ], ) def test_argmaxmin(func, x, keepdims): @@ -1020,11 +1024,12 @@ def test_argmaxmin(func, x, keepdims): # Pending ORT 1.19 conda-forge release before this becomes supported: # https://github.com/conda-forge/onnxruntime-feedstock/pull/128 +@pytest.mark.parametrize("func", [np.argmax, np.argmin]) @pytest.mark.parametrize( - "func, x", + "x", [ - (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int64)), - (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int64)), + np.array([1, 2, 3, 4, 5], dtype=np.int64), + # np.array([1, 2, 3, 4, 5], dtype=np.uint32), -> onnxruntime ], ) def test_argmaxmin_unsupported_kernels(func, x): @@ -1038,3 +1043,21 @@ def test_argmaxmin_unsupported_kernels(func, x): with pytest.raises(TypeError): getattr(ndx, func.__name__)(ndx.asarray(x)) + + +@pytest.mark.parametrize( + "x, index", + [ + ( + ndx.asarray([1, 2, 3, 4, 5]), + ndx.asarray([[True, True, False, False, True]], dtype=ndx.bool), + ), + ( + ndx.asarray([1, 2, 3, 4, 5]), + ndx.asarray([True, False, False, True], dtype=ndx.bool), + ), + ], +) +def test_getitem_bool_raises(x, index): + with pytest.raises(IndexError): + x[index] diff --git a/xfails.txt b/xfails.txt index 9665406..0f06fd8 100644 --- a/xfails.txt +++ b/xfails.txt @@ -1,3 +1,7 @@ +# These are not implemented for uint64, but are tested internally +array_api_tests/test_searching_functions.py::test_argmin +array_api_tests/test_searching_functions.py::test_argmax + array_api_tests/test_constants.py::test_newaxis array_api_tests/test_creation_functions.py::test_eye array_api_tests/test_creation_functions.py::test_meshgrid