diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 600841378a..abfa13c880 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -23,7 +23,6 @@ from ..layernorm import canonicalize_layernorm_type from ..layernorm import layernorm, layernorm_fp8_dot from ..mlp import fused_layernorm_fp8_mlp, activation_lu -""" from ..mlp import layernorm_gelu_fp8_mlp, gelu """ from ..softmax import is_softmax_kernel_available from ..softmax import softmax, SoftmaxType from ..sharding import with_sharding_constraint_by_logical_axes @@ -937,7 +936,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_2_shape = (hidden_size,) if self.use_bias else (0,) bias_2 = nn_partitioning.param_with_axes('wo_bias', - self.bias_init, + self.bias_init, bias_2_shape, jnp.float32, axes=self.bias_axes_2) @@ -958,7 +957,6 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): activation_type = normalize_acts, use_bias = self.use_bias) else: # not use_fused_ln_geglu_mlp - print("HERE " + str(self.use_bias) + " " + str(fuse_layernorm) + " " + str(is_act_implemented)) # DenseGeneral 1 gemm1_fp8_meta_package = None if fp8_meta_package is None \ else fp8_meta_package.get_package_by_gemm_idx(0) diff --git a/transformer_engine/jax/mlp.py b/transformer_engine/jax/mlp.py index 2e852dd946..0eb9d15f5b 100644 --- a/transformer_engine/jax/mlp.py +++ b/transformer_engine/jax/mlp.py @@ -3,7 +3,7 @@ # See LICENSE for license information. """JAX MLP modules""" -from typing import List, Tuple +from typing import List, Tuple, Sequence, Union, Callable from functools import partial import jax @@ -129,7 +129,7 @@ def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarr bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool, epsilon: float, layernorm_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], - ffn1_ckpt_name: str, ffn2_ckpt_name: str, + ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], use_bias: bool): output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, @@ -184,7 +184,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule( # Squeeze act axis # (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out) if not is_gated: - kernel_1 = jnp.squeeze(kernel_1, axis=-2) + kernel_1 = jnp.squeeze(kernel_1, axis=-2) amax = FP8Helper.update_amax_history(amax) @@ -254,7 +254,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule( activation_lu_out_amax, activation_lu_out_scale, activation_lu_out_scale_inv, fwd_dtype) - casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out, dot_2_input_axes) + casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out, + dot_2_input_axes) kernel_2_scale = scale[gemm2_kernel_idx] kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] @@ -263,7 +264,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule( casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale) # (batch..., hidden_in) x (hidden_out, hidden_in) - dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2, activation_lu_out_scale_inv, + dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2, + activation_lu_out_scale_inv, kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) @@ -274,9 +276,9 @@ def _fused_layernorm_fp8_mlp_fwd_rule( dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1, - casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_activation_lu_amax, - updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims, - bias_1.shape, bias_2.shape) + casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, + updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, + x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape) return dot_2_output, ctx @@ -325,8 +327,8 @@ def _fused_layernorm_fp8_mlp_bwd_rule( # (hidden, batch...,) x (hidden, batch...) gemm2_x_scale_inv = scale_inv[gemm2_x_idx] - wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv, - grad.dtype, (xt_batch_dims, xt_batch_dims), + wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv, + grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) # (batch..., hidden_out) x (hidden_in, hidden_out) @@ -367,7 +369,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule( dactivation_lu_scale_inv, bwd_dtype, static_axis_boundary=-1) - else: + else: raise NotImplementedError ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1) @@ -377,7 +379,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule( # Check if not gated xt_batch_dims_2 = xt_batch_dims if not is_gated \ else tuple(i + 1 for i in xt_batch_dims) - wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv, + wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv, dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))