-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
What is the difference between jax.lax.stop_gradient and eqx.static_field=True ? #958
Comments
I think what you want here is custom filtering, e.g. as done in this example (look for the Would that work for your use case? It would allow you to control what ends up in the dynamic and static partition of your model, so you can compute gradients only with respect to the arrays that you want gradients for. |
You want
As for how to use class AddWithoutGradient(eqx.Module):
x: jax.Array
def __call__(self, y: jax.Array):
x = jax.lax.stop_gradient(x)
return x + y |
Take this example, which I would naively expect to work: import equinox as eqx
import jax
import wadler_lindig as wl
class ArrayContainer(eqx.Module):
x: jax.Array
def parabolic_loss(tree):
return jnp.sum(tree.x**2)
some_vector = 2.0 * jax.numpy.ones(3)
a = ArrayContainer(some_vector)
wl.pprint(jax.grad(parabolic_loss)(a), short_arrays=False) # Has gradient
b = ArrayContainer(jax.lax.stop_gradient(some_vector))
wl.pprint(jax.grad(parabolic_loss)(b), short_arrays=False) # Also has gradient |
@johannahaffner -- so this is a misunderstanding of how |
Thank you for clearing that up! FWIW, here are some collected related issues, these four are directly relevant #909, #710 #31, including this comment, #214 |
Hello , Thank you for all the feedback and comments! @johannahaffner @patrick-kidger. My naive ideas after reviewing all the posts:
However I'm not sure using Not sure my understanding is correct or not, but thanks for all the reply again! |
I think option 1 sounds like a good choice! Although FWIW you probably won't need to use static fields at all -- using Note that JAX totally will compile properties though. JAX uses a tracing compiler, see point 2 here. |
Hello,
I'm going to make a ODE simulator for Bloch Equation, I would like to define a module which has some attributes of array type,
class(eqx.Module):
array1:Array["x y z"]
array2: Array ["x y z"]
what I would like to do is : When I'm initializing with a np.array, then it is static and do not need the gradient, when initializing with a jnp.array then it is marked dynamic in jax (requires gradient). Currently I read the whole document, and there are two ways doing so : eqx.static_field = True or jax.lax.stop_gradient , which one should I choose and how should I wrap this function into the initialization?
I've also checked some issues, it is said that eqx.static_field = True this way is dangerous, what should I do to realize it?
The text was updated successfully, but these errors were encountered: