Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

filter_vmap in_axes tree structure #952

Open
vadmbertr opened this issue Feb 19, 2025 · 2 comments
Open

filter_vmap in_axes tree structure #952

vadmbertr opened this issue Feb 19, 2025 · 2 comments
Labels
question User queries

Comments

@vadmbertr
Copy link

Hi!

I have a question about the tree structure of the in_axes argument of filter_vmap.
Consider the following example:

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array


class Timeseries(eqx.Module):
    value: Array
    time: Array
    name: str


class TimeseriesEnsemble(eqx.Module):
    members: Timeseries


ts_ens = TimeseriesEnsemble(
    members=Timeseries(
        value=jnp.ones((10, 20)),  # 10 members, 20 time steps
        time=jnp.arange(20),
        name="some timeseries"
    )
)

in_axes = eqx.filter(ts_ens.members, False)
in_axes = eqx.tree_at(lambda x: x.value, in_axes, 0, is_leaf=lambda x: x is None)
res = eqx.filter_vmap(lambda x: jnp.mean(x.value), in_axes=(in_axes,))(ts_ens.members)

where I want to vmap over the members of my ensemble (so the first dimension of value).
It puzzles me a bit that I have to use in_axes=(in_axes,) rather than in_axes=in_axes as the vmaped function only takes a single argument.

I was wondering if it is intended or if it is some sort of a bug? (I read in the docs Its tree structure should either be: a. a prefix of the input tuple of args. ... so it might be intended but I'm not quite sure)

Thanks a lot for this great library!
Vadim

@johannahaffner
Copy link

Hi Vadim,

eqx.filter_vmap provides some extra functionality while wrapping jax.vmap (here). And jax.vmap will expect int | None | Sequence[Any]. Your time series has to be the latter to pass muster.

I believe that the CustomNode mismatch error you are seeing is "created" here:

in_axes=(in_axes,),

and I don't know if this is required for other cases to work properly. In your case, it masks the underlying error, which is that jax.vmap would not accept Timeseries as an axis anyway. (If you try your code above with jax.vmap instead, you will get a TypeError stating as much.)

@vadmbertr
Copy link
Author

Hi Johanna,

Thanks for the reply!

eqx.filter_vmap provides some extra functionality while wrapping jax.vmap (here). And jax.vmap will expect int | None | Sequence[Any]. Your time series has to be the latter to pass muster.

Indeed, I realize that it falls under the Sequence[Any] type condition, so it makes sense having to pass (in_axes,).

I believe that the CustomNode mismatch error you are seeing is "created" here:

equinox/equinox/_vmap_pmap.py

Line 171 in 1b507b9
in_axes=(in_axes,),

and I don't know if this is required for other cases to work properly. In your case, it masks the underlying error, which is that jax.vmap would not accept Timeseries as an axis anyway. (If you try your code above with jax.vmap instead, you will get a TypeError stating as much.)

I see, thank you for the clarification. The CustomNode error led me to the wrong direction.

@patrick-kidger patrick-kidger added the question User queries label Feb 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants