-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Einsum op #240
Einsum op #240
Conversation
I think this PR is quite cool. I have some performance concerns. Should we use something like opt_einsum for computing the optimal einsum computation path? If we go this route, we maybe need only to implement tensordot and similar operators. |
Not too familiar with this lib but It looks like opt_einsum builds on top of einsum from various libraries, is it intended to be used for an einsum implementation? |
Optimized einsum is agnostic to the backend. This means we could just implement the wrapper that calls MLX backend. It is used in JAX einsum implementation (https://jax.readthedocs.io/en/latest/_modules/jax/_src/numpy/lax_numpy.html#einsum). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be great to have! Thanks for adding it! I left a few comments as a start.
I think the main thing is if/when/how we dispatch to matmul. O/w einsum will be really slow in cases when it should not be.
i += 1; | ||
} | ||
|
||
auto acc = ones_like(inputs_arr.at(0), s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Start with inputs_arr[0]
rather than include a new array for accumulation. There should be at least one input right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually would be good to check for that at the top and throw if it's right, and maybe add a test case that check that we throw for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strangely this does not work, if i switch it to use the first array and then accumulate it gets some strange results on the test cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you ever get to the bottom of this? Let me know I can take a look as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be helpful, I messed with it a little but didn't find anything concrete.
I tried calling eval on inputs_arr.at(0)
to see if maybe certain ops didnt take affect but it didn't resolve it
mlx/ops.cpp
Outdated
for (auto arr : inputs_arr) { | ||
acc = multiply(acc, arr, s); | ||
} | ||
return transpose(sum(acc, sum_axis, false, s), rhs_order, s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some/many common cases we should dispatch to matmul
rather than multiply
and sum
as it will be a lot slower and memory inefficient. Is it possible to include that logic (or some of it) now?
After messing this a little more there are still some edge cases that this doesnt cover
Will close for now and re-open when its in a better spot |
@dc-dc-dc FYI I think it was a great PR! You could have changed it to draft to keep working on it in the open and get feedback. @awni let me know what you think and sorry to both that I didn't get to reviewing this earlier. I started a review yesterday and looked into what other frameworks implement, what opt_einsum expects etc. My two cents regarding how I would approach it are
Just to make myself clearer, for the following contraction for instance
I think summing over axes that only appear on one input (and not the output) is straightforward. Subsequently, we have the following contractions As an aside, if you want to tackle something simpler to start with (again I think your PR was great!), you could implement |
I wouldve preferred to switch to draft but didnt see an option to switch to draft so closed for now. But as I was messing with it more I noticed some more missed cases that need to be covered. Might take a quick stab at dot / tensordot and come back to this. But in the meantime if someone wants to take this as a base and continue forward, you have my approval 😄 |
Proposed changes
Adds einsum op
Step closer to adding einops support as mentioned here #172
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes