You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
where I want to vmap over the members of my ensemble (so the first dimension of value).
It puzzles me a bit that I have to use in_axes=(in_axes,) rather than in_axes=in_axes as the vmaped function only takes a single argument.
I was wondering if it is intended or if it is some sort of a bug? (I read in the docs Its tree structure should either be: a. a prefix of the input tuple of args. ... so it might be intended but I'm not quite sure)
Thanks a lot for this great library!
Vadim
The text was updated successfully, but these errors were encountered:
eqx.filter_vmap provides some extra functionality while wrapping jax.vmap (here). And jax.vmap will expect int | None | Sequence[Any]. Your time series has to be the latter to pass muster.
I believe that the CustomNode mismatch error you are seeing is "created" here:
and I don't know if this is required for other cases to work properly. In your case, it masks the underlying error, which is that jax.vmap would not accept Timeseries as an axis anyway. (If you try your code above with jax.vmap instead, you will get a TypeError stating as much.)
eqx.filter_vmap provides some extra functionality while wrapping jax.vmap (here). And jax.vmap will expect int | None | Sequence[Any]. Your time series has to be the latter to pass muster.
Indeed, I realize that it falls under the Sequence[Any] type condition, so it makes sense having to pass (in_axes,).
I believe that the CustomNode mismatch error you are seeing is "created" here:
and I don't know if this is required for other cases to work properly. In your case, it masks the underlying error, which is that jax.vmap would not accept Timeseries as an axis anyway. (If you try your code above with jax.vmap instead, you will get a TypeError stating as much.)
I see, thank you for the clarification. The CustomNode error led me to the wrong direction.
Hi!
I have a question about the tree structure of the
in_axes
argument offilter_vmap
.Consider the following example:
where I want to
vmap
over the members of my ensemble (so the first dimension ofvalue
).It puzzles me a bit that I have to use
in_axes=(in_axes,)
rather thanin_axes=in_axes
as thevmap
ed function only takes a single argument.I was wondering if it is intended or if it is some sort of a bug? (I read in the docs
Its tree structure should either be: a. a prefix of the input tuple of args. ...
so it might be intended but I'm not quite sure)Thanks a lot for this great library!
Vadim
The text was updated successfully, but these errors were encountered: