Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix weekly failures #94

Merged
merged 5 commits into from
Jan 16, 2025
Merged

Conversation

neNasko1
Copy link
Contributor

@neNasko1 neNasko1 commented Jan 7, 2025

This PR fixes the weekly failures of the array-api(#92) by:

  • disabling argmin and argmax on the test suite and improving our tests as onnxruntime does not implement them for UInt64
  • reimplements right shift as the array-api defines it as an arithmetic right shift
  • fixes all and any for non-boolean inputs
  • raises errors when possible in expand_dims and getitem_null

Signed-off-by: neNasko1 <nasko119@gmail.com>
Signed-off-by: neNasko1 <nasko119@gmail.com>
Signed-off-by: neNasko1 <nasko119@gmail.com>
x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64
)
MAX_POW = 63
return ndx.where(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that where executes both branches so you may still run into a division by zero issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which division by zero are you referring to?

Copy link
Collaborator

@cbourjau cbourjau Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 109 due to an overflow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point, I suspected the same but second-guessed it when the test looked like it passed! The explanation for why it silently succeed looks to be that ndonnx (or really onnxruntime) and NumPy overflowed quite differently :/.

(Pdb) np.__version__
'2.1.3'
(Pdb) ndx.__version__
'0.0.post75+g3e22b77'
(Pdb) np.pow(np.asarray(2, np.int64), 70)
np.int64(0)
(Pdb) ndx.pow(ndx.asarray(2, ndx.int64), 70)
Array(9223372036854775807, dtype=Int64)

But yes, let's not depend on this @neNasko1. I'm pretty sure you can just extract the sign bit and manually roll this with the logical right shift. Additionally, this should simply not work for floating point arrays. We'll improve this with typed arrays but for now you'll need to guard that.

I think the following should help avoid division.

-        MAX_POW = 63
-        return ndx.where(
-            y >= MAX_POW,
-            ndx.where(x >= 0, 0, -1),
-            ndx.floor_divide(x, ndx.pow(ndx.asarray(2, ndx.int64), y)),
-        ).astype(x.dtype)
+        elif isinstance(x.dtype, (dtypes.Integral, dtypes.NullableIntegral)):
+            mask = ndx.where(
+                x >= 0,
+                ndx.asarray(0, dtype=x.dtype),
+                (((ndx.asarray(1, dtype=ndx.uint64) << y) - 1) << (ndx.iinfo(x.dtype).bits-y)).astype(x.dtype)
+            )
+            logic_shift = binary_op(
+                x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64
+            ) | mask
+            MAX_POW = 63
+            return ndx.where(
+                y >= MAX_POW,
+                ndx.where(x >= 0, 0, -1),
+                logic_shift,
+            ).astype(x.dtype)
+        else:
+            return NotImplemented

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrote the solution a bit, can you take a look?

x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64
)
MAX_POW = 63
return ndx.where(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point, I suspected the same but second-guessed it when the test looked like it passed! The explanation for why it silently succeed looks to be that ndonnx (or really onnxruntime) and NumPy overflowed quite differently :/.

(Pdb) np.__version__
'2.1.3'
(Pdb) ndx.__version__
'0.0.post75+g3e22b77'
(Pdb) np.pow(np.asarray(2, np.int64), 70)
np.int64(0)
(Pdb) ndx.pow(ndx.asarray(2, ndx.int64), 70)
Array(9223372036854775807, dtype=Int64)

But yes, let's not depend on this @neNasko1. I'm pretty sure you can just extract the sign bit and manually roll this with the logical right shift. Additionally, this should simply not work for floating point arrays. We'll improve this with typed arrays but for now you'll need to guard that.

I think the following should help avoid division.

-        MAX_POW = 63
-        return ndx.where(
-            y >= MAX_POW,
-            ndx.where(x >= 0, 0, -1),
-            ndx.floor_divide(x, ndx.pow(ndx.asarray(2, ndx.int64), y)),
-        ).astype(x.dtype)
+        elif isinstance(x.dtype, (dtypes.Integral, dtypes.NullableIntegral)):
+            mask = ndx.where(
+                x >= 0,
+                ndx.asarray(0, dtype=x.dtype),
+                (((ndx.asarray(1, dtype=ndx.uint64) << y) - 1) << (ndx.iinfo(x.dtype).bits-y)).astype(x.dtype)
+            )
+            logic_shift = binary_op(
+                x, y, lambda a, b: opx.bit_shift(a, b, direction="RIGHT"), dtypes.uint64
+            ) | mask
+            MAX_POW = 63
+            return ndx.where(
+                y >= MAX_POW,
+                ndx.where(x >= 0, 0, -1),
+                logic_shift,
+            ).astype(x.dtype)
+        else:
+            return NotImplemented

if get_rank(corearray) < get_rank(index) or any(
isinstance(xs, int) and isinstance(ks, int) and ks not in (xs, 0)
for xs, ks in zip(get_shape(corearray), get_shape(index))
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a short test to ensure we IndexError in the correct situations with boolean indexing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@adityagoel4512 adityagoel4512 merged commit 8a8b74f into Quantco:main Jan 16, 2025
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants