diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 1fba5c7404..fab2b838e4 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -4,7 +4,7 @@ import functools import operator -from typing import Callable, Sequence, Tuple, Union +from typing import Callable, Sequence, Union import jax import jax.numpy as jnp diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index d0ada5811f..600841378a 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -834,10 +834,11 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: fuse_layernorm = FP8Helper.is_fp8_enabled( ) and not self.return_layernorm_output and self.enable_layernorm - gated_act_pool = [('gelu', 'linear'), - ('linear', 'silu')] # Make sure this is sorted in alphabet order - act_pool = [('gelu',), - ('silu',)] + # Make sure each tuple is sorted in alphabet order + gated_act_pool = [('gelu', 'linear')] + #('linear', 'silu')] coming + act_pool = [('gelu',)] + #('silu',)] coming normalize_acts = [] for act in self.activations: if not isinstance(act, str): diff --git a/transformer_engine/jax/mlp.py b/transformer_engine/jax/mlp.py index 26fa111fe5..2e852dd946 100644 --- a/transformer_engine/jax/mlp.py +++ b/transformer_engine/jax/mlp.py @@ -5,7 +5,6 @@ from typing import List, Tuple from functools import partial -from typing import Callable, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -133,12 +132,13 @@ def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarr ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], use_bias: bool): - output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, - fp8_max, amax, scale, scale_inv, fwd_dtype, - bwd_dtype, layernorm_type, zero_centered_gamma, - epsilon, layernorm_input_axes, dot_1_input_axes, - dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, - activation_type, use_bias) + output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, + bias_2, fp8_max, amax, scale, scale_inv, + fwd_dtype, bwd_dtype, layernorm_type, + zero_centered_gamma, epsilon, + layernorm_input_axes, dot_1_input_axes, + dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, + activation_type, use_bias) return output @@ -392,7 +392,8 @@ def _fused_layernorm_fp8_mlp_bwd_rule( else: x_contracting_dims = (x_contracting_dims, (1,)) kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] - dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv, kernel_1_scale_inv, + dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, + dactivation_lu_scale_inv, kernel_1_scale_inv, grad.dtype, x_contracting_dims, get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))