-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question: States, vmap
-construction-of-layers and lax.scan
'ing.
#934
Comments
Hi Jed, welcome to Equinox :) Your code runs if the last lines are something like this ...
m, state = eqx.nn.make_with_state(OddModel)(key)
y, state = jax.vmap(m)(x, state)
print(y, y.shape) |
Hi Johanna! Thank you for the reply! I might be missing something here... If I change the batch size e.g. It seems like the state (returned by This means when I Cheers! |
A few points based on my understanding of states. First is that the substate doesn't do anything here (since you get the substate based on the layers, which already contain the whole state). Second is that the key issue is that you are scanning over states, but not actually scanning over them (i.e. you have a (3, 4) state and pass that in each time, when you really want to pass in a (4,) state each time). Since the docs example is vmaping over them, but here you are scanning over them, it's slightly different. I wrote something that approached it from that direction which seemed to work (I dislike the manual state interaction, it seems brutally inelegant, so I will update it if I think of something more clever). import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
class WeirdLayer(eqx.Module):
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
index: eqx.nn.StateIndex
def __init__(self, key):
self.linear1 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
self.linear2 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
self.index = eqx.nn.StateIndex(jnp.zeros((4,)))
def __call__(self, x, state):
y = state.get(self.index)
print("y", y.shape)
x = self.linear1(x)
x = jax.nn.relu(x)
x = x + y
print("x", x.shape) # This becomes 3, 4 when this layer is used in OddModel, and therefore breaks!
new_state = state.set(self.index, x)
x = self.linear2(x)
return x, new_state
class Model(eqx.Module):
layers: list[WeirdLayer]
def __init__(self, key):
"""
This model does the normal building of layers with a loop
list of layers and for loop in __call__
"""
self.layers = [WeirdLayer(key) for _ in range(3)]
def __call__(self, x, state):
for l in self.layers:
x, state = l(x, state)
return x, state
key = jax.random.key(42)
# Use the normal model, see what happens
x = jnp.ones((3, 4))
m, state = eqx.nn.make_with_state(Model)(key)
states = eqx.filter_vmap(lambda: state, axis_size=len(x))()
y, state = jax.vmap(m)(x, states)
print("Out:", y.shape) # All is well
class OddModel(eqx.Module):
layers: list[WeirdLayer]
def __init__(self, key):
"""
This model uses vmap'd construction of layers
and a lax.scan over them in __call__
"""
keys = jr.split(key, 3)
self.layers = eqx.filter_vmap(lambda key: WeirdLayer(key))(keys)
def __call__(self, x, state):
all_params, static = eqx.partition(self.layers, eqx.is_array)
def _step(x_and_state, params):
x, state, layer_ind = x_and_state
layer = eqx.combine(params, static)
substate = jax.tree.map(lambda x: x[layer_ind], state)
x, substate = layer(x, substate)
substate = jax.tree.map(lambda x, y: x.at[layer_ind].set(y), state, substate)
state = state.update(substate)
return (x, state, layer_ind + 1), None
(x, state, _), _ = jax.lax.scan(_step, (x, state, 0), all_params)
return x, state
x = jnp.ones((7, 4))
m, state = eqx.nn.make_with_state(OddModel)(key)
print(state)
y = jax.vmap(m, in_axes=(0, None))(x, state) # No Error!
print(y) |
@lockwo - thank you for this insight! This was what I was trying just now (didn't work out the tree mapping, been confused by tracers and what not...). If it's useful, I tried partitioning the state in the same way as the layers (e.g. I guess that's quite obvious but you can't use an Thank you both! I will also let you know if I get this fixed :) |
This surprises me. I think this is the correct solution! You have a different state for each iteration of the scan. I think @lockwo's example is nearly correct, and you just need to modify his final def __call__(self, x, all_state):
all_params, static = eqx.partition(self.layers, eqx.is_array)
def _step(x, params__state):
params, state = params__state
layer = eqx.combine(params, static)
x, state = layer(x, state)
return x, state
x, all_state = jax.lax.scan(_step, x, (all_params, all_state))
return x, all_state On which note, heads-up @lockwo that if you were to take a |
Ah yes, that's good, I forgot you don't need to call Thanks for the bug reminder too ;), hopefully by the time I remember all the bugs in XLA I will also be putting bounties on them. |
Awesome stuff, I was hoping you'd see this @patrick-kidger :). Yeah I tried the param Thanks everyone, I learnt a lot here! |
Hello!
I'm having some trouble with states and the construction of a model that
filter_vmap
-constructs layers.I've checked your FAQ's and stateful-ops page, but I can't get this model working as I expect...
In the normal
Model
I build a model as you'd expect. I batch the states for application to a batch of data. All good.in the
OddModel
I try to do the same thing, butvmap
construct my layers andlax.scan
over them during my__call__
. This causes an error because the state is batched when it is returned with an initialeqx.nn.make_with_state
.This error is expected, but how do I ensure my
OddModel
works as expected? I've tried sub-stating but I'm confused because I don't actually need tovmap
the layers (which as I understand is what a substate is for).What am I missing? Thanks as always!
(This is an MWE of a very large transformer-based model that uses k-v caches parameterised with
eqx.nn.States
)The text was updated successfully, but these errors were encountered: