Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WeightNorm causes unexpected PyTrees inequality #965

Open
yuanz271 opened this issue Mar 1, 2025 · 8 comments
Open

WeightNorm causes unexpected PyTrees inequality #965

yuanz271 opened this issue Mar 1, 2025 · 8 comments
Labels
feature New feature

Comments

@yuanz271
Copy link

yuanz271 commented Mar 1, 2025

The following assertion fails.

import equinox as eqx
import equinox.random as jrandom

layer = eqx.nn.Linear(2, 2, key=jrandom.key(0))
m1 = eqx.nn.WeightNorm(layer)
m2 = eqx.nn.WeightNorm(layer)
assert eqx.tree_equal(m1, m2)  # unequal

The reason of failure is that WeightNorm._norm is different (== operation returns False) for every WeightNorm instance.

This also causes the failure of using tree_equal to check deserialized models.

@yuanz271 yuanz271 changed the title Two WeightNorm causes unexpected PyTrees inequality WeightNorm causes unexpected PyTrees inequality Mar 1, 2025
@johannahaffner
Copy link

This is expected. The values of the leaves are the same, but you have two new trees now, that makes two objects that JAX is treating separately. Static leaves (such as bound methods and their jaxprs) generally do not permit equality checks, and layer does have a bound method. If you partition the layers into the dynamic and static components, you will see that equality checks for the weights.

You can verify this:

import equinox as eqx
import jax.random as jr


layer = eqx.nn.Linear(2, 2, key=jr.key(0))
m1 = eqx.nn.WeightNorm(layer)
m2 = eqx.nn.WeightNorm(layer)
d1, s1 = eqx.partition(m1, eqx.is_array)
d2, s2 = eqx.partition(m2, eqx.is_array)

assert eqx.tree_equal(d1, d2)  # Dynamic components (arrays) are equal
assert not eqx.tree_equal(s1, s2)  # Static components are not equal

@yuanz271
Copy link
Author

yuanz271 commented Mar 1, 2025

@johannahaffner thank you.

This is expected. The values of the leaves are the same, but you have two new trees now, that makes two objects that JAX is treating separately. Static leaves (such as bound methods and their jaxprs) generally do not permit equality checks, and layer does have a bound method. If you partition the layers into the dynamic and static components, you will see that equality checks for the weights.

However, what's the motivation to declare _norm in the inventory? The other instance (bound) methods are not compared.

@johannahaffner
Copy link

This is not specific to equinox, this is a general feature of how Python compares objects for equality. Callables, such as functions and methods, do not support equality checks, which require the implementation of an __eq__ method on the object. For these, Python will actually check if they are the exact same instance.

That means that equality checks for callables only pass if these point at the exact same thing - which is a rare special case in practice. Because equinox modules are immutable, we get a new layer back. In the second example below, we actually create the callable f afresh each time, and you see that equality does not check, even though they close over the same value.

def f(x):
    pass

g = f
h = f

assert g == h  # g and h point to the same object

def make_f(x):
    def f():
        return x
    return f

x = 42
g = make_f(x)
h = make_f(x)

assert g == h  # g and h point to different objects

The main takeaway is that you won't ever care about having your methods be the exact same thing, but you will care about having your leaves be the exact same thing! And you can check that by partitioning, and then asserting equality on the arrays.

@yuanz271
Copy link
Author

yuanz271 commented Mar 2, 2025

Maybe I didn't ask clearly.
My point is that, if you define _norm as an instance method instead of a Module field, it will not be compared (like __call__ of a Module). So, why this particular way?

@johannahaffner
Copy link

Ah! That is what you mean. This is consistent with how things are done elsewhere in the library. More complicated modules such as MultiheadAttention actually have several fields that are callable, e.g. here:

query_proj: Linear

Comparing any callable for equality is just not what is expected in Python, and I don't think that there is a good reason to push things into methods in order to be able to do so.

@yuanz271
Copy link
Author

yuanz271 commented Mar 2, 2025

That I understand because query_proj is a Module so that it handles well recursively. However, _norm is internally supposed to be static and stateless by WeightNorm. Since tree_equal compares arrays by value, the semantics is more like to discriminate by its function rather than reference at least for a Module (Module overloads __eq__ using tree_equal IIRC).

@johannahaffner
Copy link

Right. I can't speak for @patrick-kidger, but elsewhere in the Equinox system we do make frequent use of fields to specify some norm, e.g. here. The reason being that a design like that allows to make the norm public as an optional input argument if requested, and users might then like to use a norm other than the Frobenius norm which jnp.linalg.norm defaults to. If this was implemented as an instance method, this would require a breaking change, if it is a field it just requires the addition of a keyword argument to WeightNorm.__init__, which can default to the current behaviour and would not be breaking.

@patrick-kidger
Copy link
Owner

Okay, so! I think we can change this.

First of all, norm is assigned here:

self._norm = ft.partial(

and I assume the fact that it is dynamically creating new partials means that these are not comparing equal to each other later.

The fix is probably to cache based on axis, which is the only thing that changes. Or we could switch ft.partial to eqx.Partial, which I think has better equality semantics.

As for why this is a field, and indeed one that isn't marked static (as opposed to a method or a static field): this is because there have been a few use-cases in which people would like to dynamically patch this field using eqx.tree_at. It's unusual but valid!

I'd be happy to take a PR on the above! I think equality here is a reasonable thing to want :)

@patrick-kidger patrick-kidger added the feature New feature label Mar 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

3 participants