-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix weekly failures #94
Conversation
Signed-off-by: neNasko1 <nasko119@gmail.com>
Signed-off-by: neNasko1 <nasko119@gmail.com>
Signed-off-by: neNasko1 <nasko119@gmail.com>
ndonnx/_core/_numericimpl.py
Outdated
x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64 | ||
) | ||
MAX_POW = 63 | ||
return ndx.where( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please note that where
executes both branches so you may still run into a division by zero issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which division by zero are you referring to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 109 due to an overflow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair point, I suspected the same but second-guessed it when the test looked like it passed! The explanation for why it silently succeed looks to be that ndonnx
(or really onnxruntime) and NumPy
overflowed quite differently :/.
(Pdb) np.__version__
'2.1.3'
(Pdb) ndx.__version__
'0.0.post75+g3e22b77'
(Pdb) np.pow(np.asarray(2, np.int64), 70)
np.int64(0)
(Pdb) ndx.pow(ndx.asarray(2, ndx.int64), 70)
Array(9223372036854775807, dtype=Int64)
But yes, let's not depend on this @neNasko1. I'm pretty sure you can just extract the sign bit and manually roll this with the logical right shift. Additionally, this should simply not work for floating point arrays. We'll improve this with typed arrays but for now you'll need to guard that.
I think the following should help avoid division.
- MAX_POW = 63
- return ndx.where(
- y >= MAX_POW,
- ndx.where(x >= 0, 0, -1),
- ndx.floor_divide(x, ndx.pow(ndx.asarray(2, ndx.int64), y)),
- ).astype(x.dtype)
+ elif isinstance(x.dtype, (dtypes.Integral, dtypes.NullableIntegral)):
+ mask = ndx.where(
+ x >= 0,
+ ndx.asarray(0, dtype=x.dtype),
+ (((ndx.asarray(1, dtype=ndx.uint64) << y) - 1) << (ndx.iinfo(x.dtype).bits-y)).astype(x.dtype)
+ )
+ logic_shift = binary_op(
+ x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64
+ ) | mask
+ MAX_POW = 63
+ return ndx.where(
+ y >= MAX_POW,
+ ndx.where(x >= 0, 0, -1),
+ logic_shift,
+ ).astype(x.dtype)
+ else:
+ return NotImplemented
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rewrote the solution a bit, can you take a look?
ndonnx/_core/_numericimpl.py
Outdated
x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64 | ||
) | ||
MAX_POW = 63 | ||
return ndx.where( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair point, I suspected the same but second-guessed it when the test looked like it passed! The explanation for why it silently succeed looks to be that ndonnx
(or really onnxruntime) and NumPy
overflowed quite differently :/.
(Pdb) np.__version__
'2.1.3'
(Pdb) ndx.__version__
'0.0.post75+g3e22b77'
(Pdb) np.pow(np.asarray(2, np.int64), 70)
np.int64(0)
(Pdb) ndx.pow(ndx.asarray(2, ndx.int64), 70)
Array(9223372036854775807, dtype=Int64)
But yes, let's not depend on this @neNasko1. I'm pretty sure you can just extract the sign bit and manually roll this with the logical right shift. Additionally, this should simply not work for floating point arrays. We'll improve this with typed arrays but for now you'll need to guard that.
I think the following should help avoid division.
- MAX_POW = 63
- return ndx.where(
- y >= MAX_POW,
- ndx.where(x >= 0, 0, -1),
- ndx.floor_divide(x, ndx.pow(ndx.asarray(2, ndx.int64), y)),
- ).astype(x.dtype)
+ elif isinstance(x.dtype, (dtypes.Integral, dtypes.NullableIntegral)):
+ mask = ndx.where(
+ x >= 0,
+ ndx.asarray(0, dtype=x.dtype),
+ (((ndx.asarray(1, dtype=ndx.uint64) << y) - 1) << (ndx.iinfo(x.dtype).bits-y)).astype(x.dtype)
+ )
+ logic_shift = binary_op(
+ x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64
+ ) | mask
+ MAX_POW = 63
+ return ndx.where(
+ y >= MAX_POW,
+ ndx.where(x >= 0, 0, -1),
+ logic_shift,
+ ).astype(x.dtype)
+ else:
+ return NotImplemented
if get_rank(corearray) < get_rank(index) or any( | ||
isinstance(xs, int) and isinstance(ks, int) and ks not in (xs, 0) | ||
for xs, ks in zip(get_shape(corearray), get_shape(index)) | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a short test to ensure we IndexError
in the correct situations with boolean indexing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
This PR fixes the weekly failures of the array-api(#92) by:
argmin
andargmax
on the test suite and improving our tests asonnxruntime
does not implement them forUInt64
all
andany
for non-boolean inputsexpand_dims
andgetitem_null