Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix weekly failures #94

Merged
merged 5 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -95,15 +95,19 @@ def bitwise_invert(self, x):
def bitwise_or(self, x, y):
return binary_op(x, y, opx.bitwise_or)

# TODO: ONNX standard -> not cyclic
@validate_core
def bitwise_right_shift(self, x, y):
return binary_op(
x,
y,
lambda x, y: opx.bit_shift(x, y, direction="RIGHT"),
dtypes.uint64,
)
# Since we need to perform arithmetic right-shift we have to be a bit more careful
if isinstance(x.dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)):
return binary_op(
x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64
)
MAX_POW = 63
return ndx.where(
y >= MAX_POW,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that where executes both branches so you may still run into a division by zero issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which division by zero are you referring to?

Copy link
Collaborator

@cbourjau cbourjau Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 109 due to an overflow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point, I suspected the same but second-guessed it when the test looked like it passed! The explanation for why it silently succeed looks to be that ndonnx (or really onnxruntime) and NumPy overflowed quite differently :/.

(Pdb) np.__version__
'2.1.3'
(Pdb) ndx.__version__
'0.0.post75+g3e22b77'
(Pdb) np.pow(np.asarray(2, np.int64), 70)
np.int64(0)
(Pdb) ndx.pow(ndx.asarray(2, ndx.int64), 70)
Array(9223372036854775807, dtype=Int64)

But yes, let's not depend on this @neNasko1. I'm pretty sure you can just extract the sign bit and manually roll this with the logical right shift. Additionally, this should simply not work for floating point arrays. We'll improve this with typed arrays but for now you'll need to guard that.

I think the following should help avoid division.

-        MAX_POW = 63
-        return ndx.where(
-            y >= MAX_POW,
-            ndx.where(x >= 0, 0, -1),
-            ndx.floor_divide(x, ndx.pow(ndx.asarray(2, ndx.int64), y)),
-        ).astype(x.dtype)
+        elif isinstance(x.dtype, (dtypes.Integral, dtypes.NullableIntegral)):
+            mask = ndx.where(
+                x >= 0,
+                ndx.asarray(0, dtype=x.dtype),
+                (((ndx.asarray(1, dtype=ndx.uint64) << y) - 1) << (ndx.iinfo(x.dtype).bits-y)).astype(x.dtype)
+            )
+            logic_shift = binary_op(
+                x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64
+            ) | mask
+            MAX_POW = 63
+            return ndx.where(
+                y >= MAX_POW,
+                ndx.where(x >= 0, 0, -1),
+                logic_shift,
+            ).astype(x.dtype)
+        else:
+            return NotImplemented

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrote the solution a bit, can you take a look?

ndx.where(x >= 0, 0, -1),
ndx.floor_divide(x, ndx.pow(ndx.asarray(2, ndx.int64), y)),
).astype(x.dtype)

@validate_core
def bitwise_xor(self, x, y):
Expand Down Expand Up @@ -271,7 +275,7 @@ def positive(self, x):
def pow(self, x, y):
x, y = ndx.asarray(x), ndx.asarray(y)
dtype = ndx.result_type(x, y)
if isinstance(dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)):
if isinstance(dtype, (dtypes.Integral, dtypes.NullableIntegral)):
return binary_op(x, y, opx.pow, dtypes.int64)
else:
return binary_op(x, y, opx.pow)
Expand Down Expand Up @@ -848,6 +852,7 @@ def all(self, x, *, axis=None, keepdims: bool = False):
x = ndx.where(x.null, True, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(True, dtype=ndx.bool)
x = x.astype(ndx.bool)
return ndx.min(x.astype(ndx.int8), axis=axis, keepdims=keepdims).astype(
ndx.bool
)
Expand All @@ -858,6 +863,7 @@ def any(self, x, *, axis=None, keepdims: bool = False):
x = ndx.where(x.null, False, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(False, dtype=ndx.bool)
x = x.astype(ndx.bool)
return ndx.max(x.astype(ndx.int8), axis=axis, keepdims=keepdims).astype(
ndx.bool
)
Expand Down
4 changes: 3 additions & 1 deletion ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -60,6 +60,8 @@ def broadcast_to(self, x, shape):
return x._transmute(lambda corearray: opx.expand(corearray, shape))

def expand_dims(self, x, axis):
if axis < -x.ndim - 1 or axis > x.ndim:
raise IndexError(f"Axis must be in [-{x.ndim}, {x.ndim}]")
return x._transmute(
lambda corearray: opx.unsqueeze(
corearray, axes=opx.const([axis], dtype=dtypes.int64)
Expand Down
7 changes: 5 additions & 2 deletions ndonnx/_opset_extensions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -547,7 +547,10 @@ def getitem(
) -> _CoreArray:
if isinstance(index, _CoreArray):
if get_dtype(index) == np.bool_:
if get_rank(corearray) < get_rank(index):
if get_rank(corearray) < get_rank(index) or any(
isinstance(xs, int) and isinstance(ks, int) and ks not in (xs, 0)
for xs, ks in zip(get_shape(corearray), get_shape(index))
):
raise IndexError("Indexing with boolean array cannot happen")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a short test to ensure we IndexError in the correct situations with boolean indexing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

return getitem_null(corearray, index)
else:
Expand Down
33 changes: 19 additions & 14 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -996,18 +996,22 @@ def test_no_unsafe_cumulative_sum_cast():


@pytest.mark.parametrize("keepdims", [True, False])
@pytest.mark.parametrize("func", [np.argmax, np.argmin])
@pytest.mark.parametrize(
"func, x",
"x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int32)),
(np.argmax, np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int8)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmin, np.array([[-11, 2, 3], [4, 5, -6]], dtype=np.int32)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int16)),
np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32),
np.array([[-11, 2, 3], [4, 5, -6]], dtype=np.int32),
# Test all types
np.array([1, 2, 3, 4, 5], dtype=np.int8),
np.array([1, 2, 3, 4, 5], dtype=np.int16),
np.array([1, 2, 3, 4, 5], dtype=np.int32),
np.array([1, 2, 3, 4, 5], dtype=np.int32),
np.array([1, 2, 3, 4, 5], dtype=np.uint8),
np.array([1, 2, 3, 4, 5], dtype=np.uint16),
# (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.uint64)), -> onnxruntime
np.array([1, 2, 3, 4, 5], dtype=np.float32),
np.array([1, 2, 3, 4, 5], dtype=np.float64),
],
)
def test_argmaxmin(func, x, keepdims):
Expand All @@ -1020,11 +1024,12 @@ def test_argmaxmin(func, x, keepdims):

# Pending ORT 1.19 conda-forge release before this becomes supported:
# https://github.com/conda-forge/onnxruntime-feedstock/pull/128
@pytest.mark.parametrize("func", [np.argmax, np.argmin])
@pytest.mark.parametrize(
"func, x",
"x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int64)),
np.array([1, 2, 3, 4, 5], dtype=np.int64),
# np.array([1, 2, 3, 4, 5], dtype=np.uint32), -> onnxruntime
],
)
def test_argmaxmin_unsupported_kernels(func, x):
Expand Down
4 changes: 4 additions & 0 deletions xfails.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# These are not implemented for uint64, but are tested internally
array_api_tests/test_searching_functions.py::test_argmin
array_api_tests/test_searching_functions.py::test_argmax
adityagoel4512 marked this conversation as resolved.
Show resolved Hide resolved

array_api_tests/test_constants.py::test_newaxis
array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_meshgrid
Expand Down
Loading