-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vmap can't be used with equinox module as op #15
Comments
After some additional investigation, a potential path towards a solution: when tracing with Jax backend, prior to hitting the cache with the input shapes and function, check if the function is a bound method. If so, check if the object it's bound to is a pytree. If so, map over the pytree replacing all tensors with their shapes, and then use this and the unbound method as the cache key instead of using the bound method. I'll try and put together a PR in the coming days if I have time. |
Thanks for bringing this up! I don't think the proposed solution would work here though. The For example, if there are different instances of the same equinox class that are passed to I think the solution would be to also make |
Attempting to use
einx.vmap
with an equinox module as the op argument will crash if the module has any parameters (as jax arrays are unhashable).This is particularly problematic, as vmapping over a module is quite common in idiomatic equinox code with things like the built-in Embedding module more or less requiring it.
The text was updated successfully, but these errors were encountered: