Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap can't be used with equinox module as op #15

Open
dimitriye98 opened this issue Dec 22, 2024 · 2 comments · May be fixed by #16
Open

vmap can't be used with equinox module as op #15

dimitriye98 opened this issue Dec 22, 2024 · 2 comments · May be fixed by #16

Comments

@dimitriye98
Copy link

Attempting to use einx.vmap with an equinox module as the op argument will crash if the module has any parameters (as jax arrays are unhashable).

This is particularly problematic, as vmapping over a module is quite common in idiomatic equinox code with things like the built-in Embedding module more or less requiring it.

@dimitriye98
Copy link
Author

After some additional investigation, a potential path towards a solution: when tracing with Jax backend, prior to hitting the cache with the input shapes and function, check if the function is a bound method. If so, check if the object it's bound to is a pytree. If so, map over the pytree replacing all tensors with their shapes, and then use this and the unbound method as the cache key instead of using the bound method. I'll try and put together a PR in the coming days if I have time.

@dimitriye98 dimitriye98 linked a pull request Dec 23, 2024 that will close this issue
@fferflo
Copy link
Owner

fferflo commented Jan 13, 2025

Thanks for bringing this up! I don't think the proposed solution would work here though. The op argument isn't passed through to the compiled function each time einx.vmap is called since it is assumed to be static. Instead, a reference to the first op argument for a given signature is stored in the cache, and used for all subsequent calls with the same signature. This isn't a problem if op is identical in all calls, but that wouldn't be the case here.

For example, if there are different instances of the same equinox class that are passed to op, the one that einx.vmap is first called with (including the Jax arrays stored in it) will be used in all subsequent calls, even if a different module instance with different Jax arrays (but the same shapes/ signature) is actually passed in a later call.

I think the solution would be to also make op a traced argument (or at least the arrays stored in it), so that it isn't stored in the compiled function and is instead treated like any other dynamic argument.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants