Skip to content

Commit

Permalink
Comments after code-review
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Jan 14, 2025
1 parent 3e22b77 commit dbc9304
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
16 changes: 10 additions & 6 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,16 @@ def bitwise_right_shift(self, x, y):
return binary_op(
x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64
)
MAX_POW = 63
return ndx.where(
y >= MAX_POW,
ndx.where(x >= 0, 0, -1),
ndx.floor_divide(x, ndx.pow(ndx.asarray(2, ndx.int64), y)),
).astype(x.dtype)
elif isinstance(x.dtype, (dtypes.Integral, dtypes.NullableIntegral)):
MAX_POW = 63
pow2 = ndx.pow(ndx.asarray(2, ndx.int64), ndx.where(y > MAX_POW, 0, y))
return ndx.where(
y >= MAX_POW,
ndx.where(x >= 0, 0, -1),
ndx.floor_divide(x, pow2),
).astype(x.dtype)
else:
return NotImplemented

@validate_core
def bitwise_xor(self, x, y):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,3 +1043,21 @@ def test_argmaxmin_unsupported_kernels(func, x):

with pytest.raises(TypeError):
getattr(ndx, func.__name__)(ndx.asarray(x))


@pytest.mark.parametrize(
"x, index",
[
(
ndx.asarray([1, 2, 3, 4, 5]),
ndx.asarray([[True, True, False, False, True]], dtype=ndx.bool),
),
(
ndx.asarray([1, 2, 3, 4, 5]),
ndx.asarray([True, False, False, True], dtype=ndx.bool),
),
],
)
def test_getitem_bool_raises(x, index):
with pytest.raises(IndexError):
x[index]

0 comments on commit dbc9304

Please sign in to comment.