Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GPU sort for large arrays #1285

Merged
merged 3 commits into from
Jul 24, 2024
Merged

Fix GPU sort for large arrays #1285

merged 3 commits into from
Jul 24, 2024

Conversation

jagrit06
Copy link
Member

Proposed changes

Fixes #1254 - GPU sort should now support arbitrarily large arrays (as long as their shape is contained in int)
Added relevant test

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@jagrit06 jagrit06 requested review from awni and barronalex July 24, 2024 21:16
@@ -1840,6 +1840,15 @@ def test_sort(self):
self.assertTrue(np.array_equal(c_np, c_mx))
self.assertEqual(b_mx.dtype, c_mx.dtype)

# Test very large array
if mx.default_device() == mx.gpu:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it just too slow for CI on CPU?

Copy link
Collaborator

@barronalex barronalex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this!

@@ -1785,15 +1785,6 @@ array sort(const array& a, int axis, StreamOrDevice s /* = {} */) {
throw std::invalid_argument(msg.str());
}

// TODO: Fix GPU kernel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is worth a warning here for sorts greater than maxint?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

array::shape is made of ints, so each of those will be <= INT_MAX

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point -- nice!

@jagrit06 jagrit06 merged commit 7f91436 into main Jul 24, 2024
3 checks passed
@jagrit06 jagrit06 deleted the sort_fix branch July 24, 2024 21:37
@awni
Copy link
Member

awni commented Jul 24, 2024

Thanks for the quick fix 🚀

@kemchenj
Copy link

kemchenj commented Jul 26, 2024

I noticed that the argsort still has the 2M limit. If I understand correctly, it should works for more than 2M now.

mlx/mlx/ops.cpp

Lines 1845 to 1847 in e9e5385

msg << "[argsort] GPU sort cannot handle sort axis of >= 2M elements,"
<< " got array with sort axis size " << a.shape(axis) << "."
<< " Please place this operation on the CPU instead.";

@awni
Copy link
Member

awni commented Jul 26, 2024

That looks like an oversight. Will fix for the next release.

@awni
Copy link
Member

awni commented Jul 26, 2024

(Fixed in #1289 )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature] Enable Metal argsort, sort for > 2M elements along an axis
4 participants