-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WeightNorm
causes unexpected PyTrees inequality
#965
Comments
WeightNorm
causes unexpected PyTrees inequalityWeightNorm
causes unexpected PyTrees inequality
This is expected. The values of the leaves are the same, but you have two new trees now, that makes two objects that JAX is treating separately. Static leaves (such as bound methods and their jaxprs) generally do not permit equality checks, and layer does have a bound method. If you partition the layers into the dynamic and static components, you will see that equality checks for the weights. You can verify this: import equinox as eqx
import jax.random as jr
layer = eqx.nn.Linear(2, 2, key=jr.key(0))
m1 = eqx.nn.WeightNorm(layer)
m2 = eqx.nn.WeightNorm(layer)
d1, s1 = eqx.partition(m1, eqx.is_array)
d2, s2 = eqx.partition(m2, eqx.is_array)
assert eqx.tree_equal(d1, d2) # Dynamic components (arrays) are equal
assert not eqx.tree_equal(s1, s2) # Static components are not equal |
@johannahaffner thank you.
However, what's the motivation to declare |
This is not specific to equinox, this is a general feature of how Python compares objects for equality. Callables, such as functions and methods, do not support equality checks, which require the implementation of an That means that equality checks for callables only pass if these point at the exact same thing - which is a rare special case in practice. Because equinox modules are immutable, we get a new layer back. In the second example below, we actually create the callable def f(x):
pass
g = f
h = f
assert g == h # g and h point to the same object
def make_f(x):
def f():
return x
return f
x = 42
g = make_f(x)
h = make_f(x)
assert g == h # g and h point to different objects The main takeaway is that you won't ever care about having your methods be the exact same thing, but you will care about having your leaves be the exact same thing! And you can check that by partitioning, and then asserting equality on the arrays. |
Maybe I didn't ask clearly. |
Ah! That is what you mean. This is consistent with how things are done elsewhere in the library. More complicated modules such as equinox/equinox/nn/_attention.py Line 122 in 8191b11
Comparing any callable for equality is just not what is expected in Python, and I don't think that there is a good reason to push things into methods in order to be able to do so. |
That I understand because |
Right. I can't speak for @patrick-kidger, but elsewhere in the Equinox system we do make frequent use of fields to specify some norm, e.g. here. The reason being that a design like that allows to make the norm public as an optional input argument if requested, and users might then like to use a norm other than the Frobenius norm which |
Okay, so! I think we can change this. First of all, norm is assigned here: equinox/equinox/nn/_weight_norm.py Line 91 in 8191b11
and I assume the fact that it is dynamically creating new partials means that these are not comparing equal to each other later. The fix is probably to cache based on As for why this is a field, and indeed one that isn't marked static (as opposed to a method or a static field): this is because there have been a few use-cases in which people would like to dynamically patch this field using I'd be happy to take a PR on the above! I think equality here is a reasonable thing to want :) |
The following assertion fails.
The reason of failure is that
WeightNorm._norm
is different (== operation returns False) for everyWeightNorm
instance.This also causes the failure of using
tree_equal
to check deserialized models.The text was updated successfully, but these errors were encountered: