Skip to content

Commit

Permalink
Add DataClassModule
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Nov 30, 2023
1 parent 32a23f2 commit eec2db3
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
52 changes: 50 additions & 2 deletions tjax/_src/dataclasses/flax.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,60 @@
from __future__ import annotations

from typing import Any
from dataclasses import KW_ONLY, InitVar, dataclass
from typing import Any, Self, TypeAlias

from flax.experimental import nnx
from jax.tree_util import register_pytree_node
from typing_extensions import dataclass_transform, override

from .helpers import field

__all__ = ['module_field']
__all__ = ['module_field', 'DataClassModule']


def module_field(*, init: bool = False) -> Any:
"""A field that contains submodules."""
return field(init=init, default=None, kw_only=True) # pylint: disable=invalid-field-call


Children: TypeAlias = tuple[nnx.State, nnx.GraphDef[Any]]


@dataclass_transform(field_specifiers=(module_field,
nnx.field,
nnx.treenode_field,
nnx.variable_field,
nnx.param_field))
class _DataClassModule(nnx.Module):
@override
def __init_subclass__(cls,
*,
init: bool = True,
repr: bool = True, # noqa: A002
eq: bool = True,
order: bool = False,
kw_only: bool = False,
**kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
dataclass(init=init, repr=repr, eq=eq, order=order, kw_only=kw_only)(cls)

def flatten_func(x: Self, /) -> tuple[Children, None]:
state, graph = x.split()
return ((state, graph), None)

def unflatten_func(aux_data: None, children: Children, /) -> Self:
assert aux_data is None
state, graph = children
assert isinstance(state, nnx.State)
assert isinstance(graph, nnx.GraphDef)
return graph.merge(state)

register_pytree_node(cls, flatten_func, unflatten_func)


class DataClassModule(_DataClassModule):
_: KW_ONLY
rngs: InitVar[nnx.Rngs] = nnx.field()

def __post_init__(self, rngs: nnx.Rngs) -> None:
pass
4 changes: 2 additions & 2 deletions tjax/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._src.dataclasses.dataclass import DataclassInstance, TDataclassInstance, dataclass
from ._src.dataclasses.flax import module_field
from ._src.dataclasses.flax import DataClassModule, module_field
from ._src.dataclasses.helpers import as_shallow_dict, field

__all__ = ['dataclass', 'DataclassInstance', 'TDataclassInstance', 'as_shallow_dict', 'field',
'module_field']
'module_field', 'DataClassModule']

0 comments on commit eec2db3

Please sign in to comment.