-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Questions on eqx.filter_vmap: How to batch multiple function #963
Comments
An equinox module is a Pytree of Arrays (and optionally other things), the arrays are what you can vmap over. If I understand you correctly, then you'd like your |
Hello Johanna, If I understand correctly, I perhaps have to rewrite the whole data structure, but one complicated thing is that voxel is not the top class, I also defined a 3D array-like datatype whose elements are voxel (nested). That's the hard point I have to cope with. And what do you mean by changing structure? Now I saved all the ODEs in a list and intend to perform vmap over each object method (spin.bloch) . and change the method in Voxel accordingly |
If your So you simulate something in some space that is partitioned into voxels, each of which contains several spins, and all of these spins evolve according to some ODE. Is that ODE the same for all? If it's not, can you easily express the differences with an extra argument (such as a binary mask)? If the ensemble modelling approach is impractical, an alternative you may consider is making |
PS: vmap parallelises over arguments to a callable, not over multiple callables and their arguments. |
Hello Johanna, def spin_bloch(spin, y_spin): return eqx.filter_vmap(spin_bloch)(self.spins, y) Do you have some contact/email? Perhaps we could reach some cooperation and share the coauthorship...... Thanks for discussing with me about those software architecture which might beyond the scope. |
Hi Tim, that sounds like a tricky problem, I'm guessing you have some boundary conditions too. Unfortunately I don't know of an example to point you too, and I'm afraid this thing is best designed at a whiteboard to really map out what the individual parts need to do and how much they need to know about each other. That is well beyond the scope of a GitHub issue :) Regarding some technical points: def solve(t, y, args):
...
eqx.filter_vmap(solve, in_axes=(None, 0, 0)) which will vmap both the state def solve(t, y, constant_args, varying_args):
args = (constant_args, varying_args)
...
eqx.filter_vmap(solve, in_axes=(None, 0, None, 0)) or vmap over different axes of the constant and varying arguments. Additionally, you can nest vmap as you please - there is no need to only have one of these handling your entire problem, you can place them where appropriate to parallelise at different levels and JAX will take care of resolving these. With regards to your wrapper method:
This is not a good design. It will make each spin object carry a method that solves an ODE, and since each spin object is its own instance that also means that each of these methods and calls to diffrax are their own callable, and this will a) wreck your compilation time and b) is not what vmap is meant to do. In vmap, you take a single callable and vectorise it to take many arguments and process them in a batched fashion. In your design, since every spin carries its own copy of the callable, no vectorisation is taking place. I hope that helps - I have too many projects myself at the moment, and can't do more than replying to technical questions here. |
Hello Johanna, |
Hello, I have a question regarding eqx.filter_vmap (perhaps also diffrax): How do I effectively wrap a list of eqx.Module so that
I have made a sophiscated ODEs system, the structure is like this:
class Spin is the basic class
class Spin(eqx.Module):
.....lots of attributes
def bloch(self,t,y,extra_args):
"""This is an ODE single step"""
return .....
class Voxel contains a list of Spin:
class Voxel(exq.Module):
......lots of attributes
spin_list : List[Spin]
Now I tried to wrap the function so that each element(Spin) in the spin_list can be parallelized and batched, I tried to batch in this way in the Voxel class, but it does not work :
@eqx.filter_jit
def bloch(self, t, y,**args):
So my question to ask, whether in Eqx module we have such a functionality which could parallelize or batch to a eqx object? If we don't have what should I do to make it parallelize, I knew the jax.vmap requires a jnp.array and a function to parallelize, but after the wrapping as I have shown, it seems does not work
(What I want to do is to parallelize over Spin.bloch, should I change my data structure?)
The text was updated successfully, but these errors were encountered: