-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix GPU sort for large arrays #1285
Conversation
@@ -1840,6 +1840,15 @@ def test_sort(self): | |||
self.assertTrue(np.array_equal(c_np, c_mx)) | |||
self.assertEqual(b_mx.dtype, c_mx.dtype) | |||
|
|||
# Test very large array | |||
if mx.default_device() == mx.gpu: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it just too slow for CI on CPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into this!
@@ -1785,15 +1785,6 @@ array sort(const array& a, int axis, StreamOrDevice s /* = {} */) { | |||
throw std::invalid_argument(msg.str()); | |||
} | |||
|
|||
// TODO: Fix GPU kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is worth a warning here for sorts greater than maxint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
array::shape
is made of ints, so each of those will be <= INT_MAX
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good point -- nice!
Thanks for the quick fix 🚀 |
I noticed that the Lines 1845 to 1847 in e9e5385
|
That looks like an oversight. Will fix for the next release. |
(Fixed in #1289 ) |
Proposed changes
Fixes #1254 - GPU sort should now support arbitrarily large arrays (as long as their shape is contained in
int
)Added relevant test
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes