-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Equinox much slower than flax linen when passing pytrees #928
Comments
Thanks for looking into this. I am on the move right now but hope to look into it later the week. |
I get similar results on my Mac. I assume not crucial but I also ran the code on colab - here the results.
I hope the experts here can give some ideas about it ... |
I've just taken a look at the first of @lockwo's examples. In this case the difference can be attributed to the different parameters they are each initialised with. Step 1: initialise everything. import equinox as eqx
import jax
from jax import numpy as jnp
from flax import linen as nn
from functools import partial
class EqxMLP(eqx.Module):
layers: list
def __init__(self, key):
keys = jax.random.split(key, 3)
self.layers = [
eqx.nn.Linear(4, 64, key=keys[0]),
eqx.nn.Linear(64, 64, key=keys[1]),
eqx.nn.Linear(64, 2, key=keys[2]),
]
def __call__(self, x):
x = jax.nn.relu(self.layers[0](x))
x = jax.nn.relu(self.layers[1](x))
return self.layers[2](x)
class FlaxMLP(nn.Module):
@nn.compact
def __call__(self, x):
x = jax.nn.relu(nn.Dense(64)(x))
x = jax.nn.relu(nn.Dense(64)(x))
x = nn.Dense(2)(x)
return x
key = jax.random.key(42)
key, subkey = jax.random.split(key)
init = jax.random.uniform(subkey, shape=(4,))
key, subkey = jax.random.split(key)
eqxmlp = EqxMLP(subkey)
flaxmlp = FlaxMLP()
params = flaxmlp.init(key, jnp.ones((4,))) Step 2: create alternate versions of each using the other's randomly-initialised parameters: params2 = {
"params": {
"Dense_0": {
"kernel": eqxmlp.layers[0].weight.T,
"bias": eqxmlp.layers[0].bias
},
"Dense_1": {
"kernel": eqxmlp.layers[1].weight.T,
"bias": eqxmlp.layers[1].bias
},
"Dense_2": {
"kernel": eqxmlp.layers[2].weight.T,
"bias": eqxmlp.layers[2].bias
}
}
}
def get(m):
return (
m.layers[0].weight,
m.layers[0].bias,
m.layers[1].weight,
m.layers[1].bias,
m.layers[2].weight,
m.layers[2].bias,
)
flat_params = (
params["params"]["Dense_0"]["kernel"].T,
params["params"]["Dense_0"]["bias"],
params["params"]["Dense_1"]["kernel"].T,
params["params"]["Dense_1"]["bias"],
params["params"]["Dense_2"]["kernel"].T,
params["params"]["Dense_2"]["bias"],
)
eqxmlp2 = eqx.tree_at(get, eqxmlp, flat_params) Step 3: benchmark @jax.jit
def f_eqx(x):
return eqxmlp(x)
@jax.jit
def f_flax(x):
return flaxmlp.apply(params, x)
@jax.jit
def f_eqx2(x):
return eqxmlp2(x)
@jax.jit
def f_flax2(x):
return flaxmlp.apply(params2, x)
_ = jax.block_until_ready(f_eqx(init))
_ = jax.block_until_ready(f_flax(init))
_ = jax.block_until_ready(f_eqx2(init))
_ = jax.block_until_ready(f_flax2(init))
%timeit jax.block_until_ready(f_eqx(init))
%timeit jax.block_until_ready(f_flax(init))
%timeit jax.block_until_ready(f_eqx2(init))
%timeit jax.block_until_ready(f_flax2(init)) Results (on the CPU of M2 Macbook Air): 7.38 µs ± 87.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
6.57 µs ± 28.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
6.78 µs ± 188 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
7.33 µs ± 27.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each) We clearly see that the random initialisation of the Flax parameters just happens to be such that it saves about a microsecond (for whatever reason), whether things are implemented in either Flax or Equinox. So in this case at least, I feel comfortable saying that the two libraries are completely equivalent! (NB I also tried a variant in which I adjusted Equinox's internals to store the transpose of its weight matrix -- for complete equivalence with Flax -- and this doesn't seem to change anything whatsoever -- I got the exact same numbers out.) Okay, on to the next case @lockwo considers, which is when the parameters are passed as input (rather than being closed-over). In this case I observe that this: @jax.jit
def f_eqx3(x, model):
return model(x)
@jax.jit
def f_flax3(x, params):
return flaxmlp.apply(params, x)
_ = jax.block_until_ready(f_eqx3(init, eqxmlp))
_ = jax.block_until_ready(f_flax3(init, params))
%timeit jax.block_until_ready(f_eqx3(init, eqxmlp))
%timeit jax.block_until_ready(f_flax3(init, params))
%timeit jax.block_until_ready(f_eqx3(init, eqxmlp2))
%timeit jax.block_until_ready(f_flax3(init, params2)) gives:
Oh no! Equinox seems to be about 20 microseconds slower. However.... ...the overhead is coming from the fact that Equinox stores its parameters in a pytree that must be flattened as it crosses the JIT boundary! It is totally expected that this adds a small (microseconds) amount of overhead. That is, this is an additive overhead, not a multiplicative one. Don't think of the above as 'three times slower', think of the above as '20 microseconds apart'. If you move to bigger and bigger computations (as we do in real life, outside of these microbenchmarks), you probably don't care about differences of 20 microseconds. :) And if you are in a case in which you really do need to save 20 microseconds! Then the following trick flat_eqxmlp, treedef = jax.tree.flatten(eqxmlp)
flat_eqxmlp2, _ = jax.tree.flatten(eqxmlp2)
@jax.jit
def f_eqx4(x, flat_model):
model = jax.tree.unflatten(treedef, flat_model)
return model(x)
_ = jax.block_until_ready(f_eqx4(init, flat_eqxmlp))
%timeit jax.block_until_ready(f_eqx4(init, flat_eqxmlp))
%timeit jax.block_until_ready(f_eqx4(init, flat_eqxmlp2)) gives:
and we're back to being in first place! The extra flattening/unflattening lines of code here are a sort of Equinox analogue to the extra lines that Flax always has you write, with |
Ah these make sense. I didn't even consider data values in the first case since I was running on CPU (since it's much more noticeable on GPU sometimes https://www.thonking.ai/p/strangely-matrix-multiplications). The second point is also clear now. Is the scaling of the cost of a pytree crossing jit roughly O(1) or is pretty dependent on the pytree (e.g. depth, num nodes, flattening function, etc)? All in all, a good lesson in micro benchmarking ML code that usually is optimized for the non micro case. |
'Too small to have ever mattered for me' ;) but I expect roughly linear in the number of nodes. FWIW I did optimize the flattening time in the early days of Equinox. If we ever needed to then I wouldn't be surprised if we could push it even further e.g. with codegen to generate a custom flattening function for each Equinox module.
🎉 I'll close this issue for now, but feel free to reopen it if similar concerns later arise! |
Did some digging on one part of #926 and found something that seemed unexpected to me. First is on CPU (Mac), equinox MLP seems a little slower than flax. But if you pass the MLP as the argument, rather than capturing it, there becomes a much more substantial difference, like 100% slower(!) Maybe I am missing something, but these slowdowns seem unexpected for something as trivial as a few layer MLP (like this should just be 3 matrix vector products and 3 adds). For the first case, looking at the jaxprs it seems like it's just the order of the weight matrix, but it seems weird that that would matter.
jaxpr
eqx
flax
yields
14.8 µs ± 1.88 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
12 µs ± 605 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
and
yields
185 µs ± 9.91 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
43.6 µs ± 3.34 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
19.6 µs ± 454 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
The text was updated successfully, but these errors were encountered: