Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equinox much slower than flax linen when passing pytrees #928

Closed
lockwo opened this issue Jan 6, 2025 · 5 comments
Closed

Equinox much slower than flax linen when passing pytrees #928

lockwo opened this issue Jan 6, 2025 · 5 comments
Labels
question User queries

Comments

@lockwo
Copy link
Contributor

lockwo commented Jan 6, 2025

Did some digging on one part of #926 and found something that seemed unexpected to me. First is on CPU (Mac), equinox MLP seems a little slower than flax. But if you pass the MLP as the argument, rather than capturing it, there becomes a much more substantial difference, like 100% slower(!) Maybe I am missing something, but these slowdowns seem unexpected for something as trivial as a few layer MLP (like this should just be 3 matrix vector products and 3 adds). For the first case, looking at the jaxprs it seems like it's just the order of the weight matrix, but it seems weird that that would matter.

jaxpr

eqx

let relu = { lambda ; a:f32[64]. let b:f32[64] = max a 0.0 in (b,) } in
{ lambda ; c:f32[4]. let
    d:f32[2] = pjit[
      name=f_eqx
      jaxpr={ lambda e:f32[64,4] f:f32[64] g:f32[64,64] h:f32[64] i:f32[2,64] j:f32[2]; k:f32[4]. let
          l:f32[64] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] e k
          m:f32[64] = add l f
          n:f32[64] = custom_jvp_call[
            call_jaxpr={ lambda ; o:f32[64]. let
                p:f32[64] = pjit[name=relu jaxpr=relu] o
              in (p,) }
            jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x115b6e320>
            num_consts=0
            symbolic_zeros=False
          ] m
          q:f32[64] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] g n
          r:f32[64] = add q h
          s:f32[64] = custom_jvp_call[
            call_jaxpr={ lambda ; t:f32[64]. let
                u:f32[64] = pjit[name=relu jaxpr=relu] t
              in (u,) }
            jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x1365f5d80>
            num_consts=0
            symbolic_zeros=False
          ] r
          v:f32[2] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] i s
          w:f32[2] = add v j
        in (w,) }
    ] c
  in (d,) }

flax

let relu = { lambda ; a:f32[64]. let b:f32[64] = max a 0.0 in (b,) } in
{ lambda ; c:f32[4]. let
    d:f32[2] = pjit[
      name=f_flax
      jaxpr={ lambda e:f32[4,64] f:f32[64] g:f32[64,64] h:f32[64] i:f32[64,2] j:f32[2]; k:f32[4]. let
          l:f32[64] = dot_general[dimension_numbers=(([0], [0]), ([], []))] k e
          m:f32[64] = add l f
          n:f32[64] = custom_jvp_call[
            call_jaxpr={ lambda ; o:f32[64]. let
                p:f32[64] = pjit[name=relu jaxpr=relu] o
              in (p,) }
            jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x1365f6830>
            num_consts=0
            symbolic_zeros=False
          ] m
          q:f32[64] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n g
          r:f32[64] = add q h
          s:f32[64] = custom_jvp_call[
            call_jaxpr={ lambda ; t:f32[64]. let
                u:f32[64] = pjit[name=relu jaxpr=relu] t
              in (u,) }
            jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x1365f4f70>
            num_consts=0
            symbolic_zeros=False
          ] r
          v:f32[2] = dot_general[dimension_numbers=(([0], [0]), ([], []))] s i
          w:f32[2] = add v j
        in (w,) }
    ] c
  in (d,) }
import equinox as eqx
import jax
from jax import numpy as jnp
from flax import linen as nn
from functools import partial

class EqxMLP(eqx.Module):
    layers: list

    def __init__(self, key):
        keys = jax.random.split(key, 3)
        self.layers = [
            eqx.nn.Linear(4, 64, key=keys[0]),
            eqx.nn.Linear(64, 64, key=keys[1]),
            eqx.nn.Linear(64, 2, key=keys[2]),
        ]

    def __call__(self, x):
        x = jax.nn.relu(self.layers[0](x))
        x = jax.nn.relu(self.layers[1](x))
        return self.layers[2](x)


class FlaxMLP(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = jax.nn.relu(nn.Dense(64)(x))
        x = jax.nn.relu(nn.Dense(64)(x))
        x = nn.Dense(2)(x)
        return x

key = jax.random.key(42)
key, subkey = jax.random.split(key)
init = jax.random.uniform(subkey, shape=(4,))

key, subkey = jax.random.split(key)
eqxmlp = EqxMLP(subkey)
flaxmlp = FlaxMLP()
params = flaxmlp.init(key, jnp.ones((4,)))

@jax.jit
def f_eqx(x):
    return eqxmlp(x)

@jax.jit
def f_flax(x):
    return flaxmlp.apply(params, x)

_ = jax.block_until_ready(f_eqx(init))
_ = jax.block_until_ready(f_flax(init))

print(jax.make_jaxpr(f_flax)(init))
print(jax.make_jaxpr(f_eqx)(init))

%%timeit
_ = jax.block_until_ready(f_eqx(init))

%%timeit
_ = jax.block_until_ready(f_flax(init))

yields

14.8 µs ± 1.88 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

12 µs ± 605 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

and

key = jax.random.key(42)
key, subkey = jax.random.split(key)
init = jax.random.uniform(subkey, shape=(4,))

key, subkey = jax.random.split(key)
eqxmlp = EqxMLP(subkey)
flaxmlp = FlaxMLP()
params = flaxmlp.init(key, jnp.ones((4,)))

@jax.jit
def f_eqx(x, model):
    return model(x)

@eqx.filter_jit
def filter_eqx(x, model):
    return model(x)

@partial(jax.jit, static_argnums=(1))
def f_flax(x, model, params):
    return model.apply(params, x)

_ = jax.block_until_ready(filter_eqx(init, eqxmlp))
_ = jax.block_until_ready(f_eqx(init, eqxmlp))
_ = jax.block_until_ready(f_flax(init, flaxmlp, params))

%%timeit
_ = jax.block_until_ready(filter_eqx(init, eqxmlp))

%%timeit
_ = jax.block_until_ready(f_eqx(init, eqxmlp))

%%timeit
_ = jax.block_until_ready(f_flax(init, flaxmlp, params))

yields

185 µs ± 9.91 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

43.6 µs ± 3.34 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

19.6 µs ± 454 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

@dorjeduck
Copy link

Thanks for looking into this. I am on the move right now but hope to look into it later the week.

@dorjeduck
Copy link

dorjeduck commented Jan 8, 2025

I get similar results on my Mac. I assume not crucial but I also ran the code on colab - here the results.

T4 GPU
 
137 µs ± 19.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

131 µs ± 21.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


1.23 ms ± 135 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

217 µs ± 4.62 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

202 µs ± 16.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


v2-8 TPU

123 µs ± 1.61 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

126 µs ± 5.79 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


1.31 ms ± 20.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

181 µs ± 21.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

140 µs ± 3.84 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

I hope the experts here can give some ideas about it ...

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jan 8, 2025

I've just taken a look at the first of @lockwo's examples. In this case the difference can be attributed to the different parameters they are each initialised with.

Step 1: initialise everything.

import equinox as eqx
import jax
from jax import numpy as jnp
from flax import linen as nn
from functools import partial

class EqxMLP(eqx.Module):
    layers: list

    def __init__(self, key):
        keys = jax.random.split(key, 3)
        self.layers = [
            eqx.nn.Linear(4, 64, key=keys[0]),
            eqx.nn.Linear(64, 64, key=keys[1]),
            eqx.nn.Linear(64, 2, key=keys[2]),
        ]

    def __call__(self, x):
        x = jax.nn.relu(self.layers[0](x))
        x = jax.nn.relu(self.layers[1](x))
        return self.layers[2](x)

class FlaxMLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = jax.nn.relu(nn.Dense(64)(x))
        x = jax.nn.relu(nn.Dense(64)(x))
        x = nn.Dense(2)(x)
        return x

key = jax.random.key(42)
key, subkey = jax.random.split(key)
init = jax.random.uniform(subkey, shape=(4,))

key, subkey = jax.random.split(key)
eqxmlp = EqxMLP(subkey)
flaxmlp = FlaxMLP()
params = flaxmlp.init(key, jnp.ones((4,)))

Step 2: create alternate versions of each using the other's randomly-initialised parameters:

params2 = {
    "params": {
        "Dense_0": {
            "kernel": eqxmlp.layers[0].weight.T,
            "bias": eqxmlp.layers[0].bias
        },
        "Dense_1": {
            "kernel": eqxmlp.layers[1].weight.T,
            "bias": eqxmlp.layers[1].bias
        },
        "Dense_2": {
            "kernel": eqxmlp.layers[2].weight.T,
            "bias": eqxmlp.layers[2].bias
        }
    }
}

def get(m):
    return (
        m.layers[0].weight,
        m.layers[0].bias,
        m.layers[1].weight,
        m.layers[1].bias,
        m.layers[2].weight,
        m.layers[2].bias,
    )
flat_params = (
    params["params"]["Dense_0"]["kernel"].T,
    params["params"]["Dense_0"]["bias"],
    params["params"]["Dense_1"]["kernel"].T,
    params["params"]["Dense_1"]["bias"],
    params["params"]["Dense_2"]["kernel"].T,
    params["params"]["Dense_2"]["bias"],
)
eqxmlp2 = eqx.tree_at(get, eqxmlp, flat_params)

Step 3: benchmark

@jax.jit
def f_eqx(x):
    return eqxmlp(x)

@jax.jit
def f_flax(x):
    return flaxmlp.apply(params, x)

@jax.jit
def f_eqx2(x):
    return eqxmlp2(x)

@jax.jit
def f_flax2(x):
    return flaxmlp.apply(params2, x)

_ = jax.block_until_ready(f_eqx(init))
_ = jax.block_until_ready(f_flax(init))
_ = jax.block_until_ready(f_eqx2(init))
_ = jax.block_until_ready(f_flax2(init))

%timeit jax.block_until_ready(f_eqx(init))

%timeit jax.block_until_ready(f_flax(init))

%timeit jax.block_until_ready(f_eqx2(init))

%timeit jax.block_until_ready(f_flax2(init))

Results (on the CPU of M2 Macbook Air):

7.38 µs ± 87.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
6.57 µs ± 28.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
6.78 µs ± 188 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
7.33 µs ± 27.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

We clearly see that the random initialisation of the Flax parameters just happens to be such that it saves about a microsecond (for whatever reason), whether things are implemented in either Flax or Equinox.

So in this case at least, I feel comfortable saying that the two libraries are completely equivalent!

(NB I also tried a variant in which I adjusted Equinox's internals to store the transpose of its weight matrix -- for complete equivalence with Flax -- and this doesn't seem to change anything whatsoever -- I got the exact same numbers out.)


Okay, on to the next case @lockwo considers, which is when the parameters are passed as input (rather than being closed-over). In this case I observe that this:

@jax.jit
def f_eqx3(x, model):
    return model(x)

@jax.jit
def f_flax3(x, params):
    return flaxmlp.apply(params, x)

_ = jax.block_until_ready(f_eqx3(init, eqxmlp))
_ = jax.block_until_ready(f_flax3(init, params))

%timeit jax.block_until_ready(f_eqx3(init, eqxmlp))

%timeit jax.block_until_ready(f_flax3(init, params))

%timeit jax.block_until_ready(f_eqx3(init, eqxmlp2))

%timeit jax.block_until_ready(f_flax3(init, params2))

gives:

29.6 µs ± 810 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
9.7 µs ± 201 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
30 µs ± 56.8 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
9.85 µs ± 212 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Oh no! Equinox seems to be about 20 microseconds slower. However....

...the overhead is coming from the fact that Equinox stores its parameters in a pytree that must be flattened as it crosses the JIT boundary! It is totally expected that this adds a small (microseconds) amount of overhead.

That is, this is an additive overhead, not a multiplicative one. Don't think of the above as 'three times slower', think of the above as '20 microseconds apart'. If you move to bigger and bigger computations (as we do in real life, outside of these microbenchmarks), you probably don't care about differences of 20 microseconds. :)

And if you are in a case in which you really do need to save 20 microseconds! Then the following trick

flat_eqxmlp, treedef = jax.tree.flatten(eqxmlp)
flat_eqxmlp2, _ = jax.tree.flatten(eqxmlp2)

@jax.jit
def f_eqx4(x, flat_model):
    model = jax.tree.unflatten(treedef, flat_model)
    return model(x)

_ = jax.block_until_ready(f_eqx4(init, flat_eqxmlp))

%timeit jax.block_until_ready(f_eqx4(init, flat_eqxmlp))

%timeit jax.block_until_ready(f_eqx4(init, flat_eqxmlp2))

gives:

8.66 µs ± 352 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
8.71 µs ± 171 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

and we're back to being in first place!

The extra flattening/unflattening lines of code here are a sort of Equinox analogue to the extra lines that Flax always has you write, with flaxmlp.init / flaxmlp.apply. By default they don't really matter (only 20 microseconds) so Equinox is written so that you don't need them... but this trick is still available if you do need it :)

@lockwo
Copy link
Contributor Author

lockwo commented Jan 8, 2025

Ah these make sense. I didn't even consider data values in the first case since I was running on CPU (since it's much more noticeable on GPU sometimes https://www.thonking.ai/p/strangely-matrix-multiplications). The second point is also clear now. Is the scaling of the cost of a pytree crossing jit roughly O(1) or is pretty dependent on the pytree (e.g. depth, num nodes, flattening function, etc)?

All in all, a good lesson in micro benchmarking ML code that usually is optimized for the non micro case.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jan 8, 2025

Is the scaling of the cost of a pytree crossing jit roughly O(1) or is pretty dependent on the pytree (e.g. depth, num nodes, flattening function, etc)?

'Too small to have ever mattered for me' ;) but I expect roughly linear in the number of nodes.

FWIW I did optimize the flattening time in the early days of Equinox. If we ever needed to then I wouldn't be surprised if we could push it even further e.g. with codegen to generate a custom flattening function for each Equinox module.

All in all, a good lesson in micro benchmarking ML code that usually is optimized for the non micro case.

🎉

I'll close this issue for now, but feel free to reopen it if similar concerns later arise!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants