diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py index b82c0915e4..e5649bfe7c 100644 --- a/transformer_engine/jax/praxis/module.py +++ b/transformer_engine/jax/praxis/module.py @@ -4,6 +4,7 @@ """ Praxis Modules """ +from dataclasses import field from functools import partial from typing import Callable, Iterable, Sequence, Tuple, Union @@ -74,7 +75,9 @@ class LayerNorm(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes: Tuple[str, ...] = () transpose_batch_sequence: bool = False @@ -129,7 +132,9 @@ class Linear(TransformerEngineBaseLayer): out_features: int = 512 kernel_axes: Tuple[str, ...] = () use_bias: bool = True - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 @@ -174,11 +179,15 @@ class LayerNormLinear(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = WeightInit.Constant(1.0) + ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=1.0) + ) ln_bias_axes: Tuple[str, ...] = () kernel_axes: Tuple[str, ...] = () use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 @@ -237,12 +246,16 @@ class LayerNormMLP(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = WeightInit.Constant(1.0) + ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=1.0) + ) ln_bias_axes: Tuple[str, ...] = () kernel_axes_1: Tuple[str, ...] = () kernel_axes_2: Tuple[str, ...] = () use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes_1: Tuple[str, ...] = () bias_axes_2: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index f2ac802f10..2ae212afb9 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -4,6 +4,7 @@ """ Praxis Modules related Transformer """ +from dataclasses import field from functools import partial from typing import Optional, Sequence, Tuple import warnings @@ -138,7 +139,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer): zero_centered_gamma: bool = False return_layernorm_output: bool = False use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) attn_mask_type: str = "causal" attn_bias_type: Optional[str] = None enable_rotary_pos_emb: bool = False @@ -275,7 +278,9 @@ class TransformerLayer(TransformerEngineBaseLayer): dropout_rng_name: str = "dropout" mlp_activations: Sequence[str] = ("relu",) use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False float32_attention_logits: bool = False