Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
  • Loading branch information
phu0ngng committed Apr 9, 2024
1 parent a2e06a0 commit 1e0d0e4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
4 changes: 1 addition & 3 deletions transformer_engine/jax/flax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
from ..mlp import fused_layernorm_fp8_mlp, activation_lu
""" from ..mlp import layernorm_gelu_fp8_mlp, gelu """
from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
Expand Down Expand Up @@ -937,7 +936,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):

bias_2_shape = (hidden_size,) if self.use_bias else (0,)
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init,
self.bias_init,
bias_2_shape,
jnp.float32,
axes=self.bias_axes_2)
Expand All @@ -958,7 +957,6 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
activation_type = normalize_acts,
use_bias = self.use_bias)
else: # not use_fused_ln_geglu_mlp
print("HERE " + str(self.use_bias) + " " + str(fuse_layernorm) + " " + str(is_act_implemented))
# DenseGeneral 1
gemm1_fp8_meta_package = None if fp8_meta_package is None \
else fp8_meta_package.get_package_by_gemm_idx(0)
Expand Down
26 changes: 14 additions & 12 deletions transformer_engine/jax/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# See LICENSE for license information.
"""JAX MLP modules"""

from typing import List, Tuple
from typing import List, Tuple, Sequence, Union, Callable
from functools import partial

import jax
Expand Down Expand Up @@ -129,7 +129,7 @@ def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarr
bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
epsilon: float, layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn2_ckpt_name: str,
ffn1_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
use_bias: bool):
output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1,
Expand Down Expand Up @@ -184,7 +184,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
# Squeeze act axis
# (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out)
if not is_gated:
kernel_1 = jnp.squeeze(kernel_1, axis=-2)
kernel_1 = jnp.squeeze(kernel_1, axis=-2)

amax = FP8Helper.update_amax_history(amax)

Expand Down Expand Up @@ -254,7 +254,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
activation_lu_out_amax, activation_lu_out_scale,
activation_lu_out_scale_inv, fwd_dtype)

casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out, dot_2_input_axes)
casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out,
dot_2_input_axes)

kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
Expand All @@ -263,7 +264,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)

# (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2, activation_lu_out_scale_inv,
dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2,
activation_lu_out_scale_inv,
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))

Expand All @@ -274,9 +276,9 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)

ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_activation_lu_amax,
updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims,
bias_1.shape, bias_2.shape)
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax,
x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape)

return dot_2_output, ctx

Expand Down Expand Up @@ -325,8 +327,8 @@ def _fused_layernorm_fp8_mlp_bwd_rule(

# (hidden, batch...,) x (hidden, batch...)
gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv,
grad.dtype, (xt_batch_dims, xt_batch_dims),
wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv,
grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

# (batch..., hidden_out) x (hidden_in, hidden_out)
Expand Down Expand Up @@ -367,7 +369,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1)
else:
else:
raise NotImplementedError

ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
Expand All @@ -377,7 +379,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
# Check if not gated
xt_batch_dims_2 = xt_batch_dims if not is_gated \
else tuple(i + 1 for i in xt_batch_dims)
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
dactivation_lu_scale_inv, grad.dtype,
(xt_batch_dims, xt_batch_dims_2),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
Expand Down

0 comments on commit 1e0d0e4

Please sign in to comment.