diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index e22205c..f67c8ec 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -102,12 +102,16 @@ def bitwise_right_shift(self, x, y): return binary_op( x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64 ) - MAX_POW = 63 - return ndx.where( - y >= MAX_POW, - ndx.where(x >= 0, 0, -1), - ndx.floor_divide(x, ndx.pow(ndx.asarray(2, ndx.int64), y)), - ).astype(x.dtype) + elif isinstance(x.dtype, (dtypes.Integral, dtypes.NullableIntegral)): + MAX_POW = 63 + pow2 = ndx.pow(ndx.asarray(2, ndx.int64), ndx.where(y > MAX_POW, 0, y)) + return ndx.where( + y >= MAX_POW, + ndx.where(x >= 0, 0, -1), + ndx.floor_divide(x, pow2), + ).astype(x.dtype) + else: + return NotImplemented @validate_core def bitwise_xor(self, x, y): diff --git a/tests/test_core.py b/tests/test_core.py index 9c1e81e..81a9547 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1043,3 +1043,21 @@ def test_argmaxmin_unsupported_kernels(func, x): with pytest.raises(TypeError): getattr(ndx, func.__name__)(ndx.asarray(x)) + + +@pytest.mark.parametrize( + "x, index", + [ + ( + ndx.asarray([1, 2, 3, 4, 5]), + ndx.asarray([[True, True, False, False, True]], dtype=ndx.bool), + ), + ( + ndx.asarray([1, 2, 3, 4, 5]), + ndx.asarray([True, False, False, True], dtype=ndx.bool), + ), + ], +) +def test_getitem_bool_raises(x, index): + with pytest.raises(IndexError): + x[index]