From 09ca03e7b4cc85f7828bc66a96e14bd5db0c26e4 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Tue, 10 Dec 2024 05:57:24 +0000 Subject: [PATCH] Add dataclass decorate for praxis layers to make link happy Signed-off-by: Reese Wang --- transformer_engine/jax/praxis/module.py | 8 +++++++- transformer_engine/jax/praxis/transformer.py | 6 +++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py index 216fc743cc..2d98831a45 100644 --- a/transformer_engine/jax/praxis/module.py +++ b/transformer_engine/jax/praxis/module.py @@ -4,7 +4,7 @@ """ Praxis Modules """ -from dataclasses import field +from dataclasses import dataclass, field from functools import partial from typing import Callable, Iterable, Sequence, Tuple, Union @@ -28,6 +28,7 @@ def _generate_ln_scale_init(scale_init): return scale_init +@dataclass class TransformerEngineBaseLayer(BaseLayer): """TransformerEngineBaseLayer""" @@ -67,6 +68,7 @@ def create_layer(self, name, flax_module_cls): self.create_child(name, flax_module_p.clone()) +@dataclass class LayerNorm(TransformerEngineBaseLayer): """LayerNorm""" @@ -103,6 +105,7 @@ def __call__(self, x: JTensor) -> JTensor: return self.layer_norm(x) +@dataclass class FusedSoftmax(TransformerEngineBaseLayer): """FusedSoftmax""" @@ -124,6 +127,7 @@ def __call__(self, x: JTensor, mask: JTensor = None, bias: JTensor = None) -> JT return self.fused_softmax(x, mask, bias) +@dataclass class Linear(TransformerEngineBaseLayer): """Linear""" @@ -165,6 +169,7 @@ def __call__(self, x: JTensor) -> JTensor: return self.linear(x) +@dataclass class LayerNormLinear(TransformerEngineBaseLayer): """LayerNormLinear""" @@ -228,6 +233,7 @@ def __call__(self, x: JTensor) -> JTensor: return self.ln_linear(x) +@dataclass class LayerNormMLP(TransformerEngineBaseLayer): """LayerNormMLP""" diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index 088b79d183..a5b6defb75 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -4,7 +4,7 @@ """ Praxis Modules related Transformer """ -from dataclasses import field +from dataclasses import dataclass, field from functools import partial from typing import Optional, Sequence, Tuple import warnings @@ -22,6 +22,7 @@ from ..attention import AttnBiasType, AttnMaskType +@dataclass class RelativePositionBiases(TransformerEngineBaseLayer): """RelativePositionBiases""" @@ -67,6 +68,7 @@ def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = T return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional) +@dataclass class DotProductAttention(TransformerEngineBaseLayer): """DotProductAttention""" @@ -125,6 +127,7 @@ def __call__( ) +@dataclass class MultiHeadAttention(TransformerEngineBaseLayer): """MultiHeadAttention""" @@ -258,6 +261,7 @@ def __call__( ) +@dataclass class TransformerLayer(TransformerEngineBaseLayer): """TransformerLayer"""