-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
32a23f2
commit eec2db3
Showing
2 changed files
with
52 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,60 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
from dataclasses import KW_ONLY, InitVar, dataclass | ||
from typing import Any, Self, TypeAlias | ||
|
||
from flax.experimental import nnx | ||
from jax.tree_util import register_pytree_node | ||
from typing_extensions import dataclass_transform, override | ||
|
||
from .helpers import field | ||
|
||
__all__ = ['module_field'] | ||
__all__ = ['module_field', 'DataClassModule'] | ||
|
||
|
||
def module_field(*, init: bool = False) -> Any: | ||
"""A field that contains submodules.""" | ||
return field(init=init, default=None, kw_only=True) # pylint: disable=invalid-field-call | ||
|
||
|
||
Children: TypeAlias = tuple[nnx.State, nnx.GraphDef[Any]] | ||
|
||
|
||
@dataclass_transform(field_specifiers=(module_field, | ||
nnx.field, | ||
nnx.treenode_field, | ||
nnx.variable_field, | ||
nnx.param_field)) | ||
class _DataClassModule(nnx.Module): | ||
@override | ||
def __init_subclass__(cls, | ||
*, | ||
init: bool = True, | ||
repr: bool = True, # noqa: A002 | ||
eq: bool = True, | ||
order: bool = False, | ||
kw_only: bool = False, | ||
**kwargs: Any) -> None: | ||
super().__init_subclass__(**kwargs) | ||
dataclass(init=init, repr=repr, eq=eq, order=order, kw_only=kw_only)(cls) | ||
|
||
def flatten_func(x: Self, /) -> tuple[Children, None]: | ||
state, graph = x.split() | ||
return ((state, graph), None) | ||
|
||
def unflatten_func(aux_data: None, children: Children, /) -> Self: | ||
assert aux_data is None | ||
state, graph = children | ||
assert isinstance(state, nnx.State) | ||
assert isinstance(graph, nnx.GraphDef) | ||
return graph.merge(state) | ||
|
||
register_pytree_node(cls, flatten_func, unflatten_func) | ||
|
||
|
||
class DataClassModule(_DataClassModule): | ||
_: KW_ONLY | ||
rngs: InitVar[nnx.Rngs] = nnx.field() | ||
|
||
def __post_init__(self, rngs: nnx.Rngs) -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from ._src.dataclasses.dataclass import DataclassInstance, TDataclassInstance, dataclass | ||
from ._src.dataclasses.flax import module_field | ||
from ._src.dataclasses.flax import DataClassModule, module_field | ||
from ._src.dataclasses.helpers import as_shallow_dict, field | ||
|
||
__all__ = ['dataclass', 'DataclassInstance', 'TDataclassInstance', 'as_shallow_dict', 'field', | ||
'module_field'] | ||
'module_field', 'DataClassModule'] |