Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow Initialization #961

Open
nstarman opened this issue Feb 26, 2025 · 9 comments
Open

Slow Initialization #961

nstarman opened this issue Feb 26, 2025 · 9 comments
Labels
question User queries

Comments

@nstarman
Copy link

Hi @patrick-kidger. I've recently been speed testing some code and found that Equinox is around 2x slower than a custom pytree.

import jax
import equinox as eqx

def func(x, y):
    return (x - y) / (x + y)

x = jnp.linspace(0.0, 1, 10)
y = jnp.linspace(1.0, 2, 10)

print(jax.make_jaxpr(func)(x, y))
# { lambda ; a:f32[10] b:f32[10]. let
#     c:f32[10] = sub a b
#     d:f32[10] = add a b
#     e:f32[10] = div c d
#   in (e,) }

f = jax.jit(func)
f(x, y)
# %timeit f(x, y)
# 2.87 µs ± 95 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Now with a custom Module

class MyEqxArray(eqx.Module):
    array: Array

    def __add__(self, other):
        return jax.tree.map(jnp.add, self, other)

    def __sub__(self, other):
        return jax.tree.map(jnp.subtract, self, other)

    def __truediv__(self, other):
        return jax.tree.map(jnp.divide, self, other)


mx = MyEqxArray(x)
my = MyEqxArray(y)

func(mx, my)

print(jax.make_jaxpr(func)(mx, my))  # same jaxpr 👍 
# { lambda ; a:f32[10] b:f32[10]. let
#     c:f32[10] = sub a b
#     d:f32[10] = add a b
#     e:f32[10] = div c d
#   in (e,) }

f = jax.jit(func)
f(mx, my)
%timeit f(mx, my)
# 8.4 µs ± 969 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Now with a custom PyTree

from dataclasses import dataclass

@jax.tree_util.register_pytree_node_class
@dataclass
class CustomArray:
    array: Array

    def __add__(self, other):
        return jax.tree.map(jnp.add, self, other)

    def __sub__(self, other):
        return jax.tree.map(jnp.subtract, self, other)

    def __truediv__(self, other):
        return jax.tree.map(jnp.divide, self, other)

    def tree_flatten(self) -> tuple[tuple[Any], Any]:
        return (self.array,), None

    @classmethod
    def tree_unflatten(cls, aux_data: Any, children: tuple[Any]) -> "CustomArray":
        return cls(*children)

f = jax.jit(func)
f(mx, my)
# %timeit f(mx, my)
# 4.02 µs ± 960 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

So the timings are array:Module:custom = 2.87 : 8.4 : 4.02.

Is there any way to speed up Module?

@lockwo
Copy link
Contributor

lockwo commented Feb 26, 2025

Running this code on a fresh colab environment (and some block_until_readys), I see array is faster, but custom is slower

24.5 µs ± 9.12 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
31.6 µs ± 5.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
33.7 µs ± 3.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

However, the timing doesn't seem strictly additive, if I increase x = jnp.linspace(0.0, 1, 1000000)

838 µs ± 388 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.02 ms ± 506 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.38 ms ± 588 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

@nstarman
Copy link
Author

I also repeated this with block_until_ready and x = jnp.linspace(0.0, 1, 10_000)

JAX: 8.09 µs ± 261 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Equinox: 16.9 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
PyTree: 10.2 µs ± 2 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

The Equinox overhead still seems large.

@nstarman
Copy link
Author

nstarman commented Feb 26, 2025

Just spitballing here, but there seem to be 3 option:

  1. It's possible to speed up Equinox to be the same speed as the custom PyTree 🎉. This is my preferred solution!
  2. It isn't and this is a not-going-to-fix 😢
  3. It isn't because of some of the fancy stuff in Module, e.g. locking/unlocking __init__, etc, but there's interest in speed ups. Then maybe a good solution would be to add an ABC — AbstractModule — and then also vendor a faster bare-bones FastModule that doesn't do the fancy slow stuff (except field(converter=...) which I've tested to be fast). (I've thought along these lines in dataclassish and did a speed test with that custom dataclass + converter with results identical to the PyTree case in this Issue). With an ABC, ecosystem tools that expect a Module can be trivially adapted to work with FastModule.

@patrick-kidger
Copy link
Owner

I think what you're measuring here is an additive overhead of microseconds in the flattening and unflattening. This is expected/known, and rarely ever troublesome. Your computational work has to be essentially negligible for this to affect real-world results.

However, the timing doesn't seem strictly additive

I think you're measuring the noise in the actual operation itself here. Those standard deviations are pretty large, and overlap quite a lot!

@patrick-kidger patrick-kidger added the question User queries label Feb 26, 2025
@nstarman
Copy link
Author

nstarman commented Feb 26, 2025

in the flattening and unflattening

Where does Module do that differently than the above example jax.tree_util.register_pytree_node_class PyTree?

This may be the important point! CustomArray is a a PyTree like MyEqxArray. Why is MyEqxArray slower / where is the flattening overhead?

Your computational work has to be essentially negligible for this to affect real-world results.

This appears to be non-negligible in quax-derived objects where the overhead happens many times.
At least this is what I've found thus far in trying to figure out why unxt and coordinax operations are slow.
I purposefully haven't jitted any of the quaxify(jax.foo) in https://github.com/GalacticDynamics/quaxed nor the dunder methods in https://github.com/GalacticDynamics/quax-blocks. But when I use units things are much slower

import unxt as u

def convert_cart2d_to_polar(params, aux):
    x, y = params["x"], params["y"]
    r = jnp.sqrt(x**2 + y**2)
    theta = jnp.arctan2(y, x)
    return {"r": r, "theta": theta}, aux

params = {"x": u.Quantity(jnp.array([1.0, 2.0]), "m"), "y": u.Quantity(jnp.array([3.0, 4.0]), "m")}
aux = {}

jac, aux = jax.jacfwd(convert_cart2d_to_polar, has_aux=True)(params, aux)
jac 
# {'r': Quantity['length']({'x': Quantity['length'](Array([[0.31622777, 0.        ],
#         [0.        , 0.4472136 ]], dtype=float64), unit='m'), 'y': Quantity['length'](Array([[0.9486833 , 0.        ],
#         [0.        , 0.89442719]], dtype=float64), unit='m')}, unit='m'),
#  'theta': Quantity['angle']({'x': Quantity['length'](Array([[-0.3,  0. ],
#         [ 0. , -0.2]], dtype=float64), unit='m'), 'y': Quantity['length'](Array([[0.1, 0. ],
#         [0. , 0.1]], dtype=float64), unit='m')}, unit='rad')}
# NOTE: the weird nesting is something else I'm trying to correct,
#              but it appears to be because jax.jacfwd doesn't have an `is_leaf`

func = jax.jit(jax.jacfwd(convert_cart2d_to_polar, has_aux=True))

func(params, aux)
%timeit jax.block_until_ready(func(params, aux))
# 26.6 µs ± 140 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# vs
params2 = {k: v.value for k, v in params.items()}
func(params2, aux)
%timeit jax.block_until_ready(func(params2, aux))
# 9.65 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

So it's 2.75x faster not to use a Quantity.
If I don't jit then it becomes 8x faster.

@johannahaffner
Copy link

in the flattening and unflattening

Where does Module do that differently than the above example jax.tree_util.register_pytree_node_class PyTree?
[…] where is the flattening overhead?

It handles different types of fields and includes some checks:

def _flatten_module(module: "Module", with_keys: bool):

def _unflatten_module(cls: type["Module"], aux: _FlattenedData, dynamic_field_values):

@johannahaffner
Copy link

johannahaffner commented Feb 26, 2025

I took a look at your MWE - it looks like we're seeing the same additive overhead as above. If you have many small modules and do tiny operations on all of these, you might indeed be in the regime in which this starts to matter. Not sure if you can batch instead? But I'm unfamiliar with quax, so I can't really say anything about that!

The only thing I can add -

# NOTE: the weird nesting is something else I'm trying to correct,
#              but it appears to be because jax.jacfwd doesn't have an `is_leaf`

your Jacobian actually looks exactly as expected. You have two elements in your radians array, and two angles, corresponding to two different points in Cartesian space. Both $r$ and $\theta$ have derivatives with respect to each of the elements of $x$ and $y$, they happen to be diagonal matrices because you take the Pythagorean sum element-wise.

To make this more readable and intuitive, you might want to try wl.pprint from the wadler_lindig library that Equinox now uses under the hood.

@nstarman
Copy link
Author

I took a look at your MWE - it looks like we're seeing the same additive overhead as above. If you have many small modules and do tiny operations on all of these, you might indeed be in the regime in which this starts to matter. Not sure if you can batch instead? But I'm unfamiliar with quax, so I can't really say anything about that!

Unfortunately batching isn't possible. Yes, the tiny operations with small modules appears to be the case with quax.

@patrick-kidger Is it possible in quax (or maybe via equinox) to provide a custom tree_flatten / tree_unflatten and have Module use that? To override the default behavior...

The only thing I can add -

# NOTE: the weird nesting is something else I'm trying to correct,
#              but it appears to be because jax.jacfwd doesn't have an `is_leaf`

your Jacobian actually looks exactly as expected.

Yes, but see that the inner dicts are inside outer Quantity objects. This is because jacfwd can't stop at a certain leaf level, eg Quantity, rather than looking deeper to the underlying arrays.
But that's a separate issue.

@patrick-kidger
Copy link
Owner

@johannahaffner thank you for identifying that this is still the additive overhead!

@nstarman I think what'd probably be most desirable is if we can just speed up the existing flattening/unflattening implementation along the lines of whatever alternative you have in mind!

Taking unflattening as an example, pretty much the only difference between what we already have, and something that just assigns attributes self.foo = foo; self.bar = bar is that the latter hardcodes the attributes (rather than doing an iteration).

If this is indeed the source of the overhead then it wouldn't be very hard to dynamically generate such a 'hardcoded' function for each new Module.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

4 participants