Skip to content

Commit

Permalink
Fix weekly failures (#94)
Browse files Browse the repository at this point in the history
Signed-off-by: neNasko1 <nasko119@gmail.com>
  • Loading branch information
neNasko1 authored Jan 16, 2025
1 parent bcde9f0 commit 8a8b74f
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 26 deletions.
28 changes: 19 additions & 9 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -95,15 +95,23 @@ def bitwise_invert(self, x):
def bitwise_or(self, x, y):
return binary_op(x, y, opx.bitwise_or)

# TODO: ONNX standard -> not cyclic
@validate_core
def bitwise_right_shift(self, x, y):
return binary_op(
x,
y,
lambda x, y: opx.bit_shift(x, y, direction="RIGHT"),
dtypes.uint64,
)
# Since we need to perform arithmetic right-shift we have to be a bit more careful
if isinstance(x.dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)):
return binary_op(
x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64
)
elif isinstance(x.dtype, (dtypes.Integral, dtypes.NullableIntegral)):
MAX_POW = 63
pow2 = ndx.pow(ndx.asarray(2, ndx.int64), ndx.where(y > MAX_POW, 0, y))
return ndx.where(
y >= MAX_POW,
ndx.where(x >= 0, 0, -1),
ndx.floor_divide(x, pow2),
).astype(x.dtype)
else:
return NotImplemented

@validate_core
def bitwise_xor(self, x, y):
Expand Down Expand Up @@ -271,7 +279,7 @@ def positive(self, x):
def pow(self, x, y):
x, y = ndx.asarray(x), ndx.asarray(y)
dtype = ndx.result_type(x, y)
if isinstance(dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)):
if isinstance(dtype, (dtypes.Integral, dtypes.NullableIntegral)):
return binary_op(x, y, opx.pow, dtypes.int64)
else:
return binary_op(x, y, opx.pow)
Expand Down Expand Up @@ -848,6 +856,7 @@ def all(self, x, *, axis=None, keepdims: bool = False):
x = ndx.where(x.null, True, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(True, dtype=ndx.bool)
x = x.astype(ndx.bool)
return ndx.min(x.astype(ndx.int8), axis=axis, keepdims=keepdims).astype(
ndx.bool
)
Expand All @@ -858,6 +867,7 @@ def any(self, x, *, axis=None, keepdims: bool = False):
x = ndx.where(x.null, False, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(False, dtype=ndx.bool)
x = x.astype(ndx.bool)
return ndx.max(x.astype(ndx.int8), axis=axis, keepdims=keepdims).astype(
ndx.bool
)
Expand Down
4 changes: 3 additions & 1 deletion ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -60,6 +60,8 @@ def broadcast_to(self, x, shape):
return x._transmute(lambda corearray: opx.expand(corearray, shape))

def expand_dims(self, x, axis):
if axis < -x.ndim - 1 or axis > x.ndim:
raise IndexError(f"Axis must be in [-{x.ndim}, {x.ndim}]")
return x._transmute(
lambda corearray: opx.unsqueeze(
corearray, axes=opx.const([axis], dtype=dtypes.int64)
Expand Down
7 changes: 5 additions & 2 deletions ndonnx/_opset_extensions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -547,7 +547,10 @@ def getitem(
) -> _CoreArray:
if isinstance(index, _CoreArray):
if get_dtype(index) == np.bool_:
if get_rank(corearray) < get_rank(index):
if get_rank(corearray) < get_rank(index) or any(
isinstance(xs, int) and isinstance(ks, int) and ks not in (xs, 0)
for xs, ks in zip(get_shape(corearray), get_shape(index))
):
raise IndexError("Indexing with boolean array cannot happen")
return getitem_null(corearray, index)
else:
Expand Down
51 changes: 37 additions & 14 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -996,18 +996,22 @@ def test_no_unsafe_cumulative_sum_cast():


@pytest.mark.parametrize("keepdims", [True, False])
@pytest.mark.parametrize("func", [np.argmax, np.argmin])
@pytest.mark.parametrize(
"func, x",
"x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int32)),
(np.argmax, np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int8)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmin, np.array([[-11, 2, 3], [4, 5, -6]], dtype=np.int32)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int16)),
np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32),
np.array([[-11, 2, 3], [4, 5, -6]], dtype=np.int32),
# Test all types
np.array([1, 2, 3, 4, 5], dtype=np.int8),
np.array([1, 2, 3, 4, 5], dtype=np.int16),
np.array([1, 2, 3, 4, 5], dtype=np.int32),
np.array([1, 2, 3, 4, 5], dtype=np.int32),
np.array([1, 2, 3, 4, 5], dtype=np.uint8),
np.array([1, 2, 3, 4, 5], dtype=np.uint16),
# (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.uint64)), -> onnxruntime
np.array([1, 2, 3, 4, 5], dtype=np.float32),
np.array([1, 2, 3, 4, 5], dtype=np.float64),
],
)
def test_argmaxmin(func, x, keepdims):
Expand All @@ -1020,11 +1024,12 @@ def test_argmaxmin(func, x, keepdims):

# Pending ORT 1.19 conda-forge release before this becomes supported:
# https://github.com/conda-forge/onnxruntime-feedstock/pull/128
@pytest.mark.parametrize("func", [np.argmax, np.argmin])
@pytest.mark.parametrize(
"func, x",
"x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int64)),
np.array([1, 2, 3, 4, 5], dtype=np.int64),
# np.array([1, 2, 3, 4, 5], dtype=np.uint32), -> onnxruntime
],
)
def test_argmaxmin_unsupported_kernels(func, x):
Expand All @@ -1038,3 +1043,21 @@ def test_argmaxmin_unsupported_kernels(func, x):

with pytest.raises(TypeError):
getattr(ndx, func.__name__)(ndx.asarray(x))


@pytest.mark.parametrize(
"x, index",
[
(
ndx.asarray([1, 2, 3, 4, 5]),
ndx.asarray([[True, True, False, False, True]], dtype=ndx.bool),
),
(
ndx.asarray([1, 2, 3, 4, 5]),
ndx.asarray([True, False, False, True], dtype=ndx.bool),
),
],
)
def test_getitem_bool_raises(x, index):
with pytest.raises(IndexError):
x[index]
4 changes: 4 additions & 0 deletions xfails.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# These are not implemented for uint64, but are tested internally
array_api_tests/test_searching_functions.py::test_argmin
array_api_tests/test_searching_functions.py::test_argmax

array_api_tests/test_constants.py::test_newaxis
array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_meshgrid
Expand Down

0 comments on commit 8a8b74f

Please sign in to comment.