Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is the difference between jax.lax.stop_gradient and eqx.static_field=True ? #958

Open
timnotavailable opened this issue Feb 24, 2025 · 7 comments
Labels
question User queries

Comments

@timnotavailable
Copy link

Hello,
I'm going to make a ODE simulator for Bloch Equation, I would like to define a module which has some attributes of array type,

class(eqx.Module):
array1:Array["x y z"]
array2: Array ["x y z"]

what I would like to do is : When I'm initializing with a np.array, then it is static and do not need the gradient, when initializing with a jnp.array then it is marked dynamic in jax (requires gradient). Currently I read the whole document, and there are two ways doing so : eqx.static_field = True or jax.lax.stop_gradient , which one should I choose and how should I wrap this function into the initialization?

I've also checked some issues, it is said that eqx.static_field = True this way is dangerous, what should I do to realize it?

@johannahaffner
Copy link

johannahaffner commented Feb 24, 2025

I think what you want here is custom filtering, e.g. as done in this example (look for the filter_spec definition in the main function).

Would that work for your use case? It would allow you to control what ends up in the dynamic and static partition of your model, so you can compute gradients only with respect to the arrays that you want gradients for.

@patrick-kidger
Copy link
Owner

You want jax.lax.stop_gradient. This will block autodiff from operating through that variable.

eqx.field(static=True) does something totally different -- it marks a dataclass field as not being part of the pytree structure.

As for how to use jax.lax.stop_gradient, usually something like this:

class AddWithoutGradient(eqx.Module):
    x: jax.Array

    def __call__(self, y: jax.Array):
        x = jax.lax.stop_gradient(x)
        return x + y

@patrick-kidger patrick-kidger added the question User queries label Feb 25, 2025
@johannahaffner
Copy link

jax.lax.stop_gradient "gets lost" when the input to a module is wrapped, and seems to require doing the wrapping directly inside the bound methods. This is why I didn't recommend it yesterday - although we might be able to fix it though the module metaclass?

Take this example, which I would naively expect to work:

import equinox as eqx
import jax
import wadler_lindig as wl

class ArrayContainer(eqx.Module):
    x: jax.Array

def parabolic_loss(tree):
    return jnp.sum(tree.x**2)

some_vector = 2.0 * jax.numpy.ones(3)

a = ArrayContainer(some_vector)
wl.pprint(jax.grad(parabolic_loss)(a), short_arrays=False)  # Has gradient

b = ArrayContainer(jax.lax.stop_gradient(some_vector))
wl.pprint(jax.grad(parabolic_loss)(b), short_arrays=False)  # Also has gradient

@patrick-kidger
Copy link
Owner

@johannahaffner -- so this is a misunderstanding of how jax.lax.stop_gradient works. It has to be called from within the region that has jax.grad applied to it. It doesn't control a property of the array it is called on -- it's a function that is called within a computation graph (and this function is the identity function with zero gradient).

@johannahaffner
Copy link

Thank you for clearing that up!

FWIW, here are some collected related issues, these four are directly relevant

#909, #710 #31, including this comment, #214

@timnotavailable
Copy link
Author

timnotavailable commented Feb 25, 2025

Hello , Thank you for all the feedback and comments! @johannahaffner @patrick-kidger.
I've checked all the aforementioned issues and comments, and think about how should I organize the software structure: So what I'm currently making is a very complicated simulation software describing complicated physics. So to manage the software, i plan to define the variable which is jnp.array type has gradients, other type (numpy, python built-in type) are non-trainable for managing the software.

My naive ideas after reviewing all the posts:

$Option 1$ : Using filter_spec:
The idea is, when designing the software, making all the fixed structure, (str those datatype) as static field (they will never need gradients!) so that static arguments makes performance slightly enhanced. Those attributes might need gradients, I set as Union[Array[""], float] (for a Scalar) or Array. For scalar , it is easy to handle because Equinox treat python built-in type as static. For Array type, my solution using filter_spec is firstly mark all the leaves as non-trainable, then specify all the jnp.array type in the attributes as trainable.

$Option 2$: Make the attribute a $@Property$ or using is_leaf, as mentioned in #31, like requrie_grad=False in pytorch

However I'm not sure using $@Property$ this method, whether it will cause some overhead in diffrax ,as far as I knew it is @Property is not jit compiled.

Not sure my understanding is correct or not, but thanks for all the reply again!

@patrick-kidger
Copy link
Owner

I think option 1 sounds like a good choice! Although FWIW you probably won't need to use static fields at all -- using eqx.filter_{jit, grad, ...} is generally a better choice. (Static fields are really an advanced feature that I try to have people avoid using unless they're familiar with the details of e.g. how JIT caching interacts with pytree semantics.)

Note that JAX totally will compile properties though. JAX uses a tracing compiler, see point 2 here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants