From f37d8b0a96c55e8d9d4353c91ea8fa5c0d19e12a Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Tue, 10 Dec 2024 06:47:08 +0000 Subject: [PATCH] Silent pylint, all classes inherited from praxis base layer is already dataclass Signed-off-by: Reese Wang --- transformer_engine/jax/praxis/module.py | 24 +++++++++++++++----- transformer_engine/jax/praxis/transformer.py | 8 +++++-- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py index 216fc743cc..005ab629df 100644 --- a/transformer_engine/jax/praxis/module.py +++ b/transformer_engine/jax/praxis/module.py @@ -75,7 +75,9 @@ class LayerNorm(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=0.0)) + bias_init: WeightInit = field( + default_factory=partial(WeightInit.Constant, scale=0.0) + ) # pylint: disable=invalid-field-call bias_axes: Tuple[str, ...] = () transpose_batch_sequence: bool = False @@ -130,7 +132,9 @@ class Linear(TransformerEngineBaseLayer): out_features: int = 512 kernel_axes: Tuple[str, ...] = () use_bias: bool = True - bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=0.0)) + bias_init: WeightInit = field( + default_factory=partial(WeightInit.Constant, scale=0.0) + ) # pylint: disable=invalid-field-call bias_axes: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 @@ -175,11 +179,15 @@ class LayerNormLinear(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=1.0)) + ln_bias_init: WeightInit = field( + default_factory=partial(WeightInit.Constant, scale=1.0) + ) # pylint: disable=invalid-field-call ln_bias_axes: Tuple[str, ...] = () kernel_axes: Tuple[str, ...] = () use_bias: bool = False - bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=0.0)) + bias_init: WeightInit = field( + default_factory=partial(WeightInit.Constant, scale=0.0) + ) # pylint: disable=invalid-field-call bias_axes: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 @@ -238,12 +246,16 @@ class LayerNormMLP(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=1.0)) + ln_bias_init: WeightInit = field( + default_factory=partial(WeightInit.Constant, scale=1.0) + ) # pylint: disable=invalid-field-call ln_bias_axes: Tuple[str, ...] = () kernel_axes_1: Tuple[str, ...] = () kernel_axes_2: Tuple[str, ...] = () use_bias: bool = False - bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=0.0)) + bias_init: WeightInit = field( + default_factory=partial(WeightInit.Constant, scale=0.0) + ) # pylint: disable=invalid-field-call bias_axes_1: Tuple[str, ...] = () bias_axes_2: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index 088b79d183..7bb8776162 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -139,7 +139,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer): zero_centered_gamma: bool = False return_layernorm_output: bool = False use_bias: bool = False - bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=0.0)) + bias_init: WeightInit = field( + default_factory=partial(WeightInit.Constant, scale=0.0) + ) # pylint: disable=invalid-field-call attn_mask_type: str = "causal" attn_bias_type: Optional[str] = None enable_rotary_pos_emb: bool = False @@ -276,7 +278,9 @@ class TransformerLayer(TransformerEngineBaseLayer): dropout_rng_name: str = "dropout" mlp_activations: Sequence[str] = ("relu",) use_bias: bool = False - bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=0.0)) + bias_init: WeightInit = field( + default_factory=partial(WeightInit.Constant, scale=0.0) + ) # pylint: disable=invalid-field-call apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False float32_attention_logits: bool = False