Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Converter only applied when __post_init__ is defined for subclasses #969

Open
gorold opened this issue Mar 9, 2025 · 0 comments
Open

Converter only applied when __post_init__ is defined for subclasses #969

gorold opened this issue Mar 9, 2025 · 0 comments

Comments

@gorold
Copy link

gorold commented Mar 9, 2025

I want to have the following inheritance structure, where subclasses only need to define a function to initialise the static fields, but encountered this unintuitive behavior, where the converter is only applied when __post_init__ is defined for the children.

Reproduce

import abc
import equinox as eqx
import jax
import jax.numpy as jnp


class A(eqx.Module):
    a: jax.Array = eqx.field(converter=jnp.asarray)
    b: int = eqx.field(init=False, static=True)

    def __post_init__(self):
        # do something here for all subclasses
        print(type(self.a))
        self.init_b()

    @abc.abstractmethod
    def init_b(self): ...
    
class B(A):
    def init_b(self):
        self.b = 1

class C(A):
    def __post_init__(self):
        super().__post_init__()

    def init_b(self):
        self.b = 1

b = B(0)
c = C(0)

Output

<class 'int'>
<class 'jaxlib.xla_extension.ArrayImpl'>

Expected Output

<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant