Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Einsum op #240

Closed
wants to merge 10 commits into from
Closed

Einsum op #240

wants to merge 10 commits into from

Conversation

dc-dc-dc
Copy link
Contributor

Proposed changes

Adds einsum op

Step closer to adding einops support as mentioned here #172

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@gboduljak
Copy link
Contributor

I think this PR is quite cool. I have some performance concerns. Should we use something like opt_einsum for computing the optimal einsum computation path? If we go this route, we maybe need only to implement tensordot and similar operators.

@dc-dc-dc
Copy link
Contributor Author

dc-dc-dc commented Dec 22, 2023

Not too familiar with this lib but It looks like opt_einsum builds on top of einsum from various libraries, is it intended to be used for an einsum implementation?

@gboduljak
Copy link
Contributor

gboduljak commented Dec 24, 2023

Not too familiar with this lib but It looks like opt_einsum builds on top of einsum from various libraries, is it intended to be used for an einsum implementation?

Optimized einsum is agnostic to the backend. This means we could just implement the wrapper that calls MLX backend. It is used in JAX einsum implementation (https://jax.readthedocs.io/en/latest/_modules/jax/_src/numpy/lax_numpy.html#einsum).

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be great to have! Thanks for adding it! I left a few comments as a start.

I think the main thing is if/when/how we dispatch to matmul. O/w einsum will be really slow in cases when it should not be.

i += 1;
}

auto acc = ones_like(inputs_arr.at(0), s);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Start with inputs_arr[0] rather than include a new array for accumulation. There should be at least one input right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually would be good to check for that at the top and throw if it's right, and maybe add a test case that check that we throw for that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strangely this does not work, if i switch it to use the first array and then accumulate it gets some strange results on the test cases

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you ever get to the bottom of this? Let me know I can take a look as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be helpful, I messed with it a little but didn't find anything concrete.
I tried calling eval on inputs_arr.at(0) to see if maybe certain ops didnt take affect but it didn't resolve it

mlx/ops.cpp Outdated
Comment on lines 2844 to 2886
for (auto arr : inputs_arr) {
acc = multiply(acc, arr, s);
}
return transpose(sum(acc, sum_axis, false, s), rhs_order, s);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some/many common cases we should dispatch to matmul rather than multiply and sum as it will be a lot slower and memory inefficient. Is it possible to include that logic (or some of it) now?

@dc-dc-dc
Copy link
Contributor Author

After messing this a little more there are still some edge cases that this doesnt cover

  • fastpath for matmul
  • ij,jk

Will close for now and re-open when its in a better spot

@dc-dc-dc dc-dc-dc closed this Dec 28, 2023
@angeloskath
Copy link
Member

@dc-dc-dc FYI I think it was a great PR! You could have changed it to draft to keep working on it in the open and get feedback. @awni let me know what you think and sorry to both that I didn't get to reviewing this earlier.

I started a review yesterday and looked into what other frameworks implement, what opt_einsum expects etc. My two cents regarding how I would approach it are

  • splitting it to pairwise contractions
  • sorting the contractions based on some cost then run them
  • almost all implementation use a greedy search which yields near optimal results and is much much simpler to implement

Just to make myself clearer, for the following contraction for instance ijk,km,k->im the ideal einsum would perform the following

op0 = op0.sum(1)
op1 = op1 * op2[:, None]  # this could also be op0 * op2[None] depending on the sizes of op0, op1
return op0 @ op1

I think summing over axes that only appear on one input (and not the output) is straightforward. Subsequently, we have the following contractions ik,km->ikm, ik,k->ik, km,k->km. Each contraction should keep all axes that appear in the result or other arguments. From these it is obvious that the last two are the fastest. We could use a naive FLOPS estimator to do the sorting (namely product of size of all axes in the contraction).

As an aside, if you want to tackle something simpler to start with (again I think your PR was great!), you could implement tensordot. Then you would already have a baseline for einsum from opt_einsum since we could very easily implement a backend for that to test against our own einsum.

@dc-dc-dc
Copy link
Contributor Author

I wouldve preferred to switch to draft but didnt see an option to switch to draft so closed for now. But as I was messing with it more I noticed some more missed cases that need to be covered. Might take a quick stab at dot / tensordot and come back to this. But in the meantime if someone wants to take this as a base and continue forward, you have my approval 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants