diff --git a/.gitmodules b/.gitmodules index 21492db5ef..7fc91b1f54 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend url = https://github.com/NVIDIA/cudnn-frontend.git +[submodule "3rdparty/dlpack"] + path = 3rdparty/dlpack + url = git@github.com:dmlc/dlpack.git diff --git a/3rdparty/dlpack b/3rdparty/dlpack new file mode 160000 index 0000000000..bbd2f4d324 --- /dev/null +++ b/3rdparty/dlpack @@ -0,0 +1 @@ +Subproject commit bbd2f4d32427e548797929af08cfe2a9cbb3cf12 diff --git a/build_tools/jax.py b/build_tools/jax.py index f829230f50..bb4da4e5ed 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -5,6 +5,7 @@ """JAX related extensions.""" import os from pathlib import Path +from typing import Optional import setuptools from glob import glob @@ -36,6 +37,7 @@ def setup_jax_extension( csrc_source_files, csrc_header_files, common_header_files, + third_party_packages, ) -> setuptools.Extension: """Setup PyBind11 extension for JAX support""" # Source files @@ -55,12 +57,28 @@ def setup_jax_extension( common_header_files / "common" / "include", csrc_header_files, xla_home, + third_party_packages / "dlpack" / "include", ] # Compile flags cxx_flags = ["-O3"] nvcc_flags = ["-O3"] + # Userbuffers MPI dependence + libraries = [] + library_dirs = [] + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): + mpi_home = os.getenv("MPI_HOME") + assert mpi_home is not None, "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" + mpi_home = Path(mpi_home) + libraries.append("mpi") + library_dirs.append(mpi_home / "lib") + + include_dirs.append(mpi_home / "include") + + cxx_flags.append("-DNVTE_UB_WITH_MPI") + nvcc_flags.append("-DNVTE_UB_WITH_MPI") + # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension @@ -79,5 +97,7 @@ def _add_cflags(self, flags: List[str]) -> None: "transformer_engine_jax", sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], + library_dirs=[str(path) for path in library_dirs], + libraries=libraries, extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags}, ) diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py new file mode 100644 index 0000000000..8dc3035fbf --- /dev/null +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -0,0 +1,241 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Comm+GEMM Overlap with TE/JAX""" + +import argparse +import numpy as np +from mpi4py import MPI + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils + +import transformer_engine.jax as te +from transformer_engine import transformer_engine_jax as tex +from transformer_engine.jax.cpp_extensions import gemm_impl, copy_into_overlap_buffer +from transformer_engine.jax.gemm import ( + initialize_comm_gemm_overlaps, + destroy_comm_gemm_overlaps, + get_comm_overlap_config, +) +from transformer_engine.jax.sharding import get_padded_spec + +jax.clear_caches() + +# This script needs to be launched via `mpirun` with 1 process per GPU +myrank = MPI.COMM_WORLD.Get_rank() +numranks = MPI.COMM_WORLD.Get_size() +jax.distributed.initialize(cluster_detection_method="mpi4py") + +parser = argparse.ArgumentParser() +parser.add_argument("-dp", "--dp-size", type=int, default=1) +parser.add_argument("-zp", "--fsdp-size", type=int, default=2) +parser.add_argument("-tp", "--tp-size", type=int, default=4) +parser.add_argument("-np", "--num-gpus", type=int, default=8) +parser.add_argument("--batch-size", type=int, default=2) +parser.add_argument("--seq-length", type=int, default=8192) +parser.add_argument("--hidden-size", type=int, default=16384) +parser.add_argument("--activation-size", type=int, default=53248) +parser.add_argument("--no-batch", action="store_true") +parser.add_argument("--no-fsdp", action="store_true") +parser.add_argument("--comm-type", type=str.upper, default="AG", choices=["AG", "RS"]) +parser.add_argument("--check-result", action="store_true") +args = parser.parse_args() + +# Operand shapes +dtype = jnp.bfloat16 +lhs_shape = ( + [args.seq_length, args.hidden_size] + if args.comm_type == "AG" + else [args.seq_length, args.activation_size] +) +rhs_shape = ( + [args.hidden_size, args.activation_size] + if args.comm_type == "AG" + else [args.activation_size, args.hidden_size] +) + +# Operand partitioning +batched = not args.no_batch +fsdp = not args.no_fsdp +if batched: + lhs_shape = [args.batch_size] + lhs_shape + if fsdp: + mesh_shape = {"dp": args.dp_size, "zp": args.fsdp_size, "tp": args.tp_size} + mesh_resource = te.MeshResource( + dp_resource="dp", tp_resource="tp", cp_resource="tp", fsdp_resource="zp" + ) + if args.comm_type == "AG": + input_specs = [("dp", "zp"), "tp", None] + weight_specs = ["zp", "tp"] + weight_no_fsdp = [None, "tp"] + elif args.comm_type == "RS": + input_specs = [("dp", "zp"), None, "tp"] + weight_specs = ["tp", "zp"] + weight_no_fsdp = ["tp", None] + else: + mesh_shape = {"dp": args.dp_size, "tp": args.tp_size} + mesh_resource = te.MeshResource( + dp_resource="dp", + tp_resource="tp", + cp_resource="tp", + ) + if args.comm_type == "AG": + input_specs = ["dp", "tp", None] + weight_specs = [None, "tp"] + elif args.comm_type == "RS": + input_specs = ["dp", None, "tp"] + weight_specs = ["tp", None] + weight_no_fsdp = weight_specs +else: + if fsdp: + mesh_shape = {"zp": args.fsdp_size, "tp": args.tp_size} + mesh_resource = te.MeshResource(fsdp_resource="zp", tp_resource="tp", cp_resource="cp") + if args.comm_type == "AG": + input_specs = ["tp", None] + weight_specs = ["zp", "tp"] + elif args.comm_type == "RS": + input_specs = [None, "tp"] + weight_specs = ["tp", "zp"] + weight_no_fsdp = ["tp", None] + else: + mesh_shape = {"tp": args.tp_size} + mesh_resource = te.MeshResource(tp_resource="tp", cp_resource="cp") + if args.comm_type == "AG": + input_specs = ["tp", None] + weight_specs = [None, "tp"] + elif args.comm_type == "RS": + input_specs = [None, "tp"] + weight_specs = ["tp", None] + weight_no_fsdp = weight_specs + +# Mesh setup and sharding definitions +devices = mesh_utils.create_device_mesh((args.num_gpus,), devices=jax.devices()[: args.num_gpus]) +mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) +no_sharding = NamedSharding(mesh, PartitionSpec(None)) +input_sharding = NamedSharding(mesh, PartitionSpec(*input_specs)) +weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_specs)) +weight_no_fsdp_sharding = NamedSharding(mesh, PartitionSpec(*weight_no_fsdp)) + +# Operand initialization +key = jax.random.PRNGKey(0) +key1, key2 = jax.random.split(key, 2) +lhs_data = jax.random.normal(key1, lhs_shape, dtype=dtype) +rhs_data = jax.random.normal(key2, rhs_shape, dtype=dtype) +lhs = jax.device_put(lhs_data, input_sharding) +rhs = jax.device_put(rhs_data, weight_sharding) + +# Name of comm+GEMM overlap layer +overlap_name = "ag_gemm" if args.comm_type == "AG" else "gemm_rs" + +# Bootstrap Userbuffers communicators and communication buffers +initialize_comm_gemm_overlaps( + lhs_shape, + mesh, + myrank, + numranks, + tp_resource="tp", + overlap_configs={ + overlap_name: { + "method": "ring_exchange", # "pipeline" for collective kernels instead of send/recv + "comm_type": ( + tex.CommOverlapType.AG if args.comm_type == "AG" else tex.CommOverlapType.RS + ), + "num_splits": args.tp_size, # independent of TP size for "pipeline" + "cga_size": 1, # default is 2 for "pipeline" + "num_sm": 1, # ignored for "ring_exchange", must be tuned for "pipeline" + "set_sm_margin": False, # set to True for "pipeline" + "atomic_gemm": False, # more performant when not using CUDA Graphs + "use_ce": True, # ignored (always False) for "pipeline" method + }, + }, +) + +if myrank == 0: + print( + f"{myrank}: INPUTS {lhs.shape} x {rhs.shape}\n" + + f"{myrank}: LHS sharding: {lhs.sharding.spec}\n" + + f"{myrank}: RHS sharding: {rhs.sharding.spec}\n", + flush=True, + ) + + +@jax.jit +def te_gemm(A, B): + # For AG overlap, LHS needs to be copied into the comm. buffer before GEMM. This can usually + # be circumvented by extracting the comm. buffer as a JAX array via + # `buffer = jax.dlpack.from_dlpack(tex.get_overlap_buffer(overlap_name: str, sharded: bool))` + # and directly writing the result of a preceding operation into it (e.g.. LayerNorm output + # written directly into the communication buffer before AG+GEMM in a QKV projection) + if args.comm_type == "AG": + copy_into_overlap_buffer(A, overlap_name, True) + return_idx = 0 + else: + # For RS overlap, the scattered output is in the `extra_out` array. + return_idx = -1 + + return gemm_impl( + A, + jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding), # all-gather FSDP weights + batched_output=not args.no_batch, # internal option, will be hidden by the FWD/BWD wrapper + comm_overlap_config=get_comm_overlap_config(overlap_name), + )[return_idx] + + +with te.sharding.global_shard_guard(mesh_resource): + output = te_gemm(lhs, rhs) + +if myrank == 0: + print( + f"{myrank}: {'AG -> GEMM' if args.comm_type == 'AG' else 'GEMM -> RS'} OUTPUT " + + f"{output.shape}\n" + + f"{myrank}: Sharding: {get_padded_spec(output.sharding.spec, output.ndim)}\n", + flush=True, + ) + +if args.check_result: + ref_global = jnp.matmul( + jax.device_put(lhs_data, no_sharding), jax.device_put(rhs_data, no_sharding) + ) + if myrank == 0: + print(f"{myrank}: Global reference: {ref_global}\n", flush=True) + + output_global = jax.lax.with_sharding_constraint(output, no_sharding) + if myrank == 0: + print(f"{myrank}: Global output: {output_global}\n", flush=True) + + diff = jnp.abs(ref_global - output_global).flatten() + if myrank == 0: + print(f"{myrank}: Global difference: {diff}\n", flush=True) + + m = jnp.argmax(diff).item() + abs_err = diff[m].item() + rel_err = abs_err / max(abs(ref_global.flatten()[m]), 1e-5) + + rtol = 0.02 + atol = 0.001 + numerics_failed = False + if rel_err > rtol and abs_err > atol: + numerics_failed = True + numerics_info = ( + "NUMERICAL CHECK FAILED: " + + f"Outputs not close enough at index {m} " + + f"with {output.flatten()[m].item()} vs {ref_global.flatten()[m].item()} | " + + f"rel. error = {rel_err} (tol = {rtol}) | " + + f"abs. error = {abs_err} (tol = {atol})" + ) + else: + numerics_info = "NUMERICAL CHECK PASSED: " + if rel_err <= rtol: + numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + ( + " | " if abs_err < atol else "" + ) + if abs_err <= atol: + numerics_info += f"abs. error = {abs_err} (tol = {atol})" + + if myrank == 0: + print(numerics_info + "\n", end="", flush=True) + +destroy_comm_gemm_overlaps() diff --git a/setup.py b/setup.py index 3bb2fe6b95..a702399bc9 100644 --- a/setup.py +++ b/setup.py @@ -164,6 +164,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: "transformer_engine/jax/csrc", current_file_path / "transformer_engine" / "jax" / "csrc", current_file_path / "transformer_engine", + current_file_path / "3rdparty", ) ) if "paddle" in frameworks: diff --git a/tests/jax/test_distributed_gemm.py b/tests/jax/test_distributed_gemm.py new file mode 100644 index 0000000000..b246999d8a --- /dev/null +++ b/tests/jax/test_distributed_gemm.py @@ -0,0 +1,302 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import pytest +from functools import partial +from collections.abc import Iterable + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils + +import transformer_engine.jax as te +from transformer_engine.jax.gemm import gemm + +from utils import assert_allclose + + +jax.config.update("jax_enable_compilation_cache", False) + + +# AG+GEMM: (4, 32/P, 128) ----(AG)----> (4, 32, 128) x (128, 256/P) ----------> (4, 32, 256/P) +# - DGRAD: (4, 32, 256/P) x (128, 256/P)^T --(AR)--> (4, 32, 128) +# - WGRAD: (4, 32/P, 128)^T --(AG)--> (4, 32, 128)^T x (4, 32, 256/P) --------> (128, 256/P) + +# GEMM+AR: (4, 32, 256/P) x (256/P, 128) --(AR)--> (4, 32, 128) +# - DGRAD: (4, 32, 128) x (256/P, 128)^T ------> (4, 32, 256/P) +# - WGRAD: (4, 32, 256/P)^T --(AG)--> (4, 32, 256)^T x (4, 32, 128) --------> (256, 128) + +BATCH = 4 +BASE_SIZE = 16 +SEQ_LEN = BASE_SIZE * 8 +HIDDEN_SIZE = BASE_SIZE * 6 +FFN_HIDDEN_SIZE = BASE_SIZE * 16 + +COMM_TYPES = ["ALL_GATHER", "ALL_REDUCE"] +MESH_TYPES = ["FSDP_TP", "DP_TP", "TP"] +NUM_DEVICES = 4 + +is_fp8_supported, no_fp8_reason = te.fp8.is_fp8_available() + + +def _get_mesh(parallel_dist): + jax.clear_caches() + + batched = False + fsdp = False + mesh_shape = dict(tp=NUM_DEVICES) + resources = dict(cp_resource="tp", tp_resource="tp") + if parallel_dist in ["DP_TP", "FSDP_TP"]: + batched = True + mesh_shape.update(dict(tp=NUM_DEVICES // 2, dp=NUM_DEVICES // 2)) + resources.update(dict(dp_resource="dp")) + if parallel_dist == "FSDP_TP": + fsdp = True + mesh_shape.update(dict(tp=NUM_DEVICES // 2, dp=1, zp=NUM_DEVICES // 2)) + resources.update(dict(fsdp_resource="zp")) + mesh_resource = te.MeshResource(**resources) + + devices = mesh_utils.create_device_mesh((NUM_DEVICES,), devices=jax.devices()[:NUM_DEVICES]) + + mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) + + return mesh, mesh_resource, batched, fsdp + + +def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bwd=False): + fp8_gemm = dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + + # Operand and output shapes + lhs_shape = ( + [SEQ_LEN, HIDDEN_SIZE] if fwd_comm_type == "ALL_GATHER" else [SEQ_LEN, FFN_HIDDEN_SIZE] + ) + rhs_shape = ( + [HIDDEN_SIZE, FFN_HIDDEN_SIZE] + if fwd_comm_type == "ALL_GATHER" + else [FFN_HIDDEN_SIZE, HIDDEN_SIZE] + ) + out_shape = [lhs_shape[0], rhs_shape[1]] + + if batched: + lhs_shape = [BATCH] + lhs_shape + out_shape = [BATCH] + out_shape + + # Operand and output partition specs + lhs_spec = ( + [mesh_resource.tp_resource, None] + if fwd_comm_type == "ALL_GATHER" + else [None, mesh_resource.tp_resource] + ) + rhs_spec = ( + [None, mesh_resource.tp_resource] + if fwd_comm_type == "ALL_GATHER" + else [mesh_resource.tp_resource, None] + ) + out_spec = [None, rhs_spec[-1]] + + # Modify RHS operand for FP8 + fsdp_gathered_rhs_spec = rhs_spec.copy() + if fp8_gemm: + rhs_shape = list(reversed(rhs_shape)) + rhs_spec = list(reversed(rhs_spec)) + fsdp_gathered_rhs_spec = list(reversed(fsdp_gathered_rhs_spec)) + + # Add batch dimensions and specs + if batched: + if fsdp: + lhs_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + lhs_spec + rhs_spec = [mesh_resource.fsdp_resource if spec is None else spec for spec in rhs_spec] + out_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + out_spec + else: + lhs_spec = [mesh_resource.dp_resource] + lhs_spec + out_spec = [mesh_resource.dp_resource] + out_spec + + # Allocate global operands on device + key = jax.random.PRNGKey(42) + split_keys = jax.random.split(key, 3 if fwd_bwd else 2) + mu = 0.0 + sigma = 0.023 + shapes = (lhs_shape, rhs_shape) + if fwd_bwd: + shapes += (out_shape,) + global_operands = list( + map( + lambda key, shape: jax.device_put( + mu + (sigma * jax.random.normal(key, shape, dtype=dtype)), + NamedSharding(mesh, PartitionSpec(None)), + ), + split_keys, + shapes, + ) + ) + + # Allocate sharded operands on device + partition_axes = (lhs_spec, rhs_spec) + if fwd_bwd: + partition_axes += (out_spec,) + local_operands = list( + map( + lambda x, spec: jax.device_put(x, NamedSharding(mesh, PartitionSpec(*spec))), + global_operands, + partition_axes, + ) + ) + + # Tranpose global RHS back to non-transpoosed orientation if it was originally allocated + # for FP8 GEMM + if fp8_gemm: + rhs_global = jnp.matrix_transpose(global_operands[1]) + global_operands = (global_operands[0], rhs_global, *global_operands[2:]) + + return ( + local_operands, + global_operands, + (out_shape, out_spec), + fsdp_gathered_rhs_spec, + ) + + +def _check_output(mesh, expected_out_shape, expected_out_specs, *tensors, fwd_bwd=False): + num_operands = 3 if fwd_bwd else 2 + ref_operands = tensors[:num_operands] + test_outputs = tensors[num_operands:] + + # Check number of dimensions + assert test_outputs[0].ndim == len(expected_out_shape), ( + f"Output has different number of dimensions ({test_outputs[0].ndim}) than expected " + + f"({len(expected_out_shape)})" + ) + + # Pad test output spec for unsharded dimensions + test_spec = te.sharding.get_padded_spec(test_outputs[0].sharding.spec, test_outputs[0].ndim) + + for i in range(test_outputs[0].ndim): + # Check shape + assert test_outputs[0].shape[i] == expected_out_shape[i], ( + f"Output with shape {test_outputs[0].shape} does not match expected shape " + + f"{expected_out_shape} in dimension index {i}." + ) + + # Check shardings (with padded output spec) + spec_mismatch = False + if isinstance(expected_out_specs[i], str): + if test_spec[i] != expected_out_specs[i]: + spec_mismatch = True + elif isinstance(expected_out_specs[i], Iterable): + if not isinstance(test_spec[i], type(expected_out_specs[i])): + if test_spec[i] not in expected_out_specs[i]: + spec_mismatch = True + elif len(test_spec[i]) != len(expected_out_specs[i]): + spec_mismatch = True + else: + for j in range(len(expected_out_specs[i])): + if test_spec[i][j] != expected_out_specs[i][j]: + spec_mismatch = True + break + elif expected_out_specs[i] == None: + if test_spec[i] != None: + spec_mismatch = True + else: + raise RuntimeError("Internal TE error: Unrecognized reference partition spec type.") + if spec_mismatch: + raise AssertionError( + f"Output sharding {test_spec} does not match expected sharding " + + f"{expected_out_specs} in dimension index {i}." + ) + + def _native_gemm_fwd_bwd(lhs, rhs, grad): + fwd_out, vjp_fn = jax.vjp(jnp.dot, lhs, rhs) + lhs_grad, rhs_grad = vjp_fn(grad) + return fwd_out, lhs_grad, rhs_grad + + ref_fn = jax.jit(_native_gemm_fwd_bwd if fwd_bwd else jnp.dot) + + out_names = ["output"] + ref_outputs = ref_fn(*ref_operands) + if not fwd_bwd: + ref_outputs = [ref_outputs] + else: + out_names += ["dgrad", "wgrad"] + + for i, (test_out, ref_out) in enumerate(zip(test_outputs, ref_outputs)): + test_out_global = jax.lax.with_sharding_constraint( + test_out, NamedSharding(mesh, PartitionSpec(None)) + ) + try: + assert_allclose(ref_out, test_out_global) + except AssertionError as err: + raise AssertionError(f"Numerical mismatch in {out_names[i]}:\n" + str(err)) + + +@pytest.mark.parametrize("comm_type", COMM_TYPES) +@pytest.mark.parametrize("mesh_type", MESH_TYPES) +def test_gemm_impl(comm_type, mesh_type): + mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type) + + ( + local_operands, + global_operands, + output_info, + fsdp_gathered_rhs_spec, + ) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp) + + @jax.jit + def _test_fn(lhs, rhs): + rhs_no_fsdp = jax.lax.with_sharding_constraint( + rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec)) + ) + return te.cpp_extensions.gemm_impl(lhs, rhs_no_fsdp, batched_output=batched) + + with te.sharding.global_shard_guard(mesh_resource): + output, *_ = _test_fn(*local_operands) + + _check_output(mesh, *output_info, *global_operands, output) + + +@pytest.mark.parametrize("comm_type", COMM_TYPES) +@pytest.mark.parametrize("mesh_type", MESH_TYPES) +def test_gemm_fwd_bwd(comm_type, mesh_type): + mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type) + + ( + local_operands, + global_operands, + output_info, + fsdp_gathered_rhs_spec, + ) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True) + + @jax.jit + def _test_fn(lhs, rhs, grad): + # Gather weights in FSDP axis + rhs_no_fsdp = jax.lax.with_sharding_constraint( + rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec)) + ) + + # FWD pass + fwd_out, vjp_fn = jax.vjp(gemm, lhs, rhs_no_fsdp) + + # BWD pass + lhs_grad, rhs_grad = vjp_fn(grad) + + return fwd_out, lhs_grad, rhs_grad + + print( + f"INPUTS: {local_operands[0].shape} x {local_operands[1].shape}\n" + + f" LHS sharding: {local_operands[0].sharding.spec}\n" + + f" RHS sharding: {local_operands[1].sharding.spec}\n" + ) + + with te.sharding.global_shard_guard(mesh_resource): + output, dgrad, wgrad = _test_fn(*local_operands) + + print( + f"{'AG + GEMM' if comm_type == 'AG' else 'GEMM + AR'} output: " + + f"{output.shape} | {output.sharding.spec}\n" + + f"DGRAD: {dgrad.shape} | {dgrad.sharding.spec}\n" + + f"WGRAD: {wgrad.shape} | {wgrad.sharding.spec}\n" + ) + + _check_output(mesh, *output_info, *global_operands, output, dgrad, wgrad, fwd_bwd=True) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index c6f0f870ff..810eeb2ebe 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -139,11 +139,12 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool atomic_gemm) + bool atomic_gemm, bool overlap_first_gemm) : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, false, atomic_gemm) { _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); + _overlap_first_gemm = overlap_first_gemm; NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", "or 2 (multi-atomic)."); @@ -164,6 +165,36 @@ CommOverlapBase::~CommOverlapBase() { cudaStreamDestroy(_stream_comm); } +TensorWrapper CommOverlapBase::get_ubuf_output(CommOverlapType comm_type) { + char *output_ptr = reinterpret_cast(_ubuf.dptr()); + if (comm_type == CommOverlapType::RS) + output_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + size_t output_c_dim0 = + (comm_type == CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + size_t output_c_dim1 = _ubuf.size(1); + return TensorWrapper(reinterpret_cast(output_ptr), {output_c_dim0, output_c_dim1}, + _ubuf.dtype()); +} + +void CommOverlapBase::copy_into_ubuf(cudaStream_t stream, TensorWrapper &input, + CommOverlapType comm_type) { + char *ubuf_ptr = reinterpret_cast(_ubuf.dptr()); + if (comm_type == CommOverlapType::AG) { + if ((input.numel() * _tp_size) != (int64_t)_ubuf.numel() || + input.element_size() != (int64_t)_ubuf.element_size()) { + NVTE_ERROR("Input and buffer sizes do not match!"); + } + ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + } else { + if (input.numel() != (int64_t)_ubuf.numel() || + input.element_size() != (int64_t)_ubuf.element_size()) { + NVTE_ERROR("Input and buffer sizes do not match!"); + } + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.dptr(), input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, stream)); +} + /* ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf @@ -225,8 +256,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, - cudaStream_t stream_main) { + TensorWrapper &rs_output, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -325,8 +355,7 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, - cudaStream_t stream_main) { + TensorWrapper &rs_output, cudaStream_t stream_main) { // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; @@ -358,7 +387,7 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap assert(pre_gelu_out.numel() == 0); - if (gemm_overlap) { + if (_overlap_first_gemm) { auto input_a_chunk = TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv()); auto output_chunk = @@ -565,6 +594,37 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { cudaStreamDestroy(_stream_send); } +TensorWrapper CommOverlapP2PBase::get_ubuf_output(CommOverlapType comm_type) { + char *output_ptr = reinterpret_cast(_ubuf.dptr()); + if (comm_type == CommOverlapType::RS) + output_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); + size_t output_c_dim0 = + (comm_type == CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + size_t output_c_dim1 = _ubuf.size(1); + return TensorWrapper(reinterpret_cast(output_ptr), {output_c_dim0, output_c_dim1}, + _ubuf.dtype()); +} + +void CommOverlapP2PBase::copy_into_ubuf(cudaStream_t stream, TensorWrapper &input, + CommOverlapType comm_type) { + if (comm_type == CommOverlapType::RS) { + // Copy input to the target ubuf chunk by rank offset + if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) { + NVTE_ERROR("Input and buffer sizes do not match!"); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input.dptr(), + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + stream)); + } else { + if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) { + NVTE_ERROR("Input and buffer sizes do not match!"); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input.dptr(), + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + stream)); + } +} + /* ** Split AllGather + AtomicGEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 1d5d192a39..0605825c82 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -17,6 +17,9 @@ #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 +#define NOT_IMPLEMENTED_ERROR() NVTE_ERROR("Operation is not implemented.") + +#define NOT_SUPPORTED_ERROR() NVTE_ERROR("Operation not supported.") namespace transformer_engine { /* \brief Check if Userbufers bootstraps with direct calls to MPI collectives. @@ -26,9 +29,9 @@ namespace transformer_engine { */ bool ubuf_built_with_mpi(); -enum class CommOverlapType { RS = 0, AG = 1 }; +enum class CommOverlapType : int { RS = 0, AG = 1 }; -enum class CommOverlapAlgo { +enum class CommOverlapAlgo : int { BULK_OVERLAP_AG = 0, BULK_OVERLAP_RS = 1, SPLIT_PIPELINED_AG_P2P = 2, @@ -77,16 +80,64 @@ class CommOverlapCore { _ubuf_scale_inv_initialized = true; } - bool is_atomic_gemm() { return _atomic_gemm; } + virtual TensorWrapper get_ubuf_output(CommOverlapType comm_type) { NOT_IMPLEMENTED_ERROR(); } + + virtual void copy_into_ubuf(cudaStream_t stream, TensorWrapper &input, + CommOverlapType comm_type) { + NOT_IMPLEMENTED_ERROR(); + } + + virtual bool is_atomic_gemm() { return _atomic_gemm; } + + virtual bool is_p2p_overlap() { return _is_p2p; } + + virtual bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + + virtual void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } + + virtual void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } - bool is_p2p_overlap() { return _is_p2p; } + virtual void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } - bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + virtual void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } + + virtual void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } }; // CommOverlapCore class CommOverlapBase : public CommOverlapCore { protected: int _rs_kernel_type; + bool _overlap_first_gemm; cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; @@ -95,36 +146,47 @@ class CommOverlapBase : public CommOverlapCore { int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false, + bool overlap_first_gemm = false); virtual ~CommOverlapBase(); - /* - ** Bulk GEMM + COMM - ** This function assumes the communication input is pre-copied to _ubuf - */ + TensorWrapper get_ubuf_output(CommOverlapType comm_type); + + void copy_into_ubuf(cudaStream_t stream, TensorWrapper &input, CommOverlapType comm_type); + void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main); - /* - ** Split FPROP GEMM + ReduceScatter - */ void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, - TensorWrapper &rs_output, cudaStream_t stream_main); + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main); - /* - ** Split FPROP GEMM + ReduceScatter - */ void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output, + bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main); + + void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NOT_SUPPORTED_ERROR(); + } + + void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NOT_SUPPORTED_ERROR(); + } }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { @@ -155,44 +217,39 @@ class CommOverlapP2PBase : public CommOverlapCore { virtual ~CommOverlapP2PBase(); - /* - ** Split AllGather + AtomicGEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG - ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. - */ - void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper get_ubuf_output(CommOverlapType comm_type); + + void copy_into_ubuf(cudaStream_t stream, TensorWrapper &input, CommOverlapType comm_type); + + void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main) { + NOT_SUPPORTED_ERROR(); + } + + void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, + bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main); - /* - ** Split AllGather + GEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG - ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. - */ - void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, + bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main); - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &rs_output, + bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main); - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &rs_output, + bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main); }; // CommOverlapP2PBase diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index d302518235..6fdc93098f 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -393,7 +393,7 @@ class TensorWrapper { return nvte_tensor_scale_inv(tensor_); } - private: + protected: /*! \brief Wrapped NVTETensor. */ NVTETensor tensor_ = nullptr; }; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1a3b49f9fa..b92a993d49 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -93,7 +93,7 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { return ret; } -size_t nvte_tensor_ndim(const NVTETensor tensor) { +size_t nvte_tensor_ndims(const NVTETensor tensor) { const auto &t = *reinterpret_cast(tensor); return t.data.shape.size(); } diff --git a/transformer_engine/common/util/dlpack_helper.h b/transformer_engine/common/util/dlpack_helper.h new file mode 100644 index 0000000000..cd8210e37a --- /dev/null +++ b/transformer_engine/common/util/dlpack_helper.h @@ -0,0 +1,188 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_DLPACK_HELPER_H +#define TRANSFORMER_ENGINE_COMMON_UTIL_DLPACK_HELPER_H + +#include +#include +#include + +#include "cuda_runtime.h" +#include "logging.h" + +namespace transformer_engine { + +DLDataType nvte_dtype_to_dldtype(DType dtype) { + DLDataType dldtype; + dldtype.lanes = 1; + switch (dtype) { + case DType::kInt64: + dldtype.bits = 64; + dldtype.code = DLDataTypeCode::kDLInt; + break; + + case DType::kInt32: + dldtype.bits = 32; + dldtype.code = DLDataTypeCode::kDLInt; + break; + + case DType::kByte: + dldtype.bits = 8; + dldtype.code = DLDataTypeCode::kDLUInt; + break; + + case DType::kFloat32: + dldtype.bits = 32; + dldtype.code = DLDataTypeCode::kDLFloat; + break; + + case DType::kFloat16: + dldtype.bits = 16; + dldtype.code = DLDataTypeCode::kDLFloat; + break; + + case DType::kBFloat16: + dldtype.bits = 16; + dldtype.code = DLDataTypeCode::kDLBfloat; + break; + + case DType::kFloat8E4M3: + dldtype.bits = 8; + dldtype.code = DLDataTypeCode::kDLFloat; + break; + + case DType::kFloat8E5M2: + dldtype.bits = 8; + dldtype.code = DLDataTypeCode::kDLFloat; + break; + + default: + NVTE_ERROR("Unrecognized transformer_engine::DType."); + } + return dldtype; +} + +DType dldtype_to_nvte_dtype(const DLDataType &dldtype, bool grad) { + NVTE_CHECK(dldtype.lanes == 1, "Unsupported number of lanes in DLDataType: ", dldtype.lanes); + + switch (dldtype.code) { + case DLDataTypeCode::kDLInt: + switch (dldtype.bits) { + case 64: + return DType::kInt64; + + case 32: + return DType::kInt32; + + default: + NVTE_ERROR("Unsupported bits in integer DLDataType: ", dldtype.bits); + } + + case DLDataTypeCode::kDLFloat: + switch (dldtype.bits) { + case 32: + return DType::kFloat32; + + case 16: + return DType::kFloat16; + + case 8: + if (grad) { + return DType::kFloat8E5M2; + } else { + return DType::kFloat8E4M3; + } + + default: + NVTE_ERROR("Unsupported bits in float DLDataType: ", dldtype.bits); + } + + case DLDataTypeCode::kDLBfloat: + if (dldtype.bits == 16) { + return DType::kBFloat16; + } else { + NVTE_ERROR("Unsupported bits in bfloat DLDataType: ", dldtype.bits); + } + + case DLDataTypeCode::kDLBool: + case DLDataTypeCode::kDLUInt: + if (dldtype.bits == 8) { + return DType::kByte; + } else { + NVTE_ERROR("Unsupported bits in unsigned int DLDataType: ", dldtype.bits); + } + + default: + NVTE_ERROR("Unsupported DLDataType."); + } +} + +class DLPackWrapper : public TensorWrapper { + protected: + DLManagedTensor managed_tensor; + + public: + // Inherit TensorWrapper constructors + using TensorWrapper::TensorWrapper; + + // Construct a new DLPackWrapper from existing TensorWrapper + DLPackWrapper(TensorWrapper &&other) : TensorWrapper(std::move(other)) {} + + // New constructor from PyObject + DLPackWrapper(pybind11::object obj, bool grad = false) { + NVTE_CHECK(PyCapsule_CheckExact(obj.ptr()), "Expected DLPack capsule"); + + DLManagedTensor *dlMTensor = (DLManagedTensor *)PyCapsule_GetPointer(obj.ptr(), "dltensor"); + NVTE_CHECK(dlMTensor, "Invalid DLPack capsule."); + + DLTensor *dlTensor = &dlMTensor->dl_tensor; + NVTE_CHECK(dlTensor->device.device_type == DLDeviceType::kDLCUDA, + "DLPack tensor is not on a CUDA device."); + NVTE_CHECK(dlTensor->device.device_id == cuda::current_device(), + "DLPack tensor resides on a different device."); + + if (dlTensor->strides) { + for (int idx = dlTensor->ndim - 1; idx >= 0; ++idx) { + NVTE_CHECK(dlTensor->strides[idx] == 1, + "DLPack tensors with non-standard strides are not supported."); + } + } + + NVTEShape shape; + shape.data = reinterpret_cast(dlTensor->shape); + shape.ndim = static_cast(dlTensor->ndim); + this->tensor_ = nvte_create_tensor( + dlTensor->data, shape, static_cast(dldtype_to_nvte_dtype(dlTensor->dtype, grad)), + nullptr, nullptr, nullptr); + } + + pybind11::object capsule() { + DLDevice tensor_context; + tensor_context.device_type = DLDeviceType::kDLCUDA; + tensor_context.device_id = cuda::current_device(); + + DLTensor dlTensor; + dlTensor.data = dptr(); + dlTensor.device = tensor_context; + dlTensor.ndim = ndim(); + dlTensor.dtype = nvte_dtype_to_dldtype(dtype()); + dlTensor.shape = reinterpret_cast(const_cast(shape().data)); + dlTensor.strides = nullptr; + dlTensor.byte_offset = 0; + + managed_tensor.dl_tensor = dlTensor; + managed_tensor.manager_ctx = nullptr; + managed_tensor.deleter = [](DLManagedTensor *) {}; + + return pybind11::reinterpret_steal( + PyCapsule_New(&managed_tensor, "dltensor", nullptr)); + } +}; + +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 432ac815ec..9091e7e364 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -8,72 +8,95 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ #include +#include #include #include #include #include "cuda_runtime.h" -#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ - pybind11::enum_(m, "DType") \ - .value("kByte", transformer_engine::DType::kByte) \ - .value("kInt32", transformer_engine::DType::kInt32) \ - .value("kFloat32", transformer_engine::DType::kFloat32) \ - .value("kFloat16", transformer_engine::DType::kFloat16) \ - .value("kBFloat16", transformer_engine::DType::kBFloat16) \ - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ - pybind11::enum_(m, "NVTE_Bias_Type") \ - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ - .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ - pybind11::enum_(m, "NVTE_Mask_Type") \ - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ - pybind11::enum_(m, "NVTE_QKV_Layout") \ - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ - pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ - pybind11::enum_(m, "CommOverlapType") \ - .value("RS", transformer_engine::CommOverlapType::RS) \ - .value("AG", transformer_engine::CommOverlapType::AG); \ - pybind11::enum_(m, "CommOverlapAlgo") \ - .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ - .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ - .value("SPLIT_PIPELINED_AG_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ - .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ - .value("SPLIT_PIPELINED_RS_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ - .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ - .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ - m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ - py::call_guard(), py::arg("device_id") = -1); \ - m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); +#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ + pybind11::enum_(m, "DType") \ + .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kInt64", transformer_engine::DType::kInt64) \ + .value("kFloat32", transformer_engine::DType::kFloat32) \ + .value("kFloat16", transformer_engine::DType::kFloat16) \ + .value("kBFloat16", transformer_engine::DType::kBFloat16) \ + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + pybind11::enum_(m, "NVTE_Bias_Type") \ + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI) \ + .export_values(); \ + pybind11::enum_(m, "NVTE_Mask_Type") \ + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) \ + .export_values(); \ + pybind11::enum_(m, "NVTE_QKV_Format") \ + .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ + .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ + .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD) \ + .export_values(); \ + pybind11::enum_(m, "NVTE_QKV_Layout") \ + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \ + .export_values(); \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) \ + .export_values(); \ + pybind11::enum_(m, "NVTE_Activation_Type") \ + .value("GELU", NVTE_Activation_Type::GELU) \ + .value("GEGLU", NVTE_Activation_Type::GEGLU) \ + .value("SILU", NVTE_Activation_Type::SILU) \ + .value("SWIGLU", NVTE_Activation_Type::SWIGLU) \ + .value("RELU", NVTE_Activation_Type::RELU) \ + .value("REGLU", NVTE_Activation_Type::REGLU) \ + .value("QGELU", NVTE_Activation_Type::QGELU) \ + .value("QGEGLU", NVTE_Activation_Type::QGEGLU) \ + .value("SRELU", NVTE_Activation_Type::SRELU) \ + .value("SREGLU", NVTE_Activation_Type::SREGLU) \ + .export_values(); \ + pybind11::enum_(m, "CommOverlapType") \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo") \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ + pybind11::call_guard(), pybind11::arg("device_id") = -1); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + pybind11::call_guard()); #endif diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index 579daa8e41..1e5cc4c07e 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -4,6 +4,7 @@ """Python interface for c++ extensions""" from .activation import * from .attention import * +from .gemm import * from .normalization import * from .quantization import * from .softmax import * diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py new file mode 100644 index 0000000000..b60fd1c74f --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -0,0 +1,1370 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import warnings +import operator +from functools import reduce, partial +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from jax import dtypes +from jax.interpreters import mlir +from jax.interpreters.mlir import ir +from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi +from jax.typing import ArrayLike + +from transformer_engine import transformer_engine_jax as tex +from .base import BasePrimitive, register_primitive +from .custom_call import custom_caller, CustomCallArgsWrapper +from .misc import ( + jax_dtype_to_te_dtype, + jax_dtype_is_fp8, + get_padded_spec, + is_ffi_enabled, + check_valid_batch_dims, +) +from ..sharding import ( + global_mesh_resource, + all_reduce_max_along_all_axes_except_PP, +) + + +__all__ = [ + "fp8_gemm_impl", + "gemm_impl", + "copy_into_overlap_buffer", + "bootstrap_comm_gemm_overlap", + "get_num_max_compute_streams", + "set_num_max_compute_streams", +] + + +_NUM_MAX_COMPUTE_STREAMS = 3 +_COMM_GEMM_OVERLAP_LAYERS = ["qkv", "proj", "fc1", "fc2"] +_COMM_GEMM_OVERLAP_NAMES = ( + [layer + "_fprop" for layer in _COMM_GEMM_OVERLAP_LAYERS] + + [layer + "_dgrad" for layer in _COMM_GEMM_OVERLAP_LAYERS] + + [layer + "_wgrad" for layer in _COMM_GEMM_OVERLAP_LAYERS if layer != "fc2"] + + ["ag_gemm", "gemm_rs"] +) + + +def sanitize_dims(dim, ndims): + return (ndims + dim) if dim < 0 else dim + + +def mirror_dim(dim, ndims): + return ndims - 2 if dim == ndims - 1 else ndims - 1 + + +def get_cublas_workspace_size_bytes() -> None: + """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" + if tex.get_device_compute_capability() >= 90: + return 33_554_432 + return 4_194_304 + + +def get_num_max_compute_streams() -> int: + """Return the maximum number of compute streams that Comm+GEMM overlap can utilize.""" + return _NUM_MAX_COMPUTE_STREAMS + + +def set_num_max_compute_streams(new_max: int) -> None: + """Change the maximum number of compute streams that Comm+GEMM overlap can utilize.""" + global _NUM_MAX_COMPUTE_STREAMS + _NUM_MAX_COMPUTE_STREAMS = new_max + + +class CollectiveGemmPrimitive(BasePrimitive): + """ + cuBlasLt GEMM Primitive w/ support for distributed inputs + """ + + name = "te_gemm" + impl_static_args = (10, 11, 12, 13, 14, 15, 16, 17, 18) + multiple_results = True + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + lhs_aval, + lhs_scale_inv_aval, + rhs_aval, + rhs_scale_inv_aval, + bias_aval, + gelu_input_aval, + out_aval, + out_amax_aval, + out_scale_aval, + extra_out_aval, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + comm_overlap_config, + sharded_abstract, + ): + """ + cuBlasLt GEMM abstract + """ + if comm_overlap_config is not None: + assert tex.ubuf_built_with_mpi(), ( + "Comm+GEMM overlap in TE/JAX requires Transformer Engine to be compiled with " + + "`NVTE_UB_WITH_MPI=1` and `MPI_HOME=/path/to/mpi` options." + ) + assert is_ffi_enabled(), "Comm+GEMM overlap is supported only via XLA FFI." + + del grad, accumulate, use_split_accumulator + + # Validate operand dtypes + lhs_dtype = dtypes.canonicalize_dtype(lhs_aval.dtype) + rhs_dtype = dtypes.canonicalize_dtype(rhs_aval.dtype) + assert lhs_dtype == rhs_dtype, "Mismatched matrix dtypes for GEMM." + is_fp8 = False + if jax_dtype_is_fp8(lhs_dtype): + assert ( + lhs_scale_inv_aval.size == 1 + and dtypes.canonicalize_dtype(lhs_scale_inv_aval.dtype) == jnp.float32 + ), "Missing LHS operand scale inverse in FP8 GEMM." + is_fp8 = True + if jax_dtype_is_fp8(rhs_dtype): + assert ( + rhs_scale_inv_aval.size == 1 + and dtypes.canonicalize_dtype(rhs_scale_inv_aval.dtype) == jnp.float32 + ), "Missing RHS operand scale inverse in FP8 GEMM." + + # Validate operand layouts + lhs_inner_dim, rhs_inner_dim = map( + sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) + ) + assert lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim], ( + "Incompatible operand sizes: " + + f"{lhs_aval.shape} @ idx {lhs_inner_dim} X {rhs_aval.shape} @ idx {rhs_inner_dim}." + ) + + lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 + rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 + assert not ( + lhs_trans and rhs_trans + ), "GEMM does not support transposed LHS and transposed RHS at the same time." + if is_fp8: + assert not lhs_trans, "FP8 GEMM does not support transposed LHS." + assert rhs_trans, "FP8 GEMM requires transposed RHS." + + # Make sure leading dimensions of RHS is broadcast-compatible with LHS + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs_aval.ndim, rhs_aval.ndim), + ) + if lhs_aval.ndim > 2 and rhs_aval.ndim > 2: + assert ( + not batched_output + ), "Batched output requires batched LHS and non-batched RHS operands." + lhs_bdims = [ + dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] + ] + lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] + lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) + rhs_bdims = [ + dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] + ] + rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) + assert lhs_batch_size == rhs_batch_size, ( + "Leading dimensions of LHS and RHS are not broadcast-compatible: " + + f"{lhs_aval.shape} @ idx {lhs_inner_dim} X {rhs_aval.shape} @ idx {rhs_inner_dim}" + ) + + # Validate output dtypes + out_dtype = dtypes.canonicalize_dtype(out_aval.dtype) + if jax_dtype_is_fp8(out_dtype): + assert jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8( + rhs_dtype + ), "FP8 GEMM output requires FP8 inputs." + assert ( + out_amax_aval.size == out_scale_aval.size == 1 + ), "Invalid/missing output amax and scale." + out_amax_updated_dtype = dtypes.canonicalize_dtype(out_amax_aval.dtype) + out_scale_updated_dtype = dtypes.canonicalize_dtype(out_scale_aval.dtype) + assert ( + out_amax_updated_dtype == out_scale_updated_dtype == jnp.float32 + ), "Invalid output amax or scale dtype." + else: + if not jax_dtype_is_fp8(lhs_dtype): + assert out_dtype == lhs_dtype, ( + "Output buffer has incorrect dtype: " + + f"expected {lhs_dtype} but found {out_dtype}" + ) + out_amax_updated_dtype = jnp.float32 + out_scale_updated_dtype = jnp.float32 + + # Validate output buffers + out_shape = out_aval.shape + expected_out_shape = [ + *lhs_aval.shape[:-2], + lhs_aval.shape[lhs_outer_dim], + rhs_aval.shape[rhs_outer_dim], + ] + if batched_output: + assert out_aval.ndim > 2, "Batched output buffer is missing batch dimensions." + else: + expected_out_shape = [ + reduce(operator.mul, expected_out_shape[:-1], 1), + expected_out_shape[-1], + ] + + expected_extra_out_shape = [0] + expected_extra_out_dtype = jnp.bfloat16 + workspace_size = get_cublas_workspace_size_bytes() + if comm_overlap_config is not None: + comm_type = comm_overlap_config.get("comm_type", None) + assert comm_type is not None, "Missing comm type for comm+GEMM overlap." + + tp_size = comm_overlap_config.get("tp_size", 1) + assert ( + tp_size > 1 + ), "Comm+GEMM overlap requires tensor-parallel mesh axis size greater than 1." + + if comm_overlap_config["method"] != "bulk": + # Increase workspace size to ensure every GEMM chunk has an independent workspace + # of the appropriate size + workspace_size *= _NUM_MAX_COMPUTE_STREAMS + + if comm_type == tex.CommOverlapType.AG and extra_out_aval.size > 0: + expected_extra_out_shape = list(lhs_aval.shape).copy() + expected_extra_out_dtype = lhs_dtype + elif comm_type == tex.CommOverlapType.RS: + assert extra_out_aval.size > 0, "GEMM+RS overlap requires extra output buffer." + expected_extra_out_shape = list(expected_out_shape).copy() + + if sharded_abstract: + if comm_type == tex.CommOverlapType.AG: + expected_out_shape[-2] *= tp_size + if extra_out_aval.size > 0: + expected_extra_out_shape[-2] *= tp_size + elif comm_type == tex.CommOverlapType.RS: + expected_extra_out_shape[-2] = expected_extra_out_shape[-2] // tp_size + + assert out_aval.ndim == len(expected_out_shape), ( + "Output buffer has incorrect number of dimensions: " + + f"expected {len(expected_out_shape)} but found {out_aval.ndim}" + ) + assert all([out_aval.shape[i] == expected_out_shape[i] for i in range(out_aval.ndim)]), ( + "Output buffer has incorrect shape: " + + f"expected {expected_out_shape=} but found {out_aval.shape=}" + ) + + if extra_out_aval.size > 0: + extra_out_dtype = dtypes.canonicalize_dtype(extra_out_aval.dtype) + assert extra_out_dtype == expected_extra_out_dtype, ( + "Extra output has incorrect dtype: " + + f"expected {expected_extra_out_dtype} but found {extra_out_dtype}" + ) + assert extra_out_aval.ndim == len(expected_extra_out_shape), ( + "Extra output buffer has incorrect number of dimensions: " + + f"expected {len(expected_extra_out_shape)} but found {extra_out_aval.ndim}" + ) + assert all( + [ + extra_out_aval.shape[i] == expected_extra_out_shape[i] + for i in range(extra_out_aval.ndim) + ] + ), ( + "Extra output buffer has incorrect shape: " + + f"expected {expected_extra_out_shape=} but found {extra_out_aval.shape=}" + ) + + # Validate bias/bias_grad shape against output bufer + bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype + if fuse_bias: + assert ( + bias_aval.size > 0 and bias_aval.ndim == 1 and bias_aval.shape[0] == out_shape[-1] + ), ( + "Incorrect bias shape: " + + f"expected ({out_shape[-1]}, ) but found ({bias_aval.shape[0]}, )" + ) + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + else: + assert bias_aval.size == 0, "Internal TE error." + + # Validate GELU input/output + gelu_shape = (0,) + if fuse_gelu: + gelu_shape = ( + (reduce(operator.mul, out_shape[:-1], 1), out_shape[-1]) + if len(out_shape) > 2 + else out_shape + ) + assert gelu_input_aval.ndim == 2 and all( + [gelu_input_aval.shape[i] == gelu_shape[i] for i in len(gelu_shape)] + ), "Invalid GELU input shape." + assert gelu_input_aval.dtype == bias_dtype, "Invalid GELU dtype." + else: + assert gelu_input_aval.size == 0, "Internal TE error." + + # Create abstract arrays for all outputs + out_updated_aval = out_aval.update(shape=out_shape, dtype=out_dtype) + out_amax_updated_aval = out_amax_aval.update( + shape=out_amax_aval.shape, dtype=out_amax_updated_dtype + ) + out_scale_updated_aval = out_scale_aval.update( + shape=out_scale_aval.shape, dtype=out_scale_updated_dtype + ) + pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_shape, dtype=bias_dtype) + bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) + extra_out_updated_aval = extra_out_aval.update( + shape=expected_extra_out_shape, dtype=expected_extra_out_dtype + ) + workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) + + return ( + out_updated_aval, + out_amax_updated_aval, + out_scale_updated_aval, + pre_gelu_out_aval, + bias_grad_aval, + extra_out_updated_aval, # global LHS for AG overlap, or sharded output for RS overlap + workspace_aval, + ) + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + cuBlasLt GEMM outer abstract + """ + ( + out_aval, + out_amax_aval, + out_scale_aval, + pre_gelu_out_aval, + bias_grad_aval, + extra_out_aval, + *_, + ) = CollectiveGemmPrimitive.abstract(*args, **kwargs) + return ( + out_aval, + out_amax_aval, + out_scale_aval, + pre_gelu_out_aval, + bias_grad_aval, + extra_out_aval, + ) + + @staticmethod + def lowering( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out, + out_amax, + out_scale, + extra_out, + *, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + comm_overlap_config, + sharded_abstract, + ): + """ + Fused attention fwd lowering rules + """ + del batched_output, sharded_abstract + lhs_aval, _, rhs_aval, _, bias_aval, *_, extra_out_aval = ctx.avals_in + lhs_inner_dim, rhs_inner_dim = map( + sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) + ) + lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 + rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 + + operands = [ + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out, + out_amax, + out_scale, + extra_out, + ] + + operand_output_aliases = { + 4: 4, # bias <--> bias_grad + 5: 3, # gelu_input <--> pre_gelu_out + 6: 0, # out <--> out_updated + 7: 1, # out_amax <--> out_amax_updated + 8: 2, # out_scale <--> out_scale_updated + } + if extra_out_aval.size > 0: + operand_output_aliases[9] = 5 # extra_out <--> extra_out_updated + + if is_ffi_enabled(): + name = "te_gemm_ffi" + ffi_args = (ctx, *operands) + ffi_kwargs = dict( + lhs_trans=lhs_trans, + rhs_trans=rhs_trans, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if comm_overlap_config is not None: + name = "te_comm_gemm_overlap_ffi" + ffi_kwargs["comm_type_flag"] = int(comm_overlap_config["comm_type"]) + ffi_kwargs["name"] = comm_overlap_config["name"] + + return ffi.ffi_lowering(name, operand_output_aliases=operand_output_aliases)( + *ffi_args, **ffi_kwargs + ) + + else: + operand_shapes = map(lambda x: ir.RankedTensorType(x.type).shape, operands) + out_types = [ + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_dtype(output.dtype)) + for output in ctx.avals_out + ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim), + ) + m = lhs_aval.shape[lhs_outer_dim] + k = rhs_aval.shape[rhs_inner_dim] + n = rhs_aval.shape[rhs_outer_dim] + operand_dtype = jax_dtype_to_te_dtype(lhs_aval.dtype) + bias_dtype = jax_dtype_to_te_dtype(bias_aval.dtype) + + workspace_size = get_cublas_workspace_size_bytes() + if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": + workspace_size *= get_num_max_compute_streams() + + descriptor_packer_fn = tex.pack_gemm_decriptor + descriptor_args = ( + m, + n, + k, + workspace_size, + operand_dtype, + jax_dtype_to_te_dtype(dtypes.canonicalize_dtype(ctx.avals_out[0].dtype)), + bias_dtype, + lhs_trans, + rhs_trans, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ) + + opaque = descriptor_packer_fn(*descriptor_args) + + return custom_caller( + CollectiveGemmPrimitive.name, + args, + opaque, + has_side_effect=False, + operand_output_aliases=operand_output_aliases, + ) + + @staticmethod + def impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out, + out_amax, + out_scale, + extra_out, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + comm_overlap_config, + sharded_abstract, + ): + assert CollectiveGemmPrimitive.inner_primitive is not None + + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, (lhs_inner_dim, rhs_inner_dim), (lhs.ndim, rhs.ndim) + ) + + # Infer output shape and collapse batch dimensions + lhs_2d_shape = rhs_2d_shape = None + lhs_layout = rhs_layout = None + lhs_batch_dims = [ + dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim] + ] + lhs_batch_shape = [lhs.shape[dim] for dim in lhs_batch_dims] + lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) + contracting_dims_2d = list(contracting_dims).copy() + if lhs.ndim > 2 and rhs.ndim > 2: + # If both LHS and RHS are batched, the batch dimensions collapse into the + # contracting dimensions for both operands + lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_inner_dim], lhs.shape[lhs_outer_dim]) + lhs_layout = (*lhs_batch_dims, lhs_inner_dim, lhs_outer_dim) + contracting_dims_2d[0] = 0 + + rhs_batch_dims = [ + dim for dim in range(rhs.ndim) if dim not in [rhs_inner_dim, rhs_outer_dim] + ] + rhs_batch_shape = [rhs.shape[dim] for dim in rhs_batch_dims] + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) + rhs_2d_shape = (rhs_batch_size * rhs.shape[rhs_inner_dim], rhs.shape[rhs_outer_dim]) + rhs_layout = (*rhs_batch_dims, rhs_inner_dim, rhs_outer_dim) + contracting_dims_2d[1] = 0 + elif lhs.ndim > 2: + # If only the LHS is batched,the batch dimension collapses into the outer dimension + lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_outer_dim], lhs.shape[lhs_inner_dim]) + lhs_layout = (*lhs_batch_dims, lhs_outer_dim, lhs_inner_dim) + contracting_dims_2d[0] = 1 + + # Reshape LHS and RHS into 2D and fix layouts for FP8 GEMM + if lhs_2d_shape is not None and lhs.ndim > 2: + lhs = jax.lax.reshape(lhs, lhs_2d_shape, dimensions=lhs_layout) + if jax_dtype_is_fp8(lhs.dtype): + lhs = jax.lax.transpose(lhs, (1, 0)) + contracting_dims_2d[0] = 1 + else: + contracting_dims_2d[0] = contracting_dims[0] + + if rhs_2d_shape is not None and rhs.ndim > 2: + rhs = jax.lax.reshape(rhs, rhs_2d_shape, dimensions=rhs_layout) + if jax_dtype_is_fp8(rhs.dtype): + rhs = jax.lax.transpose(rhs, (1, 0)) + contracting_dims_2d[1] = 1 + else: + contracting_dims_2d[1] = contracting_dims[1] + + # Reshape output and extra output buffers into 2D as well + if out.ndim > 2: + out = jax.lax.reshape(out, (reduce(operator.mul, out.shape[:-1], 1), out.shape[-1])) + if extra_out.size > 0 and extra_out.ndim > 2: + extra_out = jax.lax.reshape( + extra_out, (reduce(operator.mul, extra_out.shape[:-1], 1), extra_out.shape[-1]) + ) + + batched_extra_out = False + if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": + comm_type = comm_overlap_config["comm_type"] + if comm_type == tex.CommOverlapType.AG and extra_out.size > 0: + # Extra output is global LHS, we can collapse but need to recover batches later + batched_extra_out = len(lhs_batch_dims) > 0 + elif comm_type == tex.CommOverlapType.RS: + # Extra output is scattered GEMM output, so we recover batches only if the output is + # batched + batched_extra_out = batched_output + + # Invoke GEMM with guaranteed 2D inputs, so batched_output=False + ( + out_updated, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + extra_out_updated, + _, + ) = CollectiveGemmPrimitive.inner_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out, + out_amax, + out_scale, + extra_out, + batched_output=False, + contracting_dims=contracting_dims_2d, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, + sharded_abstract=sharded_abstract, + ) + + # Recover batched dimensions in the output + if batched_output: + out_shape = ( + *lhs_batch_shape, + out_updated.shape[-2] // lhs_batch_size, + out_updated.shape[-1], + ) + out_updated = jax.lax.reshape(out_updated, out_shape) + + if batched_extra_out: + extra_out_shape = ( + *lhs_batch_shape, + extra_out_updated.shape[-2] // lhs_batch_size, + extra_out_updated.shape[-1], + ) + extra_out_updated = jax.lax.reshape(extra_out_updated, extra_out_shape) + + return ( + out_updated, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + extra_out_updated, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + comm_overlap_config, + sharded_abstract, + ): + assert CollectiveGemmPrimitive.outer_primitive is not None + check_valid_batch_dims(batch_dims) + ( + *_, + bias_bdims, + gelu_input_bdims, + out_bdims, + out_amax_bdims, + out_scale_bdims, + extra_out_bdims, + ) = batch_dims + + return ( + CollectiveGemmPrimitive.outer_primitive.bind( + *batched_args, + batched_output=batched_output, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, + sharded_abstract=sharded_abstract, + ), + ( + out_bdims, + out_amax_bdims, + out_scale_bdims, + gelu_input_bdims, + bias_bdims, + extra_out_bdims, + ), + ) + + @staticmethod + def infer_sharding_from_operands( + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + comm_overlap_config, + sharded_abstract, + mesh, + arg_infos, + result_infos, + ): + del accumulate, use_split_accumulator, sharded_abstract, result_infos + lhs, _, rhs, *_, extra_out = arg_infos + lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) + + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim), + ) + + # Modify operand specs + lhs_spec_new = [spec for spec in lhs_spec] + rhs_spec_new = [spec for spec in rhs_spec] + if comm_overlap_config is None: + # When comm overlap is not enabled: + # - Always all-gather the outer dimension of LHS. + # - If contracting dims of both operands are sharded, all-gather RHS outer dim. + # - If contracting dim of only one operand is sharded, all-gather the sharded operand. + # - Never scatter any operand. + lhs_spec_new[lhs_outer_dim] = None + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: + assert ( + lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim] + ), "Contracting dimensions of LHS and RHS operands must have the same sharding." + if lhs_spec_new[lhs_outer_dim] is not None: + warnings.warn( + "Outer dimension of the LHS operand must be all-gathered when both " + + "contracting dimensions are sharded. This will cause additional " + + "communication overhead." + ) + + if rhs_spec_new[rhs_outer_dim] is not None: + warnings.warn( + "Outer dimension of the RHS operand must be all-gathered when both " + + "contracting dimensions are sharded. This will cause additional " + + "communication overhead." + ) + rhs_spec_new[rhs_outer_dim] = None + else: + if lhs_spec_new[lhs_inner_dim] is None and rhs_spec_new[rhs_inner_dim] is not None: + warnings.warn( + "Contracting dimension of the RHS operand must be all-gathered when the " + + "contracting dimension of the LHS operand is unsharded. This will cause " + + "additional communication overhead." + ) + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is None: + if not grad: + # This is expected for sequence/context-parallel gradient in BWD (DGRAD) GEMM. + warnings.warn( + "Contracting dimension of the LHS operand must be all-gathered when " + + "the contracting dimension of the RHS operand is unsharded. This " + + "will cause additional communication overhead." + ) + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None + out_col_spec = rhs_spec_new[rhs_outer_dim] + + # Output sharding is conditional on output shape + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] + out_spec = [None, out_col_spec] + if batched_output: + out_spec = batch_spec + out_spec + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # FP8 metas are always unsharded + fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded + gelu_spec = [None, out_col_spec] if fuse_gelu else [None] + gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) + + # Bias gradient spec matches outer dimension of output if bias fusion is turned on + bias_sharding = NamedSharding(mesh, PartitionSpec(out_col_spec if fuse_bias else None)) + + # Validate operand sharding for comm+GEMM overlap and adust extra output sharding + extra_out_spec = [None] + if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": + comm_type = comm_overlap_config.get("comm_type", None) + tp_resource = comm_overlap_config.get("tp_resource", global_mesh_resource().tp_resource) + if comm_type == tex.CommOverlapType.AG: + # AG overlap requires the outer dimension of LHS to be sharded + # over the TP resource + assert lhs_spec[lhs_outer_dim] == tp_resource, ( + "AG+GEMM overlap requires the outer (sequence) dimension of the LHS " + + f"operand to be sharded over the TP resource '{tp_resource=}'." + ) + assert lhs_spec[lhs_inner_dim] is None, ( + "AG+GEMM overlap requires the contracting dimension of the LHS operand " + + "to be unsharded." + ) + assert rhs_spec[rhs_inner_dim] is None, ( + "AG+GEMM overlap requires the contracting dimension of the RHS operand " + + "to be unsharded." + ) + if extra_out.size > 0: + extra_out_spec = list(lhs_spec).copy() + extra_out_spec[lhs_outer_dim] = None + + elif comm_type == tex.CommOverlapType.RS: + # RS overlap requires the contracting dimensions of both LHS and RHS to be + # sharded over the TP resource, and the outer dimensions of LHS and RHS to be + # unsharded. + assert lhs_spec[lhs_outer_dim] is None, ( + "GEMM+RS overlap requires the outer (sequence) dimension of the LHS " + + "operand to be unsharded." + ) + assert lhs_spec[lhs_inner_dim] == tp_resource, ( + "GEMM+RS overlap requires the contracting dimension of the LHS operand " + + f"to be sharded over the TP resource '{tp_resource=}'." + ) + assert rhs_spec[rhs_inner_dim] == tp_resource, ( + "GEMM+RS overlap requires the contracting dimension of the RHS operand " + + f"to be sharded over the TP resource '{tp_resource=}'." + ) + assert rhs_spec[rhs_outer_dim] is None, ( + "GEMM+RS overlap requires the outer dimension of the RHS operand to be " + + "unsharded." + ) + extra_out_spec = list(out_spec).copy() + extra_out_spec[-2] = tp_resource + + extra_out_sharding = NamedSharding(mesh, PartitionSpec(*extra_out_spec)) + + return ( + out_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + gelu_sharding, + bias_sharding, + extra_out_sharding, + ) + + @staticmethod + def partition( + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + comm_overlap_config, + sharded_abstract, + mesh, + arg_infos, + result_infos, + ): + del sharded_abstract, result_infos + lhs, _, rhs, *_, extra_out = arg_infos + lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) + + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim), + ) + + # Modify operand specs + lhs_spec_new = [spec for spec in lhs_spec] + rhs_spec_new = [spec for spec in rhs_spec] + reduce_output = False + if comm_overlap_config is None: + # When comm overlap is not enabled: + # - Always all-gather the outer dimension of LHS. + # - If contracting dims of both operands are sharded, all-gather RHS outer dim. + # - If contracting dim of only one operand is sharded, all-gather the sharded operand. + # - Never scatter any operand. + lhs_spec_new[lhs_outer_dim] = None + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: + rhs_spec_new[rhs_outer_dim] = None + reduce_output = True + else: + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None + out_col_spec = rhs_spec_new[rhs_outer_dim] + + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec_new)) + + # Bias is sharded to match outer dimension spec of the RHS operand (also the output) + bias_sharding = NamedSharding(mesh, PartitionSpec(out_col_spec if fuse_bias else None)) + + # FP8 metas are always unsharded + fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Output sharding is conditional on output shape + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] + out_spec = [None, out_col_spec] + if batched_output: + out_spec = batch_spec + out_spec + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded + gelu_spec = [None, out_col_spec] if fuse_gelu else [None] + gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) + + # Extra output sharding for comm+GEMM overlap + extra_out_spec = [None] + if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": + comm_type = comm_overlap_config.get("comm_type", None) + if comm_type == tex.CommOverlapType.AG and extra_out.size > 0: + extra_out_spec = list(lhs_spec).copy() + extra_out_spec[lhs_outer_dim] = None + elif comm_type == tex.CommOverlapType.RS: + extra_out_spec = list(out_spec).copy() + extra_out_spec[-2] = comm_overlap_config.get( + "tp_resource", global_mesh_resource().tp_resource + ) + extra_out_sharding = NamedSharding(mesh, PartitionSpec(*extra_out_spec)) + + arg_shardings = ( + lhs_sharding, + fp8_meta_sharding, + rhs_sharding, + fp8_meta_sharding, + bias_sharding, + gelu_sharding, + out_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + extra_out_sharding, + ) + out_shardings = ( + out_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + gelu_sharding, + bias_sharding, + extra_out_sharding, + ) + + def sharded_impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out, + out_amax, + out_scale, + extra_out, + ): + ( + out_updated, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + extra_out_updated, + ) = CollectiveGemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out, + out_amax, + out_scale, + extra_out, + batched_output=batched_output, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, + sharded_abstract=True, + ) + + # FP8 amax reduction + if jax_dtype_is_fp8(lhs.dtype): + out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) + + # All-reduce sum GEMM output when contracting dimensions are sharded + if comm_overlap_config is None and reduce_output: + out_updated = jax.lax.psum(out_updated, global_mesh_resource().tp_resource) + if fuse_gelu: + pre_gelu_out = jax.lax.psum(pre_gelu_out, global_mesh_resource().tp_resource) + + return ( + out_updated, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + extra_out_updated, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(CollectiveGemmPrimitive) + + +def gemm_impl( + lhs: ArrayLike, + rhs: ArrayLike, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, + out: Optional[ArrayLike] = None, + extra_out: Optional[ArrayLike] = None, + batched_output: bool = False, + contracting_dims: Tuple[int, int] = (-1, -2), + fuse_gelu: bool = False, + fuse_bias: bool = False, + grad: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, + comm_overlap_config: Optional[dict] = None, +) -> Tuple[ArrayLike, ...]: + """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" + dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim = lhs.ndim - 1 if lhs_inner_dim != lhs.ndim - 1 else lhs.ndim - 2 + rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 + + out_shape_batched = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + out_shape_2d = (reduce(operator.mul, out_shape_batched[:-1], 1), out_shape_batched[-1]) + out_shape = out_shape_batched if batched_output else out_shape_2d + + if out is None: + out = jnp.zeros(out_shape, dtype=lhs.dtype) + + if extra_out is None: + extra_out_shape = (0,) + if ( + comm_overlap_config is not None + and comm_overlap_config["method"] != "bulk" + and comm_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): + extra_out_shape = list(out_shape).copy() + extra_out = jnp.zeros(extra_out_shape, dtype=jnp.bfloat16) + + if not fuse_bias: + bias = jnp.zeros(0, dtype=lhs.dtype) + elif grad: + bias = jnp.zeros(out_shape[-1], dtype=lhs.dtype) + else: + assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." + + if not fuse_gelu: + gelu_input = jnp.zeros(0, dtype=lhs.dtype) + elif grad: + assert ( + gelu_input is not None + ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." + elif gelu_input is None: + gelu_input = jnp.zeros(out_shape_2d, dtype=lhs.dtype) + + ( + out, + _, # out_amax in FP8 GEMM + _, # out_scale in FP8 GEMM + pre_gelu_out, + bias_grad, + extra_out, + ) = CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + dummy_fp8_meta, + rhs, + dummy_fp8_meta, + bias, + gelu_input, + out, + dummy_fp8_meta, + dummy_fp8_meta, + extra_out, + batched_output=batched_output, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, + sharded_abstract=False, + ) + + if grad: + return out, pre_gelu_out, bias_grad, extra_out + else: + return out, pre_gelu_out, extra_out + + +def fp8_gemm_impl( + lhs: ArrayLike, + lhs_scale_inv: ArrayLike, + rhs_t: ArrayLike, + rhs_scale_inv: ArrayLike, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, + out: Optional[ArrayLike] = None, + extra_out: Optional[ArrayLike] = None, + out_amax: Optional[ArrayLike] = None, + out_scale: Optional[ArrayLike] = None, + out_dtype: jnp.dtype = jnp.bfloat16, + batched_output: bool = False, + fuse_gelu: bool = False, + fuse_bias: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, + comm_overlap_config: Optional[dict] = None, +) -> Tuple[ArrayLike, ...]: + """FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" + out_shape_batched = (*lhs.shape[:-2], lhs.shape[-2], rhs_t.shape[-2]) + out_shape_2d = (reduce(operator.mul, out_shape_batched[:-1], 1), out_shape_batched[-1]) + out_shape = out_shape_batched if batched_output else out_shape_2d + + if out is None: + out = jnp.zeros(out_shape, dtype=out_dtype) + else: + out_dtype = out.dtype + + if extra_out is None: + extra_out_shape = (0,) + if ( + comm_overlap_config is not None + and comm_overlap_config["method"] != "bulk" + and comm_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): + extra_out_shape = list(out_shape).copy() + extra_out = jnp.zeros(extra_out_shape, dtype=jnp.bfloat16) + + if jax_dtype_is_fp8(out_dtype): + assert out_amax is not None and out_scale is not None, "Missing output amax and scale." + else: + out_amax = jnp.zeros(0, dtype=jnp.float32) + out_scale = jnp.zeros(0, dtype=jnp.float32) + + if not fuse_bias: + bias = jnp.zeros(0, dtype=jnp.bfloat16) + else: + assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." + + if not fuse_gelu: + gelu_input = jnp.zeros(0, dtype=bias.dtype) + elif gelu_input is None: + gelu_input = jnp.zeros(out_shape_2d, dtype=bias.dtype) + + (out, out_amax, out_scale, pre_gelu_out, _, extra_out) = ( # bias_grad in non-FP8 GEMM + CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + lhs_scale_inv, + rhs_t, + rhs_scale_inv, + bias, + gelu_input, + out, + out_amax, + out_scale, + extra_out, + batched_output=batched_output, + contracting_dims=(-1, -1), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=False, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, + sharded_abstract=False, + ) + ) + + return out, out_amax, out_scale, pre_gelu_out, extra_out + + +class BootstrapCommGemmOverlapPrimitive(BasePrimitive): + """ + Initialize Comm+GEMM overlap communicators and buffers + """ + + name = "te_bootstrap_comm_gemm_overlap_ffi" + impl_static_args = (1,) + multiple_results = False + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(buffer_aval, myrank, numranks, comm_overlap_config): + del myrank, numranks + assert is_ffi_enabled(), "Comm+GEMM overlap is supported only via XLA FFI." + overlap_name = comm_overlap_config.get("name", None) + assert ( + overlap_name in _COMM_GEMM_OVERLAP_NAMES + ), f"Unrecognized comm+GEMM overlap name: {overlap_name=}" + assert buffer_aval.size > 0, "Cannot initialize a zero-size communication buffer." + return jax.core.ShapedArray(shape=(0,), dtype=dtypes.canonicalize_dtype(buffer_aval.dtype)) + + @staticmethod + def lowering(ctx, buffer, *, myrank, numranks, comm_overlap_config): + return ffi.ffi_lowering(BootstrapCommGemmOverlapPrimitive.name)( + ctx, + buffer, + name=comm_overlap_config["name"], + method=comm_overlap_config["method"], + myrank=myrank, + numranks=numranks, + tp_size=comm_overlap_config["tp_size"], + num_splits=comm_overlap_config["num_splits"], + num_max_streams=comm_overlap_config["num_max_streams"], + cga_size=comm_overlap_config["cga_size"], + num_comm_sm=comm_overlap_config["num_sm"], + set_sm_margin=comm_overlap_config["set_sm_margin"], + use_ce=comm_overlap_config["use_ce"], + atomic_gemm=comm_overlap_config["atomic_gemm"], + aggregate=comm_overlap_config["aggregate"], + pipeline_rs_overlap_first_gemm=comm_overlap_config["pipeline_rs_overlap_first_gemm"], + ) + + @staticmethod + def impl(buffer, myrank, numranks, comm_overlap_config): + assert BootstrapCommGemmOverlapPrimitive.inner_primitive is not None + buffer = jax.lax.reshape( + buffer, (reduce(operator.mul, buffer.shape[:-1], 1), buffer.shape[-1]) + ) + return BootstrapCommGemmOverlapPrimitive.inner_primitive.bind( + buffer, + myrank=myrank, + numranks=numranks, + comm_overlap_config=comm_overlap_config, + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, myrank, numranks, comm_overlap_config): + assert BootstrapCommGemmOverlapPrimitive.inner_primitive is not None + check_valid_batch_dims(batch_dims) + return ( + BootstrapCommGemmOverlapPrimitive.inner_primitive.bind( + *batched_args, + myrank=myrank, + numranks=numranks, + comm_overlap_config=comm_overlap_config, + ), + None, + ) + + @staticmethod + def infer_sharding_from_operands( + myrank, numranks, comm_overlap_config, mesh, arg_infos, result_infos + ): + del myrank, numranks, comm_overlap_config, result_infos + buffer_spec = get_padded_spec(arg_infos[0]) + assert all([spec is None for spec in buffer_spec]), "Sample buffer must be unsharded." + return NamedSharding(mesh, PartitionSpec(None)) + + @staticmethod + def partition(myrank, numranks, comm_overlap_config, mesh, arg_infos, result_infos): + del arg_infos, result_infos + arg_shardings = (NamedSharding(mesh, PartitionSpec(None)),) + out_sharding = NamedSharding(mesh, PartitionSpec(None)) + return ( + mesh, + partial( + BootstrapCommGemmOverlapPrimitive.impl, + myrank=myrank, + numranks=numranks, + comm_overlap_config=comm_overlap_config, + ), + out_sharding, + arg_shardings, + ) + + +register_primitive(BootstrapCommGemmOverlapPrimitive) + + +def bootstrap_comm_gemm_overlap( + buffer: ArrayLike, myrank: int, numranks: int, comm_overlap_config: dict +): + _ = BootstrapCommGemmOverlapPrimitive.outer_primitive.bind( + buffer, myrank=myrank, numranks=numranks, comm_overlap_config=comm_overlap_config + ) + + +class CopyIntoOverlapBufferPrimitive(BasePrimitive): + """ + Copy JAX array data into comm+GEMM overlap buffer + """ + + name = "te_copy_into_overlap_buffer_ffi" + impl_static_args = (1, 2) + multiple_results = False + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(inp_aval, name, sharded): + del sharded + assert is_ffi_enabled(), "Comm+GEMM overlap is supported only via XLA FFI." + assert name in _COMM_GEMM_OVERLAP_NAMES, f"Unrecognized comm+GEMM overlap name: {name=}" + assert inp_aval.size > 0, "Cannot copy a zero-size array into overlap buffer." + return jax.core.ShapedArray(shape=(0,), dtype=dtypes.canonicalize_dtype(inp_aval.dtype)) + + @staticmethod + def lowering(ctx, inp, *, name, sharded): + return ffi.ffi_lowering(name)( + ctx, + inp, + name=name, + sharded=sharded, + ) + + @staticmethod + def impl(inp, name, sharded): + assert CopyIntoOverlapBufferPrimitive.inner_primitive is not None + inp_2d = jax.lax.reshape(inp, (reduce(operator.mul, inp.shape[:-1], 1), inp.shape[-1])) + return CopyIntoOverlapBufferPrimitive.inner_primitive.bind( + inp_2d, name=name, sharded=sharded + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, name, sharded): + assert CopyIntoOverlapBufferPrimitive.inner_primitive is not None + check_valid_batch_dims(batch_dims) + return ( + CopyIntoOverlapBufferPrimitive.inner_primitive.bind( + *batched_args, name=name, sharded=sharded + ), + None, + ) + + @staticmethod + def infer_sharding_from_operands(name, sharded, mesh, arg_infos, result_infos): + del name, result_infos + inp_spec = get_padded_spec(arg_infos[0]) + if sharded: + assert inp_spec[-2] is not None, ( + "Leading dimension of input tensor must be sharded in order to copy into a " + + "sharded communication tensor (e.g. preparing for bulk all-gather overlap)." + ) + else: + assert inp_spec[-2] is None, ( + "Leading dimension of input tensor cannot be sharded when copying into an " + + "unsharded communication tensor (e.g. preparing for bulk reduce-scatter overlap)." + ) + return NamedSharding(mesh, PartitionSpec(None)) + + @staticmethod + def partition(name, sharded, mesh, arg_infos, result_infos): + del name, sharded, result_infos + inp_spec = get_padded_spec(arg_infos[0]) + arg_shardings = (NamedSharding(mesh, PartitionSpec(*inp_spec)),) + out_sharding = NamedSharding(mesh, PartitionSpec(None)) + return ( + mesh, + partial(CopyIntoOverlapBufferPrimitive.impl, name=name, sharded=sharded), + out_sharding, + arg_shardings, + ) + + +register_primitive(CopyIntoOverlapBufferPrimitive) + + +def copy_into_overlap_buffer(inp: ArrayLike, name: str, sharded: bool) -> None: + _ = CopyIntoOverlapBufferPrimitive.outer_primitive.bind(inp, name=name, sharded=sharded) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 1f13484b98..15d7537fbd 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -81,6 +81,13 @@ def jax_dtype_to_te_dtype(jax_dtype): return converter.get(jax_dtype) +def jax_dtype_is_fp8(dtype): + """ + Check if the given jax.numpy.dtype is an FP8 dtype. + """ + return dtypes.canonicalize_dtype(dtype) in [jnp.float8_e4m3fn, jnp.float8_e5m2] + + def get_padded_spec(arg_info): """ Get padded spec for partitioning from arguments' information diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 02e6aaf9d5..6bc6d02173 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -4,8 +4,8 @@ * See LICENSE for license information. ************************************************************************/ -#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ -#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ +#ifndef TRANSFORMER_ENGINE_JAX_CSRC_EXTENSIONS_H_ +#define TRANSFORMER_ENGINE_JAX_CSRC_EXTENSIONS_H_ #include #include @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -147,6 +148,29 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right); +struct CustomCallGemmDescriptor { + size_t m; + size_t k; + size_t n; + size_t workspace_size; + DType operand_dtype; + DType bias_dtype; + DType out_dtype; + bool lhs_trans; + bool rhs_trans; + bool fuse_gelu; + bool fuse_bias; + bool grad; + bool accumulate; + bool use_split_accumulator; +}; + +pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, size_t workspace_size, + DType operand_dtype, DType out_dtype, DType bias_dtype, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, + bool fuse_bias, bool grad, bool accumulate, + bool use_split_accumulator); + // Transpose void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); @@ -308,7 +332,74 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); +// GEMM + +XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasltHandleInitHandler); + +void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out, Buffer_Type out_amax, Buffer_Type out_scale, + Buffer_Type dummy_in, Result_Type out_updated, Result_Type out_amax_updated, + Result_Type out_scale_updated, Result_Type pre_gelu_out, Result_Type bias_grad, + Result_Type dummy_out, Result_Type workspace, bool lhs_trans, bool rhs_trans, + bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate, + bool use_split_accumulator); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); + +// Comm+GEMM Overlap + +bool OverlapBufferIsFp8(const std::string &name); + +pybind11::object GetOverlapBuffer(const std::string &name, bool sharded); + +void SetOverlapBufferScaleInverse(const std::string &name, pybind11::object scale_inv, bool grad); + +void BootstrapCommGemmOverlap(const std::vector &buffer_shape, DType buffer_dtype, + const std::string &name, const std::string &method, + CommOverlapType comm_type, int64_t myrank, int64_t numranks, + int64_t tp_size, int64_t num_splits, int64_t num_max_streams, + int64_t cga_size, int64_t num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm, bool aggregate, + bool pipeline_rs_overlap_first_gemm); + +Error_Type BootstrapCommGemmOverlapFFI(cudaStream_t, Buffer_Type sample_buffer, + std::string_view name, std::string_view method, + int64_t comm_type_flag, int64_t myrank, int64_t numranks, + int64_t tp_size, int64_t num_splits, int64_t num_max_streams, + int64_t cga_size, int64_t num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm, bool aggregate, + bool pipeline_rs_overlap_first_gemm); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(BootstrapCommGemmOverlapHandler); + +void DestroyCommGemmOverlap(const std::string &name); + +Error_Type DestroyCommGemmOverlapFFI(cudaStream_t stream, std::string_view name); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(DestroyCommGemmOverlapHandler); + +Error_Type CopyIntoOverlapBufferFFI(cudaStream_t stream, Buffer_Type input, std::string_view name, + bool sharded); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(CopyIntoOverlapBufferHandler); + +Error_Type CommGemmOverlapFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, + Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, + Buffer_Type gelu_input, Buffer_Type out, Buffer_Type out_amax, + Buffer_Type out_scale, Buffer_Type extra_out, Result_Type out_updated, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, + Result_Type extra_out_updated, Result_Type workspace, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator, int64_t comm_type_flag, + std::string_view name); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(CommGemmOverlapHandler); + } // namespace jax } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ +#endif // TRANSFORMER_ENGINE_JAX_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp new file mode 100644 index 0000000000..02b415b321 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp @@ -0,0 +1,309 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common/util/dlpack_helper.h" +#include "extensions.h" + +void _dummy_allgather(void *global, size_t globalbytes, void *local, size_t localbytes, + ExtComm comm) {}; + +void _dummy_barrier(ExtComm comm) {}; + +namespace transformer_engine { + +namespace jax { + +Error_Type CublasltHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets, + Dictionary attrs) { + cublasLtHandle_t handle; + NVTE_CHECK_CUBLAS(cublasLtCreate(&handle)); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CublasltHandleInitHandler, CublasltHandleInitFFI, + FFI::Bind().RemainingArgs().RemainingRets().Attrs()); + +static std::unordered_map _overlaps; + +void SetOverlapBufferScaleInverse(const std::string &name, pybind11::object scale_inv, bool grad) { + auto scale_inv_tensor = DLPackWrapper(scale_inv, grad); + _overlaps[name]->set_ubuf_scale_inv(reinterpret_cast(scale_inv_tensor.dptr())); +} + +bool OverlapBufferIsFp8(const std::string &name) { return _overlaps[name]->is_fp8_ubuf(); } + +pybind11::object GetOverlapBuffer(const std::string &name, bool sharded) { + auto comm_type = (sharded) ? CommOverlapType::RS : CommOverlapType::AG; + DLPackWrapper output = std::move(_overlaps[name]->get_ubuf_output(comm_type)); + auto capsule = output.capsule(); + return capsule; +}; + +void BootstrapCommGemmOverlap(const std::vector &buffer_shape, DType buffer_dtype, + const std::string &name, const std::string &method, + CommOverlapType comm_type, int64_t myrank, int64_t numranks, + int64_t tp_size, int64_t num_splits, int64_t num_max_streams, + int64_t comm_cga_size, int64_t num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm, bool aggregate, + bool pipeline_rs_overlap_first_gemm) { +#ifndef NVTE_UB_WITH_MPI + NVTE_ERROR( + std::string("Comm+GEMM overlap in TE/JAX requires bootstrapping Userbuffers with MPI. ") + + std::string("Please compile TE with `NVTE_UB_WITH_MPI=1`.")); +#endif + + // Initialize overlap object -- this allocates the comm buffer + NVTE_CHECK(_overlaps.find(name) == _overlaps.end(), name, " is already initialized!"); + if (method == "ring_exchange") { + _overlaps[name] = new CommOverlapP2PBase(buffer_shape, buffer_dtype, myrank, numranks, -1, -1, + -1, -1, tp_size, &_dummy_allgather, &_dummy_barrier, + comm_type, num_max_streams, comm_cga_size, num_comm_sm, + set_sm_margin, use_ce, atomic_gemm, aggregate); + } else { + _overlaps[name] = new CommOverlapBase( + buffer_shape, buffer_dtype, myrank, numranks, -1, -1, -1, -1, tp_size, &_dummy_allgather, + &_dummy_barrier, num_splits, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, + atomic_gemm, pipeline_rs_overlap_first_gemm); + } +}; + +Error_Type BootstrapCommGemmOverlapFFI(cudaStream_t, Buffer_Type sample_buffer, + std::string_view name, std::string_view method, + int64_t comm_type_flag, int64_t myrank, int64_t numranks, + int64_t tp_size, int64_t num_splits, int64_t num_max_streams, + int64_t cga_size, int64_t num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm, bool aggregate, + bool pipeline_rs_overlap_first_gemm) { + auto buffer_shape = + std::vector(sample_buffer.dimensions().begin(), sample_buffer.dimensions().end()); + auto buffer_dtype = convert_ffi_datatype_to_te_dtype(sample_buffer.element_type()); + BootstrapCommGemmOverlap(buffer_shape, buffer_dtype, static_cast(name), + static_cast(method), + static_cast(comm_type_flag), myrank, numranks, tp_size, + num_splits, num_max_streams, cga_size, num_comm_sm, set_sm_margin, + use_ce, atomic_gemm, aggregate, pipeline_rs_overlap_first_gemm); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(BootstrapCommGemmOverlapHandler, BootstrapCommGemmOverlapFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // sample_buffer + .Attr("name") + .Attr("method") + .Attr("comm_type_flag") + .Attr("myrank") + .Attr("numranks") + .Attr("tp_size") + .Attr("num_splits") + .Attr("num_max_streams") + .Attr("cga_size") + .Attr("num_comm_sm") + .Attr("set_sm_margin") + .Attr("use_ce") + .Attr("atomic_gemm") + .Attr("aggregate") + .Attr("pipeline_rs_overlap_first_gemm"), + FFI_CudaGraph_Traits); + +void DestroyCommGemmOverlap(const std::string &name) { + auto overlap = _overlaps.find(name); + if (overlap != _overlaps.end()) { + delete overlap->second; + _overlaps.erase(overlap); + } +}; + +Error_Type DestroyCommGemmOverlapFFI(cudaStream_t stream, std::string_view name) { + DestroyCommGemmOverlap(static_cast(name)); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DestroyComMGemmOverlapHandler, DestroyCommGemmOverlapFFI, + FFI::Bind().Ctx().Attr("name"), + FFI_CudaGraph_Traits); + +void CopyIntoOverlapBufferImpl(cudaStream_t stream, void *input_ptr, + const std::vector &shape, DType dtype, + const std::string &name, bool sharded) { + auto input = TensorWrapper(input_ptr, shape, dtype); + auto comm_type = (sharded) ? CommOverlapType::RS : CommOverlapType::AG; + _overlaps[name]->copy_into_ubuf(stream, input, comm_type); +} + +Error_Type CopyIntoOverlapBufferFFI(cudaStream_t stream, Buffer_Type input, std::string_view name, + bool sharded) { + auto input_ptr = input.untyped_data(); + auto shape = std::vector(input.dimensions().begin(), input.dimensions().end()); + auto dtype = convert_ffi_datatype_to_te_dtype(input.element_type()); + + CopyIntoOverlapBufferImpl(stream, input_ptr, shape, dtype, static_cast(name), + sharded); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CopyIntoOverlapBufferHandler, CopyIntoOverlapBufferFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Attr("name") + .Attr("sharded"), + FFI_CudaGraph_Traits); + +void CommGemmOverlapImpl(void *lhs, const std::vector &lhs_shape, DType lhs_dtype, + float *lhs_scale_inv, bool lhs_trans, void *rhs, + const std::vector &rhs_shape, DType rhs_dtype, + float *rhs_scale_inv, bool rhs_trans, void *out, + const std::vector &out_shape, DType out_dtype, float *out_amax, + float *out_scale, void *bias, DType bias_dtype, void *pre_gelu_out, + void *extra_out, const std::vector &extra_out_shape, + void *workspace, size_t workspace_size, bool fuse_gelu, bool fuse_bias, + bool grad, bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, const std::string &name, cudaStream_t stream) { + auto lhs_ = TensorWrapper(lhs, lhs_shape, lhs_dtype, nullptr, nullptr, lhs_scale_inv); + auto rhs_ = TensorWrapper(rhs, rhs_shape, rhs_dtype, nullptr, nullptr, rhs_scale_inv); + auto out_ = TensorWrapper(out, out_shape, out_dtype, out_amax, out_scale, nullptr); + + auto bias_ptr = (fuse_bias) ? bias : nullptr; + auto bias_shape = (fuse_bias) ? std::vector(out_shape.back()) : std::vector{0}; + auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + + auto pre_gelu_ptr = (fuse_gelu) ? pre_gelu_out : nullptr; + auto pre_gelu_shape = (fuse_gelu) ? out_shape : std::vector{0}; + auto pre_gelu_out_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, bias_dtype); + + auto workspace_ = TensorWrapper(workspace, std::vector{workspace_size}, DType::kByte); + + auto extra_out_ = + TensorWrapper(extra_out, extra_out_shape, lhs_dtype, nullptr, nullptr, lhs_scale_inv); + + auto overlap = _overlaps[name]; + if (comm_type == CommOverlapType::AG) { + // AG overlap is only ring-exchange + if (overlap->is_atomic_gemm()) { + overlap->atomic_gemm_overlap_ag(rhs_, rhs_trans, lhs_, lhs_trans, out_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + extra_out_, stream); + } else { + overlap->split_overlap_ag(rhs_, rhs_trans, lhs_, lhs_trans, out_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, extra_out_, + stream); + } + } else if (comm_type == CommOverlapType::RS) { + if (overlap->is_atomic_gemm()) { + overlap->atomic_gemm_overlap_rs(rhs_, rhs_trans, lhs_, lhs_trans, out_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + extra_out_, stream); + } else { + overlap->split_overlap_rs(rhs_, rhs_trans, lhs_, lhs_trans, out_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, extra_out_, + stream); + } + } +} + +Error_Type CommGemmOverlapFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, + Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, + Buffer_Type gelu_input, Buffer_Type out, Buffer_Type out_amax, + Buffer_Type out_scale, Buffer_Type extra_out, Result_Type out_updated, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, + Result_Type extra_out_updated, Result_Type workspace, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator, int64_t comm_type_flag, + std::string_view name) { + // Inputs + auto lhs_ptr = lhs.untyped_data(); + auto lhs_shape = std::vector(lhs.dimensions().begin(), lhs.dimensions().end()); + auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + auto lhs_scale_inv_ptr = reinterpret_cast(lhs_scale_inv.untyped_data()); + auto rhs_ptr = rhs.untyped_data(); + auto rhs_shape = std::vector(rhs.dimensions().begin(), rhs.dimensions().end()); + auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs.element_type()); + auto rhs_scale_inv_ptr = reinterpret_cast(rhs_scale_inv.untyped_data()); + auto bias_ptr = bias.untyped_data(); + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + auto gelu_input_ptr = gelu_input.untyped_data(); + auto out_ptr = out.untyped_data(); + auto out_amax_ptr = reinterpret_cast(out_amax.untyped_data()); + auto out_scale_ptr = reinterpret_cast(out_scale.untyped_data()); + auto extra_out_ptr = extra_out.untyped_data(); + + // Outputs + auto out_updated_ptr = out_updated->untyped_data(); + auto out_shape = + std::vector(out_updated->dimensions().begin(), out_updated->dimensions().end()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(out_updated->element_type()); + auto out_amax_updated_ptr = reinterpret_cast(out_amax_updated->untyped_data()); + auto out_scale_updated_ptr = reinterpret_cast(out_scale_updated->untyped_data()); + auto pre_gelu_ptr = pre_gelu_out->untyped_data(); + auto bias_grad_ptr = bias_grad->untyped_data(); + auto extra_out_updated_ptr = extra_out_updated->untyped_data(); + auto extra_out_shape = std::vector(extra_out_updated->dimensions().begin(), + extra_out_updated->dimensions().end()); + auto workspace_ptr = workspace->untyped_data(); + auto workspace_size = workspace->element_count(); + + // Check operand-output aliases + NVTE_CHECK(bias_ptr == bias_grad_ptr, "bias not bound to bias_grad in TE/JAX comm+GEMM overlap."); + NVTE_CHECK(gelu_input_ptr == pre_gelu_ptr, + "gelu_input not bound to pre_gelu_out in TE/JAX comm+GEMM overlap."); + NVTE_CHECK(out_ptr == out_updated_ptr, + "out not bound to out_updated in TE/JAX comm+GEMM overlap."); + NVTE_CHECK(out_amax_ptr == out_amax_updated_ptr, + "out_amax not bound to out_amax_updated in TE/JAX comm+GEMM overlap."); + NVTE_CHECK(out_scale_ptr == out_scale_updated_ptr, + "out_scale not bound to out_scale_updated in TE/JAX comm+GEMM overlap."); + if (extra_out.element_count() > 0) { + NVTE_CHECK(extra_out_ptr == extra_out_updated_ptr, + "extra_out not bound to extra_out_updated in TE/JAX comm+GEMM overlap."); + } + + CommGemmOverlapImpl( + lhs_ptr, lhs_shape, lhs_dtype, lhs_scale_inv_ptr, lhs_trans, rhs_ptr, rhs_shape, rhs_dtype, + rhs_scale_inv_ptr, rhs_trans, out_ptr, out_shape, out_dtype, out_amax_ptr, out_scale_ptr, + bias_ptr, bias_dtype, pre_gelu_ptr, extra_out_ptr, extra_out_shape, workspace_ptr, + workspace_size, fuse_gelu, fuse_bias, grad, accumulate, use_split_accumulator, + static_cast(comm_type_flag), static_cast(name), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CommGemmOverlapHandler, CommGemmOverlapFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Arg() // out + .Arg() // out_amax + .Arg() // out_scale + .Arg() // extra_out + .Ret() // out_updated + .Ret() // out_amax_updated + .Ret() // out_scale_updated + .Ret() // pre_gelu_out + .Ret() // bias_grad + .Ret() // extra_out_updated + .Ret() // workspace + .Attr("lhs_trans") + .Attr("rhs_trans") + .Attr("fuse_gelu") + .Attr("fuse_bias") + .Attr("grad") + .Attr("accumulate") + .Attr("use_split_accumulator") + .Attr("comm_type_flag") + .Attr("name"), + FFI_CudaGraph_Traits); + +} // namespace jax + +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp new file mode 100644 index 0000000000..44a2d55f8e --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -0,0 +1,178 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/gemm.h" + +#include "common/util/cuda_runtime.h" +#include "common/util/system.h" +#include "extensions.h" + +namespace transformer_engine { + +namespace jax { + +void GemmImpl(cudaStream_t stream, void *lhs, const std::vector &lhs_shape, + float *lhs_scale_inv, bool lhs_trans, void *rhs, const std::vector &rhs_shape, + float *rhs_scale_inv, bool rhs_trans, DType operand_dtype, void *bias, + DType bias_dtype, void *out, float *out_amax, float *out_scale, DType out_dtype, + void *pre_gelu_out, void *workspace, size_t workspace_size, bool fuse_gelu, + bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator) { + auto lhs_ = TensorWrapper(lhs, lhs_shape, operand_dtype, nullptr, nullptr, lhs_scale_inv); + auto rhs_ = TensorWrapper(rhs, rhs_shape, operand_dtype, nullptr, nullptr, rhs_scale_inv); + + std::vector out_shape(2, 0); + out_shape[0] = (lhs_trans) ? lhs_shape[1] : lhs_shape[0]; + out_shape[1] = (rhs_trans) ? rhs_shape[0] : rhs_shape[1]; + auto out_ = TensorWrapper(out, out_shape, out_dtype, out_amax, out_scale, nullptr); + + void *bias_ptr = (fuse_bias) ? bias : nullptr; + std::vector bias_shape = + (fuse_bias) ? std::vector{out_shape[1]} : std::vector{0}; + auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + + void *pre_gelu_ptr = (fuse_gelu) ? pre_gelu_out : nullptr; + std::vector pre_gelu_shape = (fuse_gelu) ? out_shape : std::vector{0}; + auto pre_gelu_out_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, bias_dtype); + auto workspace_ = TensorWrapper(workspace, std::vector{workspace_size}, DType::kByte); + + // cuBLAS is column-major, so we swap LHS and RHS in the arguments + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_out_.data(), + (rhs_trans) ? CUBLAS_OP_T : CUBLAS_OP_N, (lhs_trans) ? CUBLAS_OP_T : CUBLAS_OP_N, + grad, workspace_.data(), accumulate, use_split_accumulator, num_math_sm, stream); +} + +void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + // Inputs + auto *lhs = buffers[0]; + auto *lhs_scale_inv = reinterpret_cast(buffers[1]); + auto *rhs = buffers[2]; + auto *rhs_scale_inv = reinterpret_cast(buffers[3]); + auto *bias = buffers[4]; + auto *gelu_input = buffers[5]; + auto *out = buffers[6]; + auto *out_amax = reinterpret_cast(buffers[7]); + auto *out_scale = reinterpret_cast(buffers[8]); + // buffers[9] is the extra output bufer for comm+GEMM overlap, not used here + + // Outputs + auto *out_updated = buffers[10]; + auto *out_amax_updated = reinterpret_cast(buffers[11]); + auto *out_scale_updated = reinterpret_cast(buffers[12]); + auto *pre_gelu_out = buffers[13]; + auto *bias_grad = buffers[14]; + // buffers[15] is the updated extra output for comm+GEMM overlap, not used here + auto *workspace = buffers[16]; + + // Operand aliasing + NVTE_CHECK(bias == bias_grad, "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(gelu_input == pre_gelu_out, "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out == out_updated, "out not bound to out_updated in TE/JAX GEMM"); + NVTE_CHECK(out_amax == out_amax_updated, "out_amax not bound to out_amax_updated in TE/JAX GEMM"); + NVTE_CHECK(out_scale == out_scale_updated, + "out_scale not bound to out_scale_updated in TE/JAX GEMM"); + + // GEMM sizing + const auto &desc = *UnpackOpaque(opaque, opaque_len); + std::vector lhs_shape = {(desc.lhs_trans) ? desc.k : desc.m, + (desc.lhs_trans) ? desc.m : desc.k}; + std::vector rhs_shape = {(desc.rhs_trans) ? desc.n : desc.k, + (desc.rhs_trans) ? desc.k : desc.n}; + + GemmImpl(stream, lhs, lhs_shape, lhs_scale_inv, desc.lhs_trans, rhs, rhs_shape, rhs_scale_inv, + desc.rhs_trans, desc.operand_dtype, bias, desc.bias_dtype, out, out_amax, out_scale, + desc.out_dtype, pre_gelu_out, workspace, desc.workspace_size, desc.fuse_gelu, + desc.fuse_bias, desc.grad, desc.accumulate, desc.use_split_accumulator); +} + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out, Buffer_Type out_amax, Buffer_Type out_scale, + Buffer_Type dummy_in, Result_Type out_updated, Result_Type out_amax_updated, + Result_Type out_scale_updated, Result_Type pre_gelu_out, Result_Type bias_grad, + Result_Type dummy_out, Result_Type workspace, bool lhs_trans, bool rhs_trans, + bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate, + bool use_split_accumulator) { + // Inputs + auto lhs_ptr = lhs.untyped_data(); + auto lhs_scale_inv_ptr = reinterpret_cast(lhs_scale_inv.untyped_data()); + auto rhs_ptr = rhs.untyped_data(); + auto rhs_scale_inv_ptr = reinterpret_cast(rhs_scale_inv.untyped_data()); + auto operand_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + auto bias_ptr = bias.untyped_data(); + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + auto gelu_input_ptr = gelu_input.untyped_data(); + auto out_ptr = out.untyped_data(); + auto out_amax_ptr = reinterpret_cast(out_amax.untyped_data()); + auto out_scale_ptr = reinterpret_cast(out_scale.untyped_data()); + // dummy_in is the extra output buffer for comm+GEMM overlap, not used here + + // Outputs + auto out_updated_ptr = out_updated->untyped_data(); + auto out_amax_updated_ptr = reinterpret_cast(out_amax_updated->untyped_data()); + auto out_scale_updated_ptr = reinterpret_cast(out_scale_updated->untyped_data()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(out_updated->element_type()); + auto pre_gelu_out_ptr = pre_gelu_out->untyped_data(); + auto bias_grad_ptr = bias_grad->untyped_data(); + // dummy_out is the updated extra output for comm+GEMM overlap, not used here + auto workspace_ptr = workspace->untyped_data(); + auto workspace_size = workspace->dimensions().back(); + + // Operand aliasing + NVTE_CHECK(bias_ptr == bias_grad_ptr, "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(gelu_input_ptr == pre_gelu_out_ptr, + "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out_ptr == out_updated_ptr, "out not bound to out_updated in TE/JAX GEMM"); + NVTE_CHECK(out_amax_ptr == out_amax_updated_ptr, + "out_amax not bound to out_amax_updated in TE/JAX GEMM"); + NVTE_CHECK(out_scale_ptr == out_scale_updated_ptr, + "out_scale not bound to out_scale_updated in TE/JAX GEMM"); + + // GEMM sizing + std::vector lhs_shape(lhs.dimensions().begin(), lhs.dimensions().end()); + std::vector rhs_shape(rhs.dimensions().begin(), rhs.dimensions().end()); + + // Swap A and B argument locations to match what the TE/common kernel expects + GemmImpl(stream, lhs_ptr, lhs_shape, lhs_scale_inv_ptr, lhs_trans, rhs_ptr, rhs_shape, + rhs_scale_inv_ptr, rhs_trans, operand_dtype, bias_ptr, bias_dtype, out_ptr, out_amax_ptr, + out_scale_ptr, out_dtype, pre_gelu_out_ptr, workspace_ptr, workspace_size, fuse_gelu, + fuse_bias, grad, accumulate, use_split_accumulator); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Arg() // out + .Arg() // out_amax + .Arg() // out_scale + .Arg() // dummy_in + .Ret() // out_updated + .Ret() // out_amax_updated + .Ret() // out_scale_updated + .Ret() // pre_gelu_out + .Ret() // bias_grad + .Ret() // dummy_out + .Ret() // workspace + .Attr("lhs_trans") + .Attr("rhs_trans") + .Attr("fuse_gelu") + .Attr("fuse_bias") + .Attr("grad") + .Attr("accumulate") + .Attr("use_split_accumulator"), + FFI_CudaGraph_Traits); + +} // namespace jax + +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 298478603b..dd4070af41 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -80,5 +80,15 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( deterministic, window_size_left, window_size_right}); } +pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, size_t workspace_size, + DType operand_dtype, DType bias_dtype, DType out_dtype, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, + bool fuse_bias, bool grad, bool accumulate, + bool use_split_accumulator) { + return PackOpaque(CustomCallGemmDescriptor{m, n, k, workspace_size, operand_dtype, bias_dtype, + out_dtype, lhs_trans, rhs_trans, fuse_gelu, fuse_bias, + grad, accumulate, use_split_accumulator}); +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 9b5c156e5d..c61e9c8127 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -4,9 +4,11 @@ * See LICENSE for license information. ************************************************************************/ +#include "common/util/pybind_helper.h" #include "extensions.h" namespace transformer_engine { + namespace jax { template @@ -51,6 +53,7 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); + dict["te_gemm"] = EncapsulateFunction(Gemm); // Transpose dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler); @@ -101,10 +104,24 @@ pybind11::dict Registrations() { fused_attn_backward_ffi["execute"] = EncapsulateFFI(FusedAttnBackwardHandler); dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi; + pybind11::dict gemm_ffi; + gemm_ffi["prepare"] = EncapsulateFFI(CublasltHandleInitHandler); + gemm_ffi["execute"] = EncapsulateFFI(GemmHandler); + dict["te_gemm_ffi"] = gemm_ffi; + + dict["te_bootstrap_comm_gemm_overlap_ffi"] = EncapsulateFFI(BootstrapCommGemmOverlapHandler); + dict["te_copy_into_overlap_buffer_ffi"] = EncapsulateFFI(CopyIntoOverlapBufferHandler); + + pybind11::dict comm_gemm_overlap_ffi; + comm_gemm_overlap_ffi["prepare"] = EncapsulateFFI(CublasltHandleInitHandler); + comm_gemm_overlap_ffi["execute"] = EncapsulateFFI(CommGemmOverlapHandler); + dict["te_comm_gemm_overlap_ffi"] = comm_gemm_overlap_ffi; return dict; } PYBIND11_MODULE(transformer_engine_jax, m) { + NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + m.def("registrations", &Registrations); m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor, pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0); @@ -114,10 +131,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); + m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor); m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cudnn_version", &GetCudnnRuntimeVersion); - m.def("get_device_compute_capability", &GetDeviceComputeCapability); + m.def("get_device_compute_capability", &GetDeviceComputeCapability, pybind11::arg("gpu_id") = -1); m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes); m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes); @@ -126,63 +144,14 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); - - pybind11::enum_(m, "DType", pybind11::module_local()) - .value("kByte", DType::kByte) - .value("kInt32", DType::kInt32) - .value("kInt64", DType::kInt64) - .value("kFloat32", DType::kFloat32) - .value("kFloat16", DType::kFloat16) - .value("kBFloat16", DType::kBFloat16) - .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); - - pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - - pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - - pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); - - pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) - .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) - .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) - .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); - - pybind11::enum_(m, "NVTE_Activation_Type", pybind11::module_local()) - .value("GELU", NVTE_Activation_Type::GELU) - .value("GEGLU", NVTE_Activation_Type::GEGLU) - .value("SILU", NVTE_Activation_Type::SILU) - .value("SWIGLU", NVTE_Activation_Type::SWIGLU) - .value("RELU", NVTE_Activation_Type::RELU) - .value("REGLU", NVTE_Activation_Type::REGLU) - .value("QGELU", NVTE_Activation_Type::QGELU) - .value("QGEGLU", NVTE_Activation_Type::QGEGLU) - .value("SRELU", NVTE_Activation_Type::SRELU) - .value("SREGLU", NVTE_Activation_Type::SREGLU) - .export_values(); - - pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8); + m.def("bootstrap_comm_gemm_overlap", &BootstrapCommGemmOverlap); + m.def("destroy_comm_gemm_overlap", &DestroyCommGemmOverlap); + m.def("set_buffer_scale_inv", &SetOverlapBufferScaleInverse, pybind11::arg(), pybind11::arg(), + pybind11::arg("grad") = false); + m.def("get_overlap_buffer", &GetOverlapBuffer); + m.def("overlap_buffer_is_fp8", &OverlapBufferIsFp8); } } // namespace jax + } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index 32de33bac9..b328c6e278 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -23,7 +23,7 @@ namespace jax { int GetCudaRuntimeVersion(); size_t GetCudnnRuntimeVersion(); -int GetDeviceComputeCapability(int gpu_id); +int GetDeviceComputeCapability(int gpu_id = -1); void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8b13c47cd4..abe23fdf8b 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -334,6 +334,7 @@ def generate_fp8_meta_set(postfix: str) -> FP8MetaPackage: input_name_post_fix = f"_i_{postfix}" weight_name_post_fix = f"_w_{postfix}" grad_name_post_fix = f"_g_{postfix}" + output_name_post_fix = f"_o_{postfix}" def generate_a_set(target_postfix): amax = nn_partitioning.variable_with_axes( @@ -359,9 +360,17 @@ def generate_a_set(target_postfix): input_amax, input_scale = generate_a_set(input_name_post_fix) weight_amax, weight_scale = generate_a_set(weight_name_post_fix) grad_amax, grad_scale = generate_a_set(grad_name_post_fix) + output_amax, output_scale = generate_a_set(output_name_post_fix) return FP8MetaPackage( - input_amax, input_scale, weight_amax, weight_scale, grad_amax, grad_scale + input_amax, + input_scale, + weight_amax, + weight_scale, + grad_amax, + grad_scale, + output_amax, + output_scale, ) diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index 5df8ce4386..3d58c86e3e 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -86,10 +86,11 @@ class FP8MetaPackage: A container that contains all required meta data for FP8 """ - NUM_OF_META: int = 3 + NUM_OF_META: int = 4 INPUT_IDX: int = 0 WEIGHT_IDX: int = 1 GRAD_IDX: int = 2 + OUTPUT_IDX: int = 3 def __init__( self, @@ -99,6 +100,8 @@ def __init__( weight_scale: jnp.ndarray, grad_amax: jnp.ndarray, grad_scale: jnp.ndarray, + output_amax: jnp.ndarray, + output_scale: jnp.ndarray, ) -> None: self._amax_list = [None] * FP8MetaPackage.NUM_OF_META @@ -110,6 +113,8 @@ def __init__( self._scale_list[FP8MetaPackage.WEIGHT_IDX] = weight_scale self._amax_list[FP8MetaPackage.GRAD_IDX] = grad_amax self._scale_list[FP8MetaPackage.GRAD_IDX] = grad_scale + self._amax_list[FP8MetaPackage.OUTPUT_IDX] = output_amax + self._scale_list[FP8MetaPackage.OUTPUT_IDX] = output_scale @property def amax_list(self) -> List[jnp.ndarray]: diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py new file mode 100644 index 0000000000..06cd52e97f --- /dev/null +++ b/transformer_engine/jax/gemm.py @@ -0,0 +1,1051 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import os +import warnings +import operator +from functools import partial, reduce +from typing import Optional, Tuple, Union, Sequence + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike + +from transformer_engine import transformer_engine_jax as tex +from .fp8 import FP8Helper, FP8MetaPackage +from .cpp_extensions import ( + gemm_impl, + fp8_gemm_impl, + cast_transpose, + dact_lu, + dbias_cast_transpose, + dact_lu_dbias_cast_transpose, + get_num_max_compute_streams, +) + +from .cpp_extensions.gemm import sanitize_dims, mirror_dim, copy_into_overlap_buffer +from .cpp_extensions.misc import jax_dtype_is_fp8, jax_dtype_to_te_dtype +from .sharding import get_mesh_axis_size, global_mesh_resource + + +__all__ = [ + "gemm", + "fp8_gemm", + "type_safe_gemm", + "initialize_comm_gemm_overlaps", + "destroy_comm_gemm_overlap", + "get_comm_gemm_overlap_config", +] + + +_ACTIVE_COMM_GEMM_OVERLAPS = dict() + + +def gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Optional[ArrayLike] = None, + contracting_dims: Tuple[int, int] = (-1, -2), + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, + comm_overlap_name: Optional[str] = None, + ag_overlap_skip_copy: bool = False, +) -> ArrayLike: + """ + Non-FP8 collective/distributed `nvte_cublas_gemm()` with GELU and bias-add fusions. + + Parameters + ---------- + x : ArrayLike + LHS operand, sized ([B], M, K) when not transposed. + kernel : ArrayLike + RHS operand, sized (K, N) when not transposed. + bias : Optional[ArrayLike], default = `None` + Optional bias term to add onto the (LHS x RHS) result. + contracting_dims : Tuple[int, int], default = `(-1, 0)` + Contracting dimensions of LHS and RHS, respectively, in the matrix-multiplication. + The default (-1, 0) describes the fully non-transposed 'NN' layout where LHS contracts in + the last dimension, and RHS contracts in the first dimension. + fuse_gelu : bool, default = `False` + Enable the GELU epilogue for GEMM. This applies GELU after the bias-addition if the bias + term is not `None`. + accumulate : bool, default = `False` + use_split_accumulator : bool, default = `False` + comm_overlap_name : Optional[str], default = `None` + Name of the comm+GEMM overlap layer that this GEMM is associated with. Comm+GEMM overlap + must be initialized with `te.jax.gemm.initialize_comm_gemm_overlaps()` before this + GEMM call, and the configuration dictionary used in the initialization must include + the name passed into this function. + ag_overlap_skip_copy: bool = `False` + All-gather overlap requires the LHS operand to be copied into the communication buffer. + If the communication buffer already has the necessary data, setting this flag will + avoid an unnecessary memcpy operation. + """ + comm_overlap_config = None + if comm_overlap_name is not None: + global _ACTIVE_COMM_GEMM_OVERLAPS + comm_overlap_layer = ( + comm_overlap_name + "_fprop" + if comm_overlap_name not in ["ag_gemm", "gemm_rs"] + else comm_overlap_name + ) + comm_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(comm_overlap_layer, None) + if comm_overlap_config is None: + warnings.warn( + f"Comm+GEMM overlap for {comm_overlap_name} has not been initialized! " + + "Sharded operands will trigger XLA collectives instead." + ) + + elif ( + not ag_overlap_skip_copy + and comm_overlap_config["method"] != "bulk" + and comm_overlap_config["comm_type"] == tex.CommOverlapType.AG + ): + if sanitize_dims(contracting_dims[0], x.ndim) != x.ndim - 1: + x = jnp.matrix_transpose(x) + copy_into_overlap_buffer(x, comm_overlap_name, True) + + return _gemm( + x, + kernel, + bias, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + comm_overlap_config, + ) + + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7)) +def _gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Union[ArrayLike, None], + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap_config: dict, +) -> ArrayLike: + out, _ = _gemm_fwd_rule( + x, + kernel, + bias, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + comm_overlap_config, + ) + return out + + +def _gemm_fwd_rule( + x: ArrayLike, + kernel: ArrayLike, + bias: ArrayLike, + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap_config: dict, +) -> Tuple[ArrayLike, ...]: + assert ( + kernel.ndim == 2 + ), "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." + + fuse_bias = bias is not None + + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) --> ([B], M, N/P) + # + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # + # GEMM+RS: ([B], M, K/P) x (K/P, N) --(RS)--> ([B], M/P, N) + out, pre_gelu_out, extra_out = gemm_impl( + x, + kernel, + bias=bias, + batched_output=(x.ndim > 2), + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, + ) + + final_out = out + if ( + comm_overlap_config is not None + and comm_overlap_config["method"] != "bulk" + and comm_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): + # Non-bulk RS overlap output is in extra output, not usual output + final_out = extra_out + + ctx = ( + x, + kernel, + pre_gelu_out if fuse_gelu else None, + fuse_bias, + ) + + return final_out, ctx + + +def _gemm_bwd_rule( + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + comm_overlap_config, + ctx, + grad, +): + x, kernel, pre_gelu_out, fuse_bias = ctx + x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) + x_outer_dim, kernel_outer_dim = map( + mirror_dim, (x_inner_dim, kernel_inner_dim), (x.ndim, kernel.ndim) + ) + + # Recover DGRAD and WGRAD comm+GEMM overlap configs + dgrad_overlap_name = None + dgrad_overlap_config = None + wgrad_overlap_name = None + wgrad_overlap_config = None + if comm_overlap_config is not None: + dgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_dgrad" + dgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(dgrad_overlap_name, None) + wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" + wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) + + dgrad_pre_rs = None + if dgrad_overlap_config is not None: + if dgrad_overlap_config["method"] == "bulk": + # Set DGRAD output buffer to the comm buffer of WGRAD GEMM in order to do the + # bulk RS overlap without an extra memcpy. + assert ( + wgrad_overlap_config is not None + ), f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" + dgrad_pre_rs = tex.get_overlap_buffer(wgrad_overlap_name, False) + + # Copy transposed input into the DGRAD overlap buffer for bulk AG. + copy_into_overlap_buffer(jnp.matrix_transpose(x), dgrad_overlap_name, True) + + # FWD MODE: + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) + # + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # + # GEMM+RS: ([B], M, K/P) x (K/P, N) --(RS)--> ([B], M/P, N) + + # DGRAD w/o Overlap: + # AG+GEMM: ([B], M, N/P) x (K, N/P)^T ---(AR)---> ([B], M, K) + # + # GEMM+AR: ([B], M, N) x (K/P, N)^T ----> ([B], M, K/P) + # + # DGRAD w/ Overlap: + # AG+GEMM w/ DGRAD+RS Overlap: ([B], M, N/P) x (K, N/P)^T ---(RS)---> ([B], M/P, K) + # + # AG+GEMM w/ Bulk AG Overlap: ([B], M, N/P) x (K, N/P)^T -----> ([B], M, K) (deferred RS) + # ([B], M, K/P)^T --(Bulk AG)--> ([B], M, K)^T (needed in WGRAD) + # + # GEMM+RS: ([B], M/P, N) --(AG)--> ([B], M, N) x (K/P, N)^T ----> ([B], M, K/P) + dgrad, dgelu, _, dgrad_extra_out = gemm_impl( + grad, + kernel, + out=dgrad_pre_rs, + gelu_input=pre_gelu_out, + batched_output=(x.ndim > 2), + contracting_dims=(-1, kernel_outer_dim), + fuse_gelu=fuse_gelu, + fuse_bias=False, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_config=dgrad_overlap_config, + ) + + if ( + dgrad_overlap_config is not None + and dgrad_overlap_config["method"] != "bulk" + and dgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): + # Otherwise, if DGRAD overlap is RS overlap, DGRAD output is the extra output tensor + dgrad = dgrad_extra_out + + # WGRAD w/o Overlap: + # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) + # + # GEMM+AR: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N) ---------> (K, N) + # + # WGRAD w/ Overlap: + # AG+GEMM w/ DGRAD+RS Overlap: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) + # + # AG+GEMM w/ Bulk Overlaps: ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) + # ([B], M, K) --(Bulk RS)--> ([B], M/P, K) (finalize DGRAD) + # + # GEMM+RS: ([B], M, K/P)^T x ([B], M, N) --> (K/P, N) (re-use all-gathered GRAD from DGRAD) + wgrad_rhs = dgelu if fuse_gelu else (grad if comm_overlap_config is None else dgrad_extra_out) + wgrad, _, bgrad, wgrad_extra_out = gemm_impl( + x, + wgrad_rhs, + gelu_input=pre_gelu_out, + batched_output=False, + contracting_dims=(x_outer_dim, wgrad_rhs.ndim - 2), + fuse_gelu=False, + fuse_bias=fuse_bias, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_config=wgrad_overlap_config, + ) + + if ( + wgrad_overlap_config is not None + and wgrad_overlap_config["method"] == "bulk" + and wgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): + # DGRAD was reduce-scattered during WGRAD GEMM, so set DGRAD to WGRAD extra output here + dgrad = wgrad_extra_out + + if not fuse_bias: + bgrad = None + + return dgrad, wgrad, bgrad + + +_gemm.defvjp(_gemm_fwd_rule, _gemm_bwd_rule) + + +def fp8_gemm( + x: ArrayLike, + kernel_t: ArrayLike, + fp8_meta: FP8MetaPackage, + bias: Optional[ArrayLike] = None, + out: Optional[ArrayLike] = None, + out_dtype: jnp.dtype = jnp.bfloat16, + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, + comm_overlap_name: Optional[str] = None, + ag_overlap_skip_copy: bool = False, +) -> ArrayLike: + """ + FP8 collective/distributed `nvte_cublas_gemm()` with GELU and bias-add fusions. + + FP8 GEMM requires the LHS operand to be non-transposed, and the RHS operand to be transposed, + such that the contracting dimensions are always the last dimension for both operands. + + Parameters + ---------- + x : ArrayLike + Non-transposed LHS operand, sized ([B], M, K). + kernel_t : ArrayLike + Transposed RHS operand, sized (N, K). + fp8_meta : transformer_engine.jax.fp8.FP8MetaPackage + FP8MetaPackage object carrying amax, scale and scale_inv information for the GEMM operands. + bias : Optional[ArrayLike], default = `None` + Optional bias term to add onto the (LHS x RHS) result. + out: Optional[ArrayLike], default = `None` + Optional empty buffer for FP8 GEMM output. + out_dtype : jnp.dtype, default = `jnp.bfloat16` + Data type of the FP8 GEMM output. If chosen as an FP8 dtype (i.e. `jnp.float8_e4m3fn` or + `jnp.float8_e5m2`), the `fp8_meta` must also contain amax and scale information for the + GEMM output. This option is overridden by the data type of the `out` buffer, if given. + fuse_gelu : bool, default = `False` + Enable the GELU epilogue for GEMM. This applies GELU after the bias-addition if the bias + term is not `None`. + accumulate : bool, default = `False` + use_split_accumulator : bool, default = `False` + comm_overlap_name : Optional[str], default = `None` + Name of the comm+GEMM overlap layer that this GEMM is associated with. Comm+GEMM overlap + must be initialized with `te.jax.gemm.initialize_comm_gemm_overlaps()` before this + GEMM call, and the configuration dictionary used in the initialization must include + the name passed into this function. + ag_overlap_skip_copy: bool = `False` + All-gather overlap requires the LHS operand to be copied into the communication buffer. + If the communication buffer already has the necessary data, setting this flag will + avoid an unnecessary memcpy operation. + """ + comm_overlap_config = None + if comm_overlap_name is not None: + comm_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(comm_overlap_name, None) + if comm_overlap_config is None: + warnings.warn( + f"Comm+GEMM overlap for {comm_overlap_name} has not been initialized! " + + "Sharded operands will trigger XLA collectives instead." + ) + + elif ( + not ag_overlap_skip_copy + and comm_overlap_config["method"] != "bulk" + and comm_overlap_config["comm_type"] == tex.CommOverlapType.AG + ): + copy_into_overlap_buffer(x, comm_overlap_name, True) + + return _fp8_gemm( + x, + kernel_t, + bias, + fp8_meta.amax_list, + fp8_meta.scale_list, + out_dtype, + fuse_gelu, + accumulate, + use_split_accumulator, + comm_overlap_config, + ) + + +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10)) +def _fp8_gemm( + x: ArrayLike, + kernel_t: ArrayLike, + bias: ArrayLike, + amax_list: ArrayLike, + scale_list: ArrayLike, + out: ArrayLike, + out_dtype: jnp.dtype, + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap_config: dict, +) -> ArrayLike: + out, _ = _fp8_gemm_fwd_rule( + x, + kernel_t, + bias, + amax_list, + scale_list, + out_dtype, + fuse_gelu, + accumulate, + use_split_accumulator, + comm_overlap_config, + ) + return out + + +def _fp8_gemm_fwd_rule( + x: ArrayLike, + kernel_t: ArrayLike, + bias: ArrayLike, + amax_list: ArrayLike, + scale_list: ArrayLike, + out_dtype: jnp.dtype, + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap_config: dict, +) -> Tuple[ArrayLike, ...]: + assert ( + kernel_t.ndim == 2 + ), "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." + + fuse_bias = bias is not None + + maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( + *amax_list, + *scale_list, + ) + amax_list = maybe_fm32_to_fp32(*amax_list) + scale_list = maybe_fm32_to_fp32(*scale_list) + + fwd_dtype = FP8Helper.FWD_DTYPE + bwd_dtype = FP8Helper.BWD_DTYPE + fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype, fwd_dtype] + scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale( + amax_list, scale_list, fp8_dtype_list + ) + amax_list = FP8MetaPackage.update_amax_list(amax_list) + + x_amax = amax_list[FP8MetaPackage.INPUT_IDX][0:1] + x_scale = scale_list[FP8MetaPackage.INPUT_IDX] + x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] + if x.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + casted_x, casted_x_t, updated_x_amax = cast_transpose( + x, + x_amax, + x_scale, + x_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + casted_x = x + casted_x_t = jnp.matrix_transpose(x) + updated_x_amax = x_amax + + kernel_amax = amax_list[FP8MetaPackage.WEIGHT_IDX][0:1] + kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX] + kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] + if kernel_t.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + casted_kernel_t, casted_kernel, updated_kernel_amax = cast_transpose( + kernel_t, + kernel_amax, + kernel_scale, + kernel_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + casted_kernel = jnp.matrix_transpose(kernel_t) + casted_kernel_t = kernel_t + updated_kernel_amax = kernel_amax + + out_amax = ( + amax_list[FP8MetaPackage.OUTPUT_IDX][0:1] + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + else None + ) + out_scale = ( + scale_list[FP8MetaPackage.OUTPUT_IDX][0:1] + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + else None + ) + + # Set scale_inv for comm overlap buffer + buffer_scale_inv = None + if comm_overlap_config is not None: + overlap_name = comm_overlap_config["name"] + if comm_overlap_config["method"] != "bulk" and tex.overlap_buffer_is_fp8(overlap_name): + if comm_overlap_config["comm_type"] == tex.CommOverlapType.AG: + buffer_scale_inv = x_scale_inv + + elif comm_overlap_config["comm_type"] == tex.CommOverlapType.RS: + out_dtype = fwd_dtype + out_scale = scale_list[FP8MetaPackage.OUTPUT_IDX][0:1] + buffer_scale_inv = jnp.reciprocal(out_scale) + + tex.set_overlap_buffer_scale_inverse( + overlap_name, + jax.dlpack.to_dlpack(buffer_scale_inv), + ) + + out, updated_out_amax, updated_out_scale, pre_gelu_out, extra_out = fp8_gemm_impl( + casted_x, + x_scale_inv, + casted_kernel_t, + kernel_scale_inv, + bias=bias, + out_amax=out_amax, + out_scale=out_scale, + out_dtype=out_dtype, + batched_output=(x.ndim > 2), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, + ) + + # Update returned and saved arrays based on comm+GEMM overlap config + final_out = out + if comm_overlap_config is not None: + if comm_overlap_config["comm_type"] == tex.CommOverlapType.RS: + # RS overlap puts the reduce-scattered sharded output into extra_out + final_out = extra_out + + if not jax_dtype_is_fp8(final_out): + updated_out_amax = None + updated_out_scale = None + + ctx = ( + casted_x_t, + casted_kernel, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + updated_out_amax, + pre_gelu_out if fuse_gelu else None, + fuse_bias, + maybe_fp32_to_fm32, + (x.ndim > 2), + ) + + return (final_out, updated_out_amax, updated_out_scale), ctx + + +def _fp8_gemm_bwd_rule( + out_dtype, + fuse_gelu, + accumulate, + use_split_accumulator, + comm_overlap_config, + ctx, + grad, +): + ( + casted_x_t, + casted_kernel, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + updated_out_amax, + pre_gelu_out, + fuse_bias, + maybe_fp32_to_fm32, + batched_input, + ) = ctx + del out_dtype + bwd_dtype = FP8Helper.BWD_DTYPE + + # Recover DGRAD and WGRAD comm+GEMM overlap configs + dgrad_overlap_name = None + dgrad_overlap_config = None + wgrad_overlap_name = None + wgrad_overlap_config = None + if comm_overlap_config is not None: + dgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_dgrad" + dgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(dgrad_overlap_name, None) + wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" + wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) + + # Cast-transpose grad with potential fusions + grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1] + grad_scale = scale_list[FP8MetaPackage.GRAD_IDX] + grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_ID] + if fuse_gelu: + if fuse_bias: + # Fuse dbias into this dGELU. + casted_grad, casted_grad_t, bgrad, updated_grad_amax = dact_lu_dbias_cast_transpose( + grad, + pre_gelu_out, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + activation_type=("gelu",), + ) + else: + # No bias to fuse so we just do dGELU. + casted_grad, casted_grad_t, updated_grad_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) + bgrad = None + else: + if fuse_bias: + # Since there is no GELU fusion, we need to fuse dbias into this cast_transpose. + casted_grad, casted_grad_t, bgrad, updated_grad_amax = dbias_cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + # If both bias and GELU is fused into the forward pass, we will fuse dbias later with + # dGELU. No need to do it here. + casted_grad, casted_grad_t, updated_grad_amax = cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + bgrad = None + + # Set scale_inv for comm overlap buffer + dgrad_amax = None + dgrad_scale = None + if dgrad_overlap_config is not None: + if dgrad_overlap_config["method"] == "bulk": + assert ( + wgrad_overlap_config is not None + ), f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" + # Set WGRAD buffer as output of DGRAD in order to avoid a memcpy for bulk RS overlap + dgrad_pre_rs = jax.dlpack.from_dlpack(tex.get_overlap_buffer(wgrad_overlap_name, False)) + # Copy input into overlap buffer for all-gather + copy_into_overlap_buffer(casted_x_t, dgrad_overlap_name, True) + + elif tex.overlap_buffer_is_fp8(dgrad_overlap_name): + # Non-bulk RS DGRAD overlap needs output amax and scale if buffer type is FP8 + dgrad_amax = grad_amax + dgrad_scale = grad_scale + tex.set_overlap_buffer_scale_inverse( + dgrad_overlap_name, + jax.dlpack.to_dlpack(grad_scale_inv), + ) + + # DGRAD: ([B], M, N) x (K, N)^T = ([B], M, K) + kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] + dgrad, *_, dgrad_extra_out = fp8_gemm_impl( + casted_grad, + grad_scale_inv, + casted_kernel, + kernel_scale_inv, + out=dgrad_pre_rs, + out_amax=dgrad_amax, + out_scale=dgrad_scale, + batched_output=batched_input, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_config=dgrad_overlap_config, + ) + + # If dgrad overlapped reduce-scatter, set it to the RS output + if ( + dgrad_overlap_config is not None + and dgrad_overlap_config["method"] != "bulk" + and dgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): + dgrad = dgrad_extra_out + + # Prepare comm+GEMM overlap for WGRAD + if wgrad_overlap_config is not None: + if wgrad_overlap_config["method"] == "bulk": + # Get all-gathered input from DGRAD bulk overlap + casted_x_t = jax.dlpack.from_dlpack(tex.get_overlap_buffer(dgrad_overlap_name, False)) + + elif tex.overlap_buffer_is_fp8(wgrad_overlap_name): + # Set FP8 scale inverse for non-bulk AG overlap + tex.set_overlap_buffer_scale_inverse( + wgrad_overlap_name, jax.dlpack.to_dlpack(x_scale_inv) + ) + + # WGRAD: ([B], N, M) x ([B], K, M)^T = (N, K) + x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] + wgrad, *_, wgrad_extra_out = fp8_gemm_impl( + casted_x_t, + x_scale_inv, + casted_grad_t, + grad_scale_inv, + out_dtype=jnp.bfloat16, + batched_output=False, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_config=wgrad_overlap_config, + ) + + # If wgrad overlapped reduce-scatter, set it to the RS output + if ( + wgrad_overlap_config is not None + and wgrad_overlap_config["method"] != "bulk" + and wgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): + dgrad = wgrad_extra_out + + amax_list[FP8MetaPackage.INPUT_IDX] = ( + amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0]) + ) + amax_list[FP8MetaPackage.WEIGHT_IDX] = ( + amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0]) + ) + amax_list[FP8MetaPackage.GRAD_IDX] = ( + amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0]) + ) + if updated_out_amax is not None: + amax_list[FP8MetaPackage.OUTPUT_IDX] = ( + amax_list[FP8MetaPackage.OUTPUT_IDX].at[0].set(updated_out_amax[0]) + ) + + amax_list = maybe_fp32_to_fm32(*amax_list) + scale_list = maybe_fp32_to_fm32(*scale_list) + + return dgrad, wgrad, bgrad, amax_list, scale_list + + +_fp8_gemm.defvjp(_fp8_gemm_fwd_rule, _fp8_gemm_bwd_rule) + + +def type_safe_gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Optional[ArrayLike] = None, + out: Optional[ArrayLike] = None, + out_dtype: Optional[jnp.dtype] = None, + fp8_meta: Optional[FP8MetaPackage] = None, + contracting_dims: Tuple[int, int] = (-1, -2), + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, + comm_overlap_name: Optional[str] = None, +) -> ArrayLike: + if jax_dtype_is_fp8(x.dtype) or jax_dtype_is_fp8(kernel.dtype): + assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." + + if fp8_meta is not None: + x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) + assert x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 1, ( + "FP8 GEMM requires non-transposed X (LHS) and transposed kernel (RHS), " + + "i.e. contracting_dims=(-1, -1)." + ) + return fp8_gemm( + x, + kernel, + fp8_meta, + bias=bias, + out=out, + out_dtype=out_dtype, + fuse_gelu=fuse_gelu, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_name=comm_overlap_name, + ) + else: + return gemm( + x, + kernel, + bias=bias, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_name=comm_overlap_name, + ) + + +def initialize_comm_gemm_overlaps( + buffer_shape: Sequence[int], + mesh: jax.sharding.Mesh, + myrank: int, + numranks: int, + **kwargs: Optional[dict], +) -> None: + """ + Initialize Comm+GEMM overlap communicators and buffers. + + .. warning:: + Communication buffer allocations for this functionality are outside the XLA memory pool + and can cause OOM errors if XLA's memory margin is not reduced. + + Parameters + ---------- + buffer_shape : Sequence[int] + Shape of the communication buffer. This should be sized to match the global shape of the + input/activation tensor. + mesh : jax.sharding.Mesh + JAX Mesh with a `tp_resource` axis. + myrank: int + Global rank of the calling process. + numranks: int + Global number of processes. + tp_resource : Optional[str] = None + Tensor-parallel mesh axis name. If not given, defaults to the TP resource in the global + te.sharding.MeshResource context. + tp_size : Optional[int] = None + Size of the tensor-parallel axis in the mesh. If not given, defaults to the size of the + tensor-parallel axis in `jax.interpreters.pxla.thread_resources`. + use_fp8 : bool = False + Flag for allocating an FP8 communication buffer. This is not supported for reduce-scatter + overlaps with the `pipeline` method. + overlap_configs: Optional[dict] = None, + Dictionary of configs for comm+GEMM overlaps by layer name. + """ + assert tex.ubuf_built_with_mpi(), ( + "Comm+GEMM overlap in TE/JAX requires Transformer Engine to be compiled with " + + "`NVTE_UB_WITH_MPI=1` and `MPI_HOME=/path/to/mpi` variables." + ) + if not tex.device_supports_multicast(): + assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( + "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " + + "CUDA Multicast. Launch with UB_SKIPMC=1 to try CUDA IPC instead." + ) + # Extract kwargs + tp_resource = kwargs.get("tp_resource", global_mesh_resource().tp_resource) + tp_size = kwargs.get("tp_size", get_mesh_axis_size(tp_resource, mesh=mesh)) + use_fp8 = kwargs.get("use_fp8", False) + overlap_configs = kwargs.get("overlap_configs", None) + + # Layers that support comm+GEMM overlap + layers_all_gather_overlap = [ + "ag_gemm", + "qkv_fprop", + "qkv_dgrad", + "proj_dgrad", + "fc1_fprop", + "fc1_dgrad", + "fc2_dgrad", + ] + layers_reduce_scatter_overlap = [ + "gemm_rs", + "proj_fprop", + "fc2_fprop", + "qkv_wgrad", + "fc1_wgrad", + ] + dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] + + # Default overlap methods for layers + methods = { + "ring_exchange": [ + "ag_gemm", + "gemm_rs", + "qkv_fprop", + "fc1_fprop", + "proj_dgrad", + "fc2_dgrad", + ], + "pipeline": ["proj_fprop", "fc2_fprop"], + "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], + } + + # AG-RS overlap pairs of layers forming a tensor-parallel block + ag_rs_pairs = { + "qkv_fprop": "proj_fprop", + "fc1_fprop": "fc2_fprop", + } + rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} + global layers_atomic_ring_exchange + layers_atomic_ring_exchange = [] + + def get_method(name): + for method, names in methods.items(): + if name in names: + return method + raise KeyError(f"Given layer name {name} does not exist.") + + def get_default_config(name): + method = get_method(name) + default_cfg = { + "mesh": mesh, + "tp_resource": tp_resource, + "tp_size": tp_size, + "name": name, + "method": method, + "comm_type": ( + tex.CommOverlapType.AG + if name in layers_all_gather_overlap + else tex.CommOverlapType.RS + ), + "num_sm": 1 if method == "ring_exchange" else 16, + "num_max_streams": get_num_max_compute_streams(), + "cga_size": 1 if method == "ring_exchange" else 2, + "set_sm_margin": False, + "num_splits": 4 if method == "pipeline" else tp_size, + "aggregate": False, + "atomic_gemm": False, + "pipeline_rs_overlap_first_gemm": False, + "use_ce": True, + "fp8_buf": name in layers_all_gather_overlap, + } + return default_cfg + + def add_new_comm_gemm_overlap( + shape: Sequence[int], + kwargs: dict, + ) -> None: + overlap_name = kwargs["name"] + assert ( + overlap_name not in _ACTIVE_COMM_GEMM_OVERLAPS + ), f"Duplicate initialization for `{overlap_name}` overlap!" + + overlap_method = kwargs["method"] + overlap_atomic_gemm = kwargs["atomic_gemm"] + if overlap_atomic_gemm: + warnings.warn( + "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." + ) + assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." + if overlap_method == "bulk": + warnings.warn( + f"At {overlap_name}, atoimic GEMM not is supported for a bulk overlap." + "Defaulting to `atomic_gemm=False`." + ) + overlap_atomic_gemm = False + kwargs["atomic_gemm"] = overlap_atomic_gemm + if overlap_method == "pipeline" and kwargs["comm_type"] == tex.CommOverlapType.AG: + raise ValueError( + f"At {overlap_name}, `pipeline` overlap method is not supported for AllGather." + ) + # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`. + # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality. + global layers_atomic_ring_exchange + if ( + overlap_atomic_gemm + and overlap_method == "ring_exchange" + and overlap_name in ag_rs_pairs + ): + layers_atomic_ring_exchange += [overlap_name, ag_rs_pairs[overlap_name]] + if overlap_name in rs_ag_pairs: + assert_message = ( + f"At {overlap_name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM " + "chunk outputs, and RS-GEMM overlap un-suffle them. When one of the GEMM-AG and " + "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses " + "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config " + "for functionality." + ) + if overlap_name in layers_atomic_ring_exchange: + assert overlap_atomic_gemm and overlap_method == "ring_exchange", assert_message + else: + if overlap_atomic_gemm and overlap_method == "ring_exchange": + assert rs_ag_pairs[overlap_name] in layers_atomic_ring_exchange, assert_message + + # Reduce buffer shape to 2D here in case the user initialized with batch dims + buffer_shape = (reduce(operator.mul, shape[:-1], 1), shape[-1]) + tex.bootstrap_comm_gemm_overlap( + buffer_shape, + jax_dtype_to_te_dtype(jnp.uint8 if (use_fp8 and fp8_buf) else jnp.bfloat16), + overlap_name, + overlap_method, + kwargs["comm_type"], + myrank, + numranks, + tp_size, + kwargs["num_splits"], + get_num_max_compute_streams(), + kwargs["cga_size"], + kwargs["num_sm"], + kwargs["set_sm_margin"], + kwargs["use_ce"], + overlap_atomic_gemm, + kwargs["aggregate"], + kwargs["pipeline_rs_overlap_first_gemm"], + ) + + if overlap_configs is not None: + for name in dgrad_reduce_scatter_overlap: + if ( + name in overlap_configs + and "method" in overlap_configs[name] + and overlap_configs[name]["method"] != "bulk" + ): + wgrad_name = name.replace("dgrad", "wgrad") + assert wgrad_name not in overlap_configs + layers_reduce_scatter_overlap.remove(wgrad_name) + layers_all_gather_overlap.remove(name) + layers_reduce_scatter_overlap.append(name) + methods["bulk"].remove(name) + methods["bulk"].remove(wgrad_name) + new_method = overlap_configs[name]["method"] + methods[new_method].append(name) + + global _ACTIVE_COMM_GEMM_OVERLAPS + for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: + if overlap_configs is not None and name in overlap_configs: + fp8_buf = (name in layers_all_gather_overlap) or ( + overlap_configs[name].get("fp8_buf", False) and name not in methods["pipeline"] + ) + final_config = get_default_config(name) + final_config.update(overlap_configs[name]) + final_config["fp8_buf"] = fp8_buf + add_new_comm_gemm_overlap(buffer_shape, final_config) + _ACTIVE_COMM_GEMM_OVERLAPS[name] = final_config + + +def destroy_comm_gemm_overlaps(): + global _ACTIVE_COMM_GEMM_OVERLAPS + for name in _ACTIVE_COMM_GEMM_OVERLAPS: + tex.destroy_comm_gemm_overlap(name) + _ACTIVE_COMM_GEMM_OVERLAPS = dict() + + +def get_comm_overlap_config(name): + global _ACTIVE_COMM_GEMM_OVERLAPS + assert ( + name in _ACTIVE_COMM_GEMM_OVERLAPS + ), f"Comm+GEMM overlap for '{name}' has not been initialized!" + return _ACTIVE_COMM_GEMM_OVERLAPS[name] diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 3b49ece4a3..d906bba98f 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -553,7 +553,8 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false, + bool overlap_first_gemm = false); void set_ubuf_scale_inv(torch::Tensor scale_inv) { assert(scale_inv.numel()); diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index d212d13516..587e3115b9 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -186,13 +186,13 @@ void CommOverlapHelper::ub_barrier(ExtComm group) { CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, - bool set_sm_margin, bool atomic_gemm) - : te::CommOverlapBase(buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, - helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, - helper->numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, - num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) { + bool set_sm_margin, bool atomic_gemm, bool overlap_first_gemm) + : te::CommOverlapBase( + buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, + helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, num_max_streams, + comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm, overlap_first_gemm) { // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to // for PyTorch to factor externally allocated memory into its memory pool and garbage collection // threshold calculation. diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 8856553c54..9841b5d640 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -263,12 +263,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::class_(m, "CommOverlap") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, - int, int, bool, bool>(), + int, int, bool, bool, bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 2, py::arg("num_comm_sm") = 16, - py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false) + py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false, + py::arg("overlap_first_gemm") = false) .def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard()) .def("split_overlap_rs", &CommOverlap::split_overlap_rs, py::call_guard()) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d115efedaa..164d371985 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -306,6 +306,7 @@ def get_default_config(name): "num_splits": 4 if method == "pipeline" else tp_size, "aggregate": False, "atomic_gemm": False, + "pipeline_rs_overlap_first_gemm": False, "use_ce": True, "fp8_buf": name in layers_all_gather_overlap, } @@ -314,13 +315,14 @@ def get_default_config(name): def add_ub( name: str, method: str, - is_reduce_scatter: int, + is_reduce_scatter: bool, num_sm: int = 16, cga_size: int = 2, - set_sm_margin: int = 0, - num_splits: int = 0, - aggregate: int = 0, - atomic_gemm: int = 0, + set_sm_margin: bool = False, + num_splits: int = 4, + aggregate: bool = False, + atomic_gemm: bool = False, + pipeline_rs_overlap_first_gemm: bool = False, use_ce: bool = True, fp8_buf: bool = False, ) -> None: @@ -386,6 +388,7 @@ def add_ub( num_comm_sm=num_sm, set_sm_margin=set_sm_margin, atomic_gemm=atomic_gemm, + overlap_first_gemm=pipeline_rs_overlap_first_gemm, ) _ub_communicators[name] = ub_obj