Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions on eqx.filter_vmap: How to batch multiple function #963

Open
timnotavailable opened this issue Feb 27, 2025 · 7 comments
Open
Labels
question User queries

Comments

@timnotavailable
Copy link

Hello, I have a question regarding eqx.filter_vmap (perhaps also diffrax): How do I effectively wrap a list of eqx.Module so that

I have made a sophiscated ODEs system, the structure is like this:

class Spin is the basic class

class Spin(eqx.Module):
.....lots of attributes

def bloch(self,t,y,extra_args):
"""This is an ODE single step"""
return .....

class Voxel contains a list of Spin:

class Voxel(exq.Module):
......lots of attributes
spin_list : List[Spin]

  1. Now I tried to wrap the function so that each element(Spin) in the spin_list can be parallelized and batched, I tried to batch in this way in the Voxel class, but it does not work :
    @eqx.filter_jit
    def bloch(self, t, y,**args):

     def spin_bloch(spin, y_spin):
         return spin.bloch(t, y_spin, **args)
    
     return eqx.filter_vmap(spin_bloch)(self.spins, y)
    

So my question to ask, whether in Eqx module we have such a functionality which could parallelize or batch to a eqx object? If we don't have what should I do to make it parallelize, I knew the jax.vmap requires a jnp.array and a function to parallelize, but after the wrapping as I have shown, it seems does not work
(What I want to do is to parallelize over Spin.bloch, should I change my data structure?)

  1. I plan to use diffrax solver to simulate the ODEs, now whether I need to initialize solver in each Spin class? Or simply initialize one solver and batch over it?
@johannahaffner
Copy link

An equinox module is a Pytree of Arrays (and optionally other things), the arrays are what you can vmap over. eqx.filter_vmap is a relatively thin wrapper over jax.vmap, so pretty much the same constraints apply.

If I understand you correctly, then you'd like your bloch method to be the thing that gets vmapped. In this case you should change your data structure, similar to how it is done in the ensembling example, and only call the model in a vmapped region. The method performing your ODE integration is a callable that will be jitted, and if you have many of these the vmap principle is not met (which is that you have one function and apply it many times).
Would that work for your use case? Avoiding many instances of potentially costly functions would definitely be good practice.

@patrick-kidger patrick-kidger added the question User queries label Feb 27, 2025
@timnotavailable
Copy link
Author

timnotavailable commented Feb 27, 2025

Hello Johanna,
thanks for the quick reply. I've checked the example provided. Yes your understanding is correct, I'm going to vmap the bloch method in each Spin class and define this in voxel method.
This is in essence to define ODEs system, which I tried to find examples on diffrax but not found a proper example or issue in the discussion. Some of these ODEs (Spin class) share some common parameters in the Voxel. Another thorny point is, these ODEs have different extra_args, which makes changing the data structure a bit harder.

If I understand correctly, I perhaps have to rewrite the whole data structure, but one complicated thing is that voxel is not the top class, I also defined a 3D array-like datatype whose elements are voxel (nested). That's the hard point I have to cope with. And what do you mean by changing structure? Now I saved all the ODEs in a list and intend to perform vmap over each object method (spin.bloch) .
I might have to change according to your example:
In the Spin class
@eqx.filter_vmap(in_axes=0)
bloch()
.....
return

and change the method in Voxel accordingly

@johannahaffner
Copy link

If your bloch method is the same for each spin, then you really want to avoid creating multiple instances thereof, these will all get traced separately, which will be really bad for your compilation times.

So you simulate something in some space that is partitioned into voxels, each of which contains several spins, and all of these spins evolve according to some ODE. Is that ODE the same for all? If it's not, can you easily express the differences with an extra argument (such as a binary mask)?

If the ensemble modelling approach is impractical, an alternative you may consider is making bloch a function that takes a spin object instead of making it a method, and call it from elsewhere as appropriate (maybe from Voxel?).

@johannahaffner
Copy link

Now I saved all the ODEs in a list and intend to perform vmap over each object method (spin.bloch)

PS: vmap parallelises over arguments to a callable, not over multiple callables and their arguments.

@timnotavailable
Copy link
Author

Hello Johanna,
The ODE is same for all, however the dynamics(meaning that extra argument) of the ODEs are different. The tricky part is not only one Voxel has many attributes, but also, Spins in the same voxel have some same attributes so that they becomes hard to batch (vmap) with a single array. (The reason why I let them have same attributes is because I want those attributes of the Spin in voxel have the same gradients)
From the talk I think the first priority to solve this problem is, trying to use one ODE system function so that vmap that ODE system function. Another possibility is like you said, making bloch as a function and takes the spin object. This is what I managed to do in the first post:
(In the Voxel class I define such a wrapper method)
@eqx.filter_jit
def bloch(self, t, y,**args):

def spin_bloch(spin, y_spin):
return spin.bloch(t, y_spin, **args)

return eqx.filter_vmap(spin_bloch)(self.spins, y)

Do you have some contact/email? Perhaps we could reach some cooperation and share the coauthorship...... Thanks for discussing with me about those software architecture which might beyond the scope.

@johannahaffner
Copy link

johannahaffner commented Feb 28, 2025

Hi Tim,

that sounds like a tricky problem, I'm guessing you have some boundary conditions too. Unfortunately I don't know of an example to point you too, and I'm afraid this thing is best designed at a whiteboard to really map out what the individual parts need to do and how much they need to know about each other. That is well beyond the scope of a GitHub issue :)

Regarding some technical points:
You can vmap over arguments as well (provided they're the right type), vmap is not restricted to the first argument at all. So you can absolutely do

def solve(t, y, args):
      ...

eqx.filter_vmap(solve, in_axes=(None, 0, 0))

which will vmap both the state y and the arguments over their first axis (you can also specify another axis). If you need to aggregate multiple arguments that may vary from spin to spin, you can also do something like

def solve(t, y, constant_args, varying_args):
       args = (constant_args, varying_args)
       ...

eqx.filter_vmap(solve, in_axes=(None, 0, None, 0))

or vmap over different axes of the constant and varying arguments.

Additionally, you can nest vmap as you please - there is no need to only have one of these handling your entire problem, you can place them where appropriate to parallelise at different levels and JAX will take care of resolving these.

With regards to your wrapper method:

def spin_bloch(spin, y_spin):
return spin.bloch(t, y_spin, **args)

This is not a good design. It will make each spin object carry a method that solves an ODE, and since each spin object is its own instance that also means that each of these methods and calls to diffrax are their own callable, and this will a) wreck your compilation time and b) is not what vmap is meant to do. In vmap, you take a single callable and vectorise it to take many arguments and process them in a batched fashion. In your design, since every spin carries its own copy of the callable, no vectorisation is taking place.
Does that make sense?

I hope that helps - I have too many projects myself at the moment, and can't do more than replying to technical questions here.

@timnotavailable
Copy link
Author

Hello Johanna,
the answer is great and the information is enough for me to design it properly, thanks very much for your detailed instruction and help! At least in the acknowledgement I will mention it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants