-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Slow Initialization #961
Comments
Running this code on a fresh colab environment (and some block_until_readys), I see array is faster, but custom is slower
However, the timing doesn't seem strictly additive, if I increase
|
I also repeated this with
The Equinox overhead still seems large. |
Just spitballing here, but there seem to be 3 option:
|
I think what you're measuring here is an additive overhead of microseconds in the flattening and unflattening. This is expected/known, and rarely ever troublesome. Your computational work has to be essentially negligible for this to affect real-world results.
I think you're measuring the noise in the actual operation itself here. Those standard deviations are pretty large, and overlap quite a lot! |
Where does Module do that differently than the above example This may be the important point!
This appears to be non-negligible in import unxt as u
def convert_cart2d_to_polar(params, aux):
x, y = params["x"], params["y"]
r = jnp.sqrt(x**2 + y**2)
theta = jnp.arctan2(y, x)
return {"r": r, "theta": theta}, aux
params = {"x": u.Quantity(jnp.array([1.0, 2.0]), "m"), "y": u.Quantity(jnp.array([3.0, 4.0]), "m")}
aux = {}
jac, aux = jax.jacfwd(convert_cart2d_to_polar, has_aux=True)(params, aux)
jac
# {'r': Quantity['length']({'x': Quantity['length'](Array([[0.31622777, 0. ],
# [0. , 0.4472136 ]], dtype=float64), unit='m'), 'y': Quantity['length'](Array([[0.9486833 , 0. ],
# [0. , 0.89442719]], dtype=float64), unit='m')}, unit='m'),
# 'theta': Quantity['angle']({'x': Quantity['length'](Array([[-0.3, 0. ],
# [ 0. , -0.2]], dtype=float64), unit='m'), 'y': Quantity['length'](Array([[0.1, 0. ],
# [0. , 0.1]], dtype=float64), unit='m')}, unit='rad')}
# NOTE: the weird nesting is something else I'm trying to correct,
# but it appears to be because jax.jacfwd doesn't have an `is_leaf`
func = jax.jit(jax.jacfwd(convert_cart2d_to_polar, has_aux=True))
func(params, aux)
%timeit jax.block_until_ready(func(params, aux))
# 26.6 µs ± 140 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# vs
params2 = {k: v.value for k, v in params.items()}
func(params2, aux)
%timeit jax.block_until_ready(func(params2, aux))
# 9.65 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each) So it's 2.75x faster not to use a Quantity. |
It handles different types of fields and includes some checks: Line 908 in 8191b11
Line 953 in 8191b11
|
I took a look at your MWE - it looks like we're seeing the same additive overhead as above. If you have many small modules and do tiny operations on all of these, you might indeed be in the regime in which this starts to matter. Not sure if you can batch instead? But I'm unfamiliar with quax, so I can't really say anything about that! The only thing I can add -
your Jacobian actually looks exactly as expected. You have two elements in your radians array, and two angles, corresponding to two different points in Cartesian space. Both To make this more readable and intuitive, you might want to try |
Unfortunately batching isn't possible. Yes, the tiny operations with small modules appears to be the case with quax. @patrick-kidger Is it possible in quax (or maybe via equinox) to provide a custom
Yes, but see that the inner dicts are inside outer Quantity objects. This is because |
@johannahaffner thank you for identifying that this is still the additive overhead! @nstarman I think what'd probably be most desirable is if we can just speed up the existing flattening/unflattening implementation along the lines of whatever alternative you have in mind! Taking unflattening as an example, pretty much the only difference between what we already have, and something that just assigns attributes If this is indeed the source of the overhead then it wouldn't be very hard to dynamically generate such a 'hardcoded' function for each new Module. |
Hi @patrick-kidger. I've recently been speed testing some code and found that Equinox is around 2x slower than a custom pytree.
Now with a custom Module
Now with a custom PyTree
So the timings are array:Module:custom = 2.87 : 8.4 : 4.02.
Is there any way to speed up Module?
The text was updated successfully, but these errors were encountered: