Skip to content

Commit

Permalink
rs conflicts
Browse files Browse the repository at this point in the history
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
  • Loading branch information
phu0ngng committed Apr 9, 2024
1 parent 0b20c83 commit a2e06a0
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
2 changes: 1 addition & 1 deletion tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import functools
import operator
from typing import Callable, Sequence, Tuple, Union
from typing import Callable, Sequence, Union

import jax
import jax.numpy as jnp
Expand Down
9 changes: 5 additions & 4 deletions transformer_engine/jax/flax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,10 +834,11 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm

gated_act_pool = [('gelu', 'linear'),
('linear', 'silu')] # Make sure this is sorted in alphabet order
act_pool = [('gelu',),
('silu',)]
# Make sure each tuple is sorted in alphabet order
gated_act_pool = [('gelu', 'linear')]
#('linear', 'silu')] coming
act_pool = [('gelu',)]
#('silu',)] coming
normalize_acts = []
for act in self.activations:
if not isinstance(act, str):
Expand Down
17 changes: 9 additions & 8 deletions transformer_engine/jax/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing import List, Tuple
from functools import partial
from typing import Callable, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -133,12 +132,13 @@ def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarr
ffn1_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
use_bias: bool):
output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2,
fp8_max, amax, scale, scale_inv, fwd_dtype,
bwd_dtype, layernorm_type, zero_centered_gamma,
epsilon, layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
activation_type, use_bias)
output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1,
bias_2, fp8_max, amax, scale, scale_inv,
fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon,
layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
activation_type, use_bias)
return output


Expand Down Expand Up @@ -392,7 +392,8 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
else:
x_contracting_dims = (x_contracting_dims, (1,))
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv, kernel_1_scale_inv,
dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1,
dactivation_lu_scale_inv, kernel_1_scale_inv,
grad.dtype, x_contracting_dims,
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

Expand Down

0 comments on commit a2e06a0

Please sign in to comment.