Skip to content

Commit

Permalink
Silent pylint, all classes inherited from praxis base layer is alread…
Browse files Browse the repository at this point in the history
…y dataclass

Signed-off-by: Reese Wang <rewang@nvidia.com>
  • Loading branch information
zlsh80826 committed Dec 10, 2024
1 parent 0e912ad commit f37d8b0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
24 changes: 18 additions & 6 deletions transformer_engine/jax/praxis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ class LayerNorm(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=0.0))
bias_init: WeightInit = field(
default_factory=partial(WeightInit.Constant, scale=0.0)
) # pylint: disable=invalid-field-call
bias_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False

Expand Down Expand Up @@ -130,7 +132,9 @@ class Linear(TransformerEngineBaseLayer):
out_features: int = 512
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = True
bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=0.0))
bias_init: WeightInit = field(
default_factory=partial(WeightInit.Constant, scale=0.0)
) # pylint: disable=invalid-field-call
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
Expand Down Expand Up @@ -175,11 +179,15 @@ class LayerNormLinear(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=1.0))
ln_bias_init: WeightInit = field(
default_factory=partial(WeightInit.Constant, scale=1.0)
) # pylint: disable=invalid-field-call
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=0.0))
bias_init: WeightInit = field(
default_factory=partial(WeightInit.Constant, scale=0.0)
) # pylint: disable=invalid-field-call
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
Expand Down Expand Up @@ -238,12 +246,16 @@ class LayerNormMLP(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=1.0))
ln_bias_init: WeightInit = field(
default_factory=partial(WeightInit.Constant, scale=1.0)
) # pylint: disable=invalid-field-call
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes_1: Tuple[str, ...] = ()
kernel_axes_2: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=0.0))
bias_init: WeightInit = field(
default_factory=partial(WeightInit.Constant, scale=0.0)
) # pylint: disable=invalid-field-call
bias_axes_1: Tuple[str, ...] = ()
bias_axes_2: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/jax/praxis/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
return_layernorm_output: bool = False
use_bias: bool = False
bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=0.0))
bias_init: WeightInit = field(
default_factory=partial(WeightInit.Constant, scale=0.0)
) # pylint: disable=invalid-field-call
attn_mask_type: str = "causal"
attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
Expand Down Expand Up @@ -276,7 +278,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
dropout_rng_name: str = "dropout"
mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
bias_init: WeightInit = field(default_factory=partial(WeightInit.Constant, scale=0.0))
bias_init: WeightInit = field(
default_factory=partial(WeightInit.Constant, scale=0.0)
) # pylint: disable=invalid-field-call
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
float32_attention_logits: bool = False
Expand Down

0 comments on commit f37d8b0

Please sign in to comment.