Skip to content

Commit

Permalink
[JAX] Fix failure on pattern matching of FP8 GEMM when enabling FSDP. (
Browse files Browse the repository at this point in the history
…#547)

* Adding Cast custom call

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Applying cast to the kernel of layernorm_fp8_dot

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Applying native cast to the kernel of fp8_dot.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Apply Cast and native cast to layernorm_geglu_fp8_dot

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Fix the bug to enable layernorm_geglu_fp8_dot in LayernormMlp

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Modifiied code with the review feedback.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Adding 2xACC control to FP8 GEMMs.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Set precision as an static arg

Signed-off-by: Ming Huang <mingh@nvidia.com>

---------

Signed-off-by: Ming Huang <mingh@nvidia.com>
  • Loading branch information
mingxu1067 authored Jan 12, 2024
1 parent e547f8e commit 2ae121d
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 72 deletions.
2 changes: 1 addition & 1 deletion tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_qdq(self):
scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
scale_inv = (1 / scale).reshape(1)

y = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)

assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
Expand Down
127 changes: 127 additions & 0 deletions transformer_engine/jax/cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2884,6 +2884,133 @@ def cast_transpose(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_
transpose_axis_boundary=transpose_axis_boundary)


class CastFP8Primitive(BasePrimitive):
"""
Cast Primitive
"""
name = "te_quantize"
multiple_results = True
impl_static_args = (4,)
inner_primitive = None
outer_primitive = None

@staticmethod
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
"""
te_cast abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32

casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

return casted_x_aval, updated_amax_aval

@staticmethod
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
"""
te_cast lowering rules
"""
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape

out_types = [
ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

opaque = transformer_engine_jax.pack_common_descriptor(ir_x_shape,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype))

out = custom_caller(CastFP8Primitive.name,
args,
opaque,
False,
operand_output_aliases={1: 1})

return out

@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype):
"""
te_cast implementation
"""
assert CastFP8Primitive.inner_primitive is not None
casted_x, updated_amax = \
CastFP8Primitive.inner_primitive.bind(
x, amax, scale, scale_inv, out_dtype=out_dtype)
return casted_x, updated_amax

@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype):
_check_valid_batch_dims(batch_dims)
assert CastFP8Primitive.outer_primitive is not None

x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, *_ = batch_dims

out_bdims = x_bdim, x_bdim, amax_bdim
return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv,
out_dtype=out_dtype), out_bdims

@staticmethod
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (casted_x_sharding, amax_sharding)

@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, amax_sharding)

def sharded_impl(x, amax, scale, scale_inv):
local_cx, local_updated_amax = \
CastFP8Primitive.impl(x, amax, scale, scale_inv, out_dtype=out_dtype)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)

return local_cx, global_updated_amax

return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(CastFP8Primitive)


def cast_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: TEDType) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Cast wrapper
Return FP8 tensor
"""
return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)


class TransposePrimitive(BasePrimitive):
"""
Transpose Primitive
Expand Down
64 changes: 36 additions & 28 deletions transformer_engine/jax/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .cpp_extensions import cast_transpose
from .fp8 import FP8Helper, FP8MetaPackage

Precision = jax.lax.Precision


def type_safe_dot_general(
x,
Expand Down Expand Up @@ -40,10 +42,11 @@ def quantize(x, q_dtype, scale):
"""
Quantize with scale.
"""
updated_amax = jnp.max(jnp.abs(x)).astype(scale.dtype)
dtype_max = (jnp.finfo(q_dtype).max).astype(x.dtype)
scale = scale.astype(x.dtype)
clipped_scaled_x = jnp.clip((x * scale), -dtype_max, dtype_max)
return clipped_scaled_x.astype(q_dtype)
return clipped_scaled_x.astype(q_dtype), updated_amax


def dequantize(x, dq_dtype, scale_inv):
Expand All @@ -54,14 +57,15 @@ def dequantize(x, dq_dtype, scale_inv):


# Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(4, 5))
@partial(jax.jit, static_argnums=(4, 5, 6))
def fp8_dot_impl(
q_lhs: jnp.ndarray,
q_rhs: jnp.ndarray,
lhs_scale_inv: jnp.ndarray,
rhs_scale_inv: jnp.ndarray,
ctype: jnp.dtype, # computing type
contracting_dims: Tuple[Sequence[int], Sequence[int]]):
contracting_dims: Tuple[Sequence[int], Sequence[int]],
precision: Precision = None):
"""
FP8 GEMM for XLA pattern match
"""
Expand All @@ -70,7 +74,14 @@ def fp8_dot_impl(
lhs = dequantize(q_lhs, ctype, lhs_scale_inv)
rhs = dequantize(q_rhs, ctype, rhs_scale_inv)

return jax.lax.dot_general(lhs, rhs, dim_nums)
return jax.lax.dot_general(lhs, rhs, dim_nums, precision=precision)


def get_precision_of_fp8_dot(enable_2xACC: bool):
"""
Get Precision of FP8 DOT.
"""
return jax.lax.Precision.HIGHEST if enable_2xACC else jax.lax.Precision.DEFAULT


@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
Expand Down Expand Up @@ -102,36 +113,31 @@ def _fp8_dot_fwd_rule(

gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

x_amax = amax[gemm_x_idx, 0:1]
x_scale = scale[gemm_x_idx]
x_scale_inv = scale_inv[gemm_x_idx]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_x, updated_x_amax = quantize(x, fwd_dtype, x_scale)

casted_x, casted_xt, updated_x_amax = \
cast_transpose(x, x_amax, x_scale, x_scale_inv, fwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims))

kernel_amax = amax[gemm_kernel_idx, 0:1]
kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel, updated_kernel_amax = quantize(kernel, fwd_dtype, kernel_scale)

casted_kerenl, casted_kerenl_t, updated_kernel_amax = \
cast_transpose(kernel, kernel_amax, kernel_scale, kernel_scale_inv,
fwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=(max(rhs_contracting_dims) + 1))

rhs_t_contracting_dims = tuple(range(kernel.ndim - len(rhs_contracting_dims), kernel.ndim))
output = fp8_dot_impl(casted_x, casted_kerenl_t, x_scale_inv, kernel_scale_inv, x.dtype,
(lhs_contracting_dims, rhs_t_contracting_dims))
output = fp8_dot_impl(casted_x, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
(lhs_contracting_dims, rhs_contracting_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))

ctx = (casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax,
ctx = (casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_kernel_amax, x.shape, kernel.shape)
return output, ctx


def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # pylint: disable=unused-argument
lhs_contracting_dims, rhs_contracting_dims = contracting_dims

casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, \
casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, \
updated_x_amax, updated_kernel_amax, x_shape, kernel_shape = ctx

gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
Expand All @@ -145,21 +151,23 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p
bwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims))

xt_constracting_dim = tuple(range(len(lhs_contracting_dims), len(x_shape)))
gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
x_constracting_dim = tuple(range(0, len(x_shape) - len(lhs_contracting_dims)))
gt_constracting_dim = tuple(range(grad.ndim - len(x_constracting_dim), grad.ndim))
x_scale_inv = scale_inv[gemm_x_idx]
wgrad = fp8_dot_impl(casted_xt, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
(xt_constracting_dim, gt_constracting_dim))
wgrad = fp8_dot_impl(casted_x, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
(x_constracting_dim, gt_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

g_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim))
k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
kernel_scale_inv = scale_inv[gemm_kernel_idx]
dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype,
(g_constracting_dim, k_constracting_dim))
dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
(g_constracting_dim, k_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
amax = amax.at[gemm_x_idx, 0].set(updated_x_amax)
amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax)
amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])

scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/flax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def is_geglu(acts):
if not isinstance(act, str):
return False
normalize_acts.append(act.lower())
return normalize_acts in geglu_act_pool
return tuple(normalize_acts) in geglu_act_pool

use_fused_ln_mlp = fuse_layernorm \
and (not self.use_bias) and is_geglu(self.activations) \
Expand Down
31 changes: 17 additions & 14 deletions transformer_engine/jax/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import jax
import jax.numpy as jnp

from .cpp_extensions import cast_transpose, transpose
from .cpp_extensions import cast_fp8, cast_transpose, transpose
from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
from .dot import fp8_dot_impl
from .dot import fp8_dot_impl, get_precision_of_fp8_dot
from .fp8 import FP8Helper, FP8MetaPackage


Expand Down Expand Up @@ -186,16 +186,17 @@ def _layernorm_fp8_dot_fwd_rule(
kernel_scale_inv = scale_inv[gemm_kernel_idx]

# Kernel in (hidden_in, hidden_out...)
casted_kerenl, casted_kerenl_t, updated_kernel_amax = \
cast_transpose(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=1)
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel, updated_kernel_amax = \
cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype)

# (batch..., hidden_in) x (hidden_in, hidden_out...)
kt_contracting_dims = (kernel.ndim - 1,)
output = fp8_dot_impl(ln_out, casted_kerenl_t, x_scale_inv, kernel_scale_inv, x.dtype,
(x_contracting_dims, kt_contracting_dims))
output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
(x_contracting_dims, k_contracting_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))

ctx = (ln_out, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax,
ctx = (ln_out, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims,
k_contracting_dims)

Expand All @@ -210,7 +211,7 @@ def _layernorm_fp8_dot_bwd_rule(
epsilon,
ctx,
grad):
ln_out_, casted_kerenl, fp8_max, amax, scale, scale_inv, \
ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \
updated_x_amax, updated_kernel_amax, \
x_shape, kernel_shape, mu, rsigma, x, gamma, \
x_contracting_dims, k_contracting_dims = ctx
Expand All @@ -231,14 +232,16 @@ def _layernorm_fp8_dot_bwd_rule(
gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
x_scale_inv = scale_inv[gemm_x_idx]
wgrad = fp8_dot_impl(ln_out_t, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
(xt_constracting_dim, gt_constracting_dim))
(xt_constracting_dim, gt_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))

g_constracting_dim = tuple(
g_for_dgrad_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim))
k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape)))
kernel_scale_inv = scale_inv[gemm_kernel_idx]
dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype,
(g_constracting_dim, k_constracting_dim))
dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
(g_for_dgrad_constracting_dim, k_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))

if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dgrad,
Expand Down
Loading

0 comments on commit 2ae121d

Please sign in to comment.