From b5e0d51b666630c1e633796076eca0dc9b082592 Mon Sep 17 00:00:00 2001 From: neNasko1 Date: Tue, 7 Jan 2025 19:09:33 +0200 Subject: [PATCH 1/4] Fix weekly failures Signed-off-by: neNasko1 --- ndonnx/_core/_numericimpl.py | 17 +++++++++-------- ndonnx/_core/_shapeimpl.py | 2 ++ ndonnx/_opset_extensions.py | 5 ++++- xfails.txt | 4 ++++ 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index b8efde7..557d005 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -95,15 +95,14 @@ def bitwise_invert(self, x): def bitwise_or(self, x, y): return binary_op(x, y, opx.bitwise_or) - # TODO: ONNX standard -> not cyclic @validate_core def bitwise_right_shift(self, x, y): - return binary_op( - x, - y, - lambda x, y: opx.bit_shift(x, y, direction="RIGHT"), - dtypes.uint64, - ) + # Since we need to perform arithmetic right-shift we have to be a bit more careful + return ndx.where( + y >= 63, + ndx.where(x >= 0, 0, -1), + ndx.floor_divide(x, ndx.pow(ndx.asarray(2, ndx.int64), y)), + ).astype(x.dtype) @validate_core def bitwise_xor(self, x, y): @@ -271,7 +270,7 @@ def positive(self, x): def pow(self, x, y): x, y = ndx.asarray(x), ndx.asarray(y) dtype = ndx.result_type(x, y) - if isinstance(dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)): + if isinstance(dtype, (dtypes.Integral, dtypes.NullableIntegral)): return binary_op(x, y, opx.pow, dtypes.int64) else: return binary_op(x, y, opx.pow) @@ -848,6 +847,7 @@ def all(self, x, *, axis=None, keepdims: bool = False): x = ndx.where(x.null, True, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(True, dtype=ndx.bool) + x = x.astype(ndx.bool) return ndx.min(x.astype(ndx.int8), axis=axis, keepdims=keepdims).astype( ndx.bool ) @@ -858,6 +858,7 @@ def any(self, x, *, axis=None, keepdims: bool = False): x = ndx.where(x.null, False, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(False, dtype=ndx.bool) + x = x.astype(ndx.bool) return ndx.max(x.astype(ndx.int8), axis=axis, keepdims=keepdims).astype( ndx.bool ) diff --git a/ndonnx/_core/_shapeimpl.py b/ndonnx/_core/_shapeimpl.py index 56a3c8d..d7b78aa 100644 --- a/ndonnx/_core/_shapeimpl.py +++ b/ndonnx/_core/_shapeimpl.py @@ -60,6 +60,8 @@ def broadcast_to(self, x, shape): return x._transmute(lambda corearray: opx.expand(corearray, shape)) def expand_dims(self, x, axis): + if axis < -x.ndim - 1 or axis > x.ndim: + raise IndexError(f"Axis must be in [-{x.ndim}, {x.ndim}]") return x._transmute( lambda corearray: opx.unsqueeze( corearray, axes=opx.const([axis], dtype=dtypes.int64) diff --git a/ndonnx/_opset_extensions.py b/ndonnx/_opset_extensions.py index 1ac10df..0c2829a 100644 --- a/ndonnx/_opset_extensions.py +++ b/ndonnx/_opset_extensions.py @@ -547,7 +547,10 @@ def getitem( ) -> _CoreArray: if isinstance(index, _CoreArray): if get_dtype(index) == np.bool_: - if get_rank(corearray) < get_rank(index): + if get_rank(corearray) < get_rank(index) or any( + isinstance(xs, int) and isinstance(ks, int) and ks not in (xs, 0) + for xs, ks in zip(get_shape(corearray), get_shape(index)) + ): raise IndexError("Indexing with boolean array cannot happen") return getitem_null(corearray, index) else: diff --git a/xfails.txt b/xfails.txt index 9665406..0f06fd8 100644 --- a/xfails.txt +++ b/xfails.txt @@ -1,3 +1,7 @@ +# These are not implemented for uint64, but are tested internally +array_api_tests/test_searching_functions.py::test_argmin +array_api_tests/test_searching_functions.py::test_argmax + array_api_tests/test_constants.py::test_newaxis array_api_tests/test_creation_functions.py::test_eye array_api_tests/test_creation_functions.py::test_meshgrid From 6bd7b43a729caf02dafbb122b84f6f2f14c08a9c Mon Sep 17 00:00:00 2001 From: neNasko1 Date: Tue, 7 Jan 2025 19:14:56 +0200 Subject: [PATCH 2/4] Update tests for argmin and argmax Signed-off-by: neNasko1 --- ndonnx/_core/_numericimpl.py | 2 +- ndonnx/_core/_shapeimpl.py | 2 +- ndonnx/_opset_extensions.py | 2 +- tests/test_core.py | 31 ++++++++++++++++++------------- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index 557d005..7301708 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations diff --git a/ndonnx/_core/_shapeimpl.py b/ndonnx/_core/_shapeimpl.py index d7b78aa..ff8142e 100644 --- a/ndonnx/_core/_shapeimpl.py +++ b/ndonnx/_core/_shapeimpl.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations diff --git a/ndonnx/_opset_extensions.py b/ndonnx/_opset_extensions.py index 0c2829a..5056d0a 100644 --- a/ndonnx/_opset_extensions.py +++ b/ndonnx/_opset_extensions.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations diff --git a/tests/test_core.py b/tests/test_core.py index a9bfe0f..b64eea8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -996,18 +996,22 @@ def test_no_unsafe_cumulative_sum_cast(): @pytest.mark.parametrize("keepdims", [True, False]) +@pytest.mark.parametrize("func", [np.argmax, np.argmin]) @pytest.mark.parametrize( - "func, x", + "x", [ - (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int32)), - (np.argmax, np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)), - (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int8)), - (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float32)), - (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float64)), - (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)), - (np.argmin, np.array([[-11, 2, 3], [4, 5, -6]], dtype=np.int32)), - (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float64)), - (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int16)), + np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32), + np.array([[-11, 2, 3], [4, 5, -6]], dtype=np.int32), + # Test all types + np.array([1, 2, 3, 4, 5], dtype=np.int8), + np.array([1, 2, 3, 4, 5], dtype=np.int16), + np.array([1, 2, 3, 4, 5], dtype=np.int32), + np.array([1, 2, 3, 4, 5], dtype=np.int32), + np.array([1, 2, 3, 4, 5], dtype=np.uint8), + np.array([1, 2, 3, 4, 5], dtype=np.uint16), + # (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.uint64)), -> onnxruntime + np.array([1, 2, 3, 4, 5], dtype=np.float32), + np.array([1, 2, 3, 4, 5], dtype=np.float64), ], ) def test_argmaxmin(func, x, keepdims): @@ -1020,11 +1024,12 @@ def test_argmaxmin(func, x, keepdims): # Pending ORT 1.19 conda-forge release before this becomes supported: # https://github.com/conda-forge/onnxruntime-feedstock/pull/128 +@pytest.mark.parametrize("func", [np.argmax, np.argmin]) @pytest.mark.parametrize( - "func, x", + "x", [ - (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int64)), - (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int64)), + np.array([1, 2, 3, 4, 5], dtype=np.int64), + # np.array([1, 2, 3, 4, 5], dtype=np.uint32), -> onnxruntime ], ) def test_argmaxmin_unsupported_kernels(func, x): From 3e22b77ee174e0e352a58db3e8e6e8cefb7f36c8 Mon Sep 17 00:00:00 2001 From: neNasko1 Date: Tue, 7 Jan 2025 19:26:17 +0200 Subject: [PATCH 3/4] Optimise and remove magic constants Signed-off-by: neNasko1 --- ndonnx/_core/_numericimpl.py | 7 ++++++- tests/test_core.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index 7301708..e22205c 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -98,8 +98,13 @@ def bitwise_or(self, x, y): @validate_core def bitwise_right_shift(self, x, y): # Since we need to perform arithmetic right-shift we have to be a bit more careful + if isinstance(x.dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)): + return binary_op( + x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64 + ) + MAX_POW = 63 return ndx.where( - y >= 63, + y >= MAX_POW, ndx.where(x >= 0, 0, -1), ndx.floor_divide(x, ndx.pow(ndx.asarray(2, ndx.int64), y)), ).astype(x.dtype) diff --git a/tests/test_core.py b/tests/test_core.py index b64eea8..9c1e81e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations From dbc93040f392e74b26d4a5e7e24392786934165d Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Tue, 14 Jan 2025 09:24:46 +0200 Subject: [PATCH 4/4] Comments after code-review --- ndonnx/_core/_numericimpl.py | 16 ++++++++++------ tests/test_core.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index e22205c..f67c8ec 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -102,12 +102,16 @@ def bitwise_right_shift(self, x, y): return binary_op( x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64 ) - MAX_POW = 63 - return ndx.where( - y >= MAX_POW, - ndx.where(x >= 0, 0, -1), - ndx.floor_divide(x, ndx.pow(ndx.asarray(2, ndx.int64), y)), - ).astype(x.dtype) + elif isinstance(x.dtype, (dtypes.Integral, dtypes.NullableIntegral)): + MAX_POW = 63 + pow2 = ndx.pow(ndx.asarray(2, ndx.int64), ndx.where(y > MAX_POW, 0, y)) + return ndx.where( + y >= MAX_POW, + ndx.where(x >= 0, 0, -1), + ndx.floor_divide(x, pow2), + ).astype(x.dtype) + else: + return NotImplemented @validate_core def bitwise_xor(self, x, y): diff --git a/tests/test_core.py b/tests/test_core.py index 9c1e81e..81a9547 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1043,3 +1043,21 @@ def test_argmaxmin_unsupported_kernels(func, x): with pytest.raises(TypeError): getattr(ndx, func.__name__)(ndx.asarray(x)) + + +@pytest.mark.parametrize( + "x, index", + [ + ( + ndx.asarray([1, 2, 3, 4, 5]), + ndx.asarray([[True, True, False, False, True]], dtype=ndx.bool), + ), + ( + ndx.asarray([1, 2, 3, 4, 5]), + ndx.asarray([True, False, False, True], dtype=ndx.bool), + ), + ], +) +def test_getitem_bool_raises(x, index): + with pytest.raises(IndexError): + x[index]