Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: States, vmap-construction-of-layers and lax.scan'ing. #934

Open
homerjed opened this issue Jan 16, 2025 · 7 comments
Open

Question: States, vmap-construction-of-layers and lax.scan'ing. #934

homerjed opened this issue Jan 16, 2025 · 7 comments

Comments

@homerjed
Copy link
Contributor

homerjed commented Jan 16, 2025

Hello!

I'm having some trouble with states and the construction of a model that filter_vmap-constructs layers.

I've checked your FAQ's and stateful-ops page, but I can't get this model working as I expect...

  • In the normal Model I build a model as you'd expect. I batch the states for application to a batch of data. All good.

  • in the OddModel I try to do the same thing, but vmap construct my layers and lax.scan over them during my __call__. This causes an error because the state is batched when it is returned with an initial eqx.nn.make_with_state.

This error is expected, but how do I ensure my OddModel works as expected? I've tried sub-stating but I'm confused because I don't actually need to vmap the layers (which as I understand is what a substate is for).

What am I missing? Thanks as always!

(This is an MWE of a very large transformer-based model that uses k-v caches parameterised with eqx.nn.States)

import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx 

class WeirdLayer(eqx.Module):
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear
    index: eqx.nn.StateIndex

    def __init__(self, key):
        self.linear1 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
        self.linear2 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
        self.index = eqx.nn.StateIndex(jnp.zeros((4,)))

    def __call__(self, x, state):
        y = state.get(self.index)
        x = self.linear1(x)
        x = jax.nn.relu(x)
        x = x + y
        print("x", x.shape) # This becomes 3, 4 when this layer is used in OddModel, and therefore breaks!
        new_state = state.set(self.index, x)
        x = self.linear2(x)
        return x, new_state

class Model(eqx.Module):
    layers: list[WeirdLayer]

    def __init__(self, key):
        """ 
            This model does the normal building of layers with a loop
            list of layers and for loop in __call__
        """
        self.layers = [WeirdLayer(key) for _ in range(3)]
    
    def __call__(self, x, state):
        for l in self.layers:
            x, state = l(x, state)
        return x, state

# Use the normal model, see what happens
x = jnp.ones((3, 4))
m, state = eqx.nn.make_with_state(Model)(key)
states = eqx.filter_vmap(lambda: state, axis_size=len(x))()
y, state = jax.vmap(m)(x, states)
print("Out:", y.shape) # All is well

class OddModel(eqx.Module):
    layers: list[WeirdLayer]

    def __init__(self, key):
        """ 
            This model uses vmap'd construction of layers 
            and a lax.scan over them in __call__
        """
        keys = jr.split(key, 3)
        self.layers = eqx.filter_vmap(lambda key: WeirdLayer(key))(keys)

    def __call__(self, x, state):
        all_params, static = eqx.partition(self.layers, eqx.is_array)

        def _step(x_and_state, params):
            x, state = x_and_state
            layer = eqx.combine(params, static)
            substate = state.substate(self.layers)
            x, substate = layer(x, substate)
            state = state.update(substate)
            return (x, state), None

        (x, state), _ = jax.lax.scan(_step, (x, state), all_params)
        return x, state

x = jnp.ones((3, 4))
m, state = eqx.nn.make_with_state(OddModel)(key)

jax.debug.print("{}", state) # Batched, as expected

states = eqx.filter_vmap(lambda: state, axis_size=len(x))()

y = jax.vmap(m)(x, states) # Error!
print(y)
@johannahaffner
Copy link

johannahaffner commented Jan 16, 2025

Hi Jed,

welcome to Equinox :)
The vmap operation after creating OddModel is redundant, since your model is already batched when you create it. That means that the state is batched too.

Your code runs if the last lines are something like this

...
m, state = eqx.nn.make_with_state(OddModel)(key)
y, state = jax.vmap(m)(x, state) 
print(y, y.shape)

@homerjed
Copy link
Contributor Author

Hi Johanna!

Thank you for the reply! I might be missing something here... If I change the batch size e.g. x = jnp.ones((7, 4)) the error returns.

It seems like the state (returned by make_with_state) is a batched state of a single module (meaning the batch axis of this state m, state = eqx.nn.make_with_state(OddModel)(key) has 3 elements, one for each WeirdLayer in the OddModel model). I'd expect that from the vmap'd construction of the layers in the OddModel.

This means when I vmap over the states, the batch axis of the data x and the states are not aligned. I would do this vmap since my application requires it (also this would be done e.g. for eqx.nn.BatchNorm during training.

Cheers!

@lockwo
Copy link
Contributor

lockwo commented Jan 16, 2025

A few points based on my understanding of states. First is that the substate doesn't do anything here (since you get the substate based on the layers, which already contain the whole state). Second is that the key issue is that you are scanning over states, but not actually scanning over them (i.e. you have a (3, 4) state and pass that in each time, when you really want to pass in a (4,) state each time). Since the docs example is vmaping over them, but here you are scanning over them, it's slightly different. I wrote something that approached it from that direction which seemed to work (I dislike the manual state interaction, it seems brutally inelegant, so I will update it if I think of something more clever).

import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx 

class WeirdLayer(eqx.Module):
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear
    index: eqx.nn.StateIndex

    def __init__(self, key):
        self.linear1 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
        self.linear2 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
        self.index = eqx.nn.StateIndex(jnp.zeros((4,)))

    def __call__(self, x, state):
        y = state.get(self.index)
        print("y", y.shape)
        x = self.linear1(x)
        x = jax.nn.relu(x)
        x = x + y
        print("x", x.shape) # This becomes 3, 4 when this layer is used in OddModel, and therefore breaks!
        new_state = state.set(self.index, x)
        x = self.linear2(x)
        return x, new_state

class Model(eqx.Module):
    layers: list[WeirdLayer]

    def __init__(self, key):
        """ 
            This model does the normal building of layers with a loop
            list of layers and for loop in __call__
        """
        self.layers = [WeirdLayer(key) for _ in range(3)]
    
    def __call__(self, x, state):
        for l in self.layers:
            x, state = l(x, state)
        return x, state

key = jax.random.key(42)
# Use the normal model, see what happens
x = jnp.ones((3, 4))
m, state = eqx.nn.make_with_state(Model)(key)
states = eqx.filter_vmap(lambda: state, axis_size=len(x))()
y, state = jax.vmap(m)(x, states)
print("Out:", y.shape) # All is well

class OddModel(eqx.Module):
    layers: list[WeirdLayer]

    def __init__(self, key):
        """ 
            This model uses vmap'd construction of layers 
            and a lax.scan over them in __call__
        """
        keys = jr.split(key, 3)
        self.layers = eqx.filter_vmap(lambda key: WeirdLayer(key))(keys)

    def __call__(self, x, state):
        all_params, static = eqx.partition(self.layers, eqx.is_array)

        def _step(x_and_state, params):
            x, state, layer_ind = x_and_state
            layer = eqx.combine(params, static)
            substate = jax.tree.map(lambda x: x[layer_ind], state)
            x, substate = layer(x, substate)
            substate = jax.tree.map(lambda x, y: x.at[layer_ind].set(y), state, substate)
            state = state.update(substate)
            return (x, state, layer_ind + 1), None

        (x, state, _), _ = jax.lax.scan(_step, (x, state, 0), all_params)
        return x, state

x = jnp.ones((7, 4))
m, state = eqx.nn.make_with_state(OddModel)(key)

print(state)

y = jax.vmap(m, in_axes=(0, None))(x, state) # No Error!
print(y)

@homerjed
Copy link
Contributor Author

homerjed commented Jan 16, 2025

@lockwo - thank you for this insight!

This was what I was trying just now (didn't work out the tree mapping, been confused by tracers and what not...).

If it's useful, I tried partitioning the state in the same way as the layers (e.g. state = eqx.combine(state_params, static_state)) but lax.scan' doesn't like iterating over (all_params, all_states)instead ofall_params`.

I guess that's quite obvious but you can't use an xs=jnp.arange(n_layers) to index the state/params due to the tracers involved (i'm 99% sure I tried exactly that, I've got a long notebook of trials...).

Thank you both! I will also let you know if I get this fixed :)

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jan 16, 2025

lax.scan doesn't like iterating over (all_params, all_states) instead of all_params.

This surprises me. I think this is the correct solution! You have a different state for each iteration of the scan.

I think @lockwo's example is nearly correct, and you just need to modify his final __call__ to look like this:

def __call__(self, x, all_state):
    all_params, static = eqx.partition(self.layers, eqx.is_array)

    def _step(x, params__state):
        params, state = params__state
        layer = eqx.combine(params, static)
        x, state = layer(x, state)
        return x, state

    x, all_state = jax.lax.scan(_step, x, (all_params, all_state))
    return x, all_state

On which note, heads-up @lockwo that if you were to take a @jax.grad of the scan you have, then you'd run into the same XLA bug I mentioned here ('And sadly, XLA has a longstanding bug in which grad-of-loop-of-inplace will make copies of that buffer during the backward pass!') Specifically you shouldn't do grad-of-loop-of-inplace if you read from the buffer you're also writing to. (But if it's never read during the loop, and only written to, then it's fine.)

@lockwo
Copy link
Contributor

lockwo commented Jan 16, 2025

Ah yes, that's good, I forgot you don't need to call .update since the substate wasn't meaningful (and can therefore make it more elegant scaning over both).

Thanks for the bug reminder too ;), hopefully by the time I remember all the bugs in XLA I will also be putting bounties on them.

@homerjed
Copy link
Contributor Author

Awesome stuff, I was hoping you'd see this @patrick-kidger :).

Yeah I tried the param xs arg for lax.scan but I must have had a hidden bug or something in my long, long notebook. It makes sense as just another pytree there.

Thanks everyone, I learnt a lot here!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants