diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 4f170e3f84..9d4d74b61a 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -19,8 +19,8 @@ import transformer_engine.pytorch as te import transformer_engine.pytorch.cpp_extensions as tex -from transformer_engine.common.recipe import Format -from transformer_engine.pytorch.fp8 import _default_sf_compute +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.module.base import get_cublas_workspace_size_bytes warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) @@ -288,33 +288,6 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None else tex.CommOverlapHelper(bootstrap_pg) ) - if opts.comm_type == tex.CommOverlapType.RS: - if opts.bulk_overlap: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_RS - elif opts.p2p: - ub_algo = ( - tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - if opts.atomic - else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - ) - else: - ub_algo = ( - tex.CommOverlapAlgo.ATOMIC_GEMM_RS - if opts.atomic - else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ) - elif opts.comm_type == tex.CommOverlapType.AG: - if opts.bulk_overlap: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - else: - ub_algo = ( - tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - if opts.atomic - else tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - ) - else: - raise TypeError("Invalid comm+GEMM overlap type!") - # Initialize userbuffers with (M, N) buffer # M = sequence * batch # N = hidden size @@ -325,7 +298,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None if ( opts.fp8 and not opts.bulk_overlap - and (opts.comm_type == tex.CommOverlapType.AG or opts.fp8_output) + and opts.comm_type == tex.CommOverlapType.AG ): buffer_dtype = torch.uint8 ub_obj = ( @@ -421,6 +394,10 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None std=opts.std, ) + # Allocate cuBLAS workspace + workspace_size = 3 * get_cublas_workspace_size_bytes() + workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda") + # Gather global tensors and calculate reference result (need these first for Fp8 scales) if opts.bulk_overlap: ker_g = torch.transpose(kernel_t, 0, 1) @@ -467,120 +444,126 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable ref2_g = torch.matmul(inp2_g, ker2_g) + inp_quantizer = None + ker_quantizer = None + out_quantizer = None + bulk_inp_quantizer = None + inp2_quantizer = None + ker2_quantizer = None + out2_quantizer = None if opts.fp8: - fp8_formats = { - tex.DType.kFloat8E4M3: Format.E4M3, - tex.DType.kFloat8E5M2: Format.E5M2, - } - # Structure to maintain amax and scale/scale_inv information for the kernel and input - fp8_dtype = tex.DType.kFloat8E4M3 - fp8_meta = tex.FP8TensorMeta() num_gemms = 6 if ub_obj2 is not None else 3 - fp8_meta.amax_history = torch.zeros((2, num_gemms), dtype=torch.float, device="cuda") - fp8_meta.scale = torch.ones(num_gemms, dtype=torch.float, device="cuda") - fp8_meta.scale_inv = torch.ones(num_gemms, dtype=torch.float, device="cuda") + fp8_dtype = tex.DType.kFloat8E4M3 + fp8_scales = torch.ones(num_gemms, dtype=torch.float, device="cuda") + fp8_amaxes = torch.zeros(num_gemms, dtype=torch.float, device="cuda") # Compute initial amaxes and scales inp_amax = torch.max(torch.abs(inp_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_INPUT].copy_(inp_amax) + fp8_amaxes[0].copy_(inp_amax) ker_amax = torch.max(torch.abs(ker_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax) + fp8_amaxes[1].copy_(ker_amax) ref_amax = torch.max(torch.abs(ref_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax) + fp8_amaxes[2].copy_(ref_amax) if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: bulk_amax = torch.max(torch.abs(bulk_inp)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax) + fp8_amaxes[5].copy_(bulk_amax) elif ub_obj2 is not None: inp2_amax = torch.max(torch.abs(inp2_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_INPUT].copy_(inp2_amax) + fp8_amaxes[3].copy_(inp2_amax) ker2_amax = torch.max(torch.abs(ker2_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_WEIGHT].copy_(ker2_amax) + fp8_amaxes[4].copy_(ker2_amax) ref2_amax = torch.max(torch.abs(ref2_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(ref2_amax) - fp8_meta.scale = _default_sf_compute( - fp8_meta.amax_history[1], fp8_meta.scale, fp8_formats[fp8_dtype].value.max_fwd, 1 - ) - fp8_meta.scale_inv = torch.reciprocal(fp8_meta.scale) + fp8_amaxes[5].copy_(ref2_amax) + + inp_quantizer = Float8Quantizer(fp8_scales[0].clone(), fp8_amaxes[0].clone(), fp8_dtype) + ker_quantizer = Float8Quantizer(fp8_scales[1].clone(), fp8_amaxes[1].clone(), fp8_dtype) + if opts.fp8_output: + out_quantizer = Float8Quantizer(fp8_scales[2].clone(), fp8_amaxes[2].clone(), + fp8_dtype) + + if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: + bulk_inp_quantizer = Float8Quantizer(fp8_scales[5].clone(), fp8_amaxes[5].clone(), + fp8_dtype) + elif ub_obj2 is not None: + inp2_quantizer = Float8Quantizer(fp8_scales[3].clone(), fp8_amaxes[3].clone(), + fp8_dtype) + ker2_quantizer = Float8Quantizer(fp8_scales[4].clone(), fp8_amaxes[4].clone(), + fp8_dtype) + if opts.fp8_output: + out2_quantizer = Float8Quantizer(fp8_scales[5].clone(), fp8_amaxes[5].clone(), + fp8_dtype) # Cast input to Float8Tensor - inp_fp8 = tex.cast_to_fp8(inp, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype) + inp_fp8 = inp_quantizer(inp) # Cast kernel to Float8Tensor - kernel_t_fp8 = tex.cast_to_fp8( - kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype - ) + kernel_t_fp8 = ker_quantizer(kernel_t) if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: - bulk_inp_fp8 = tex.cast_to_fp8( - bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype - ) + bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp) elif ub_obj2 is not None: - kernel2_t_fp8 = tex.cast_to_fp8( - kernel2_t, fp8_meta, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype - ) + kernel2_t_fp8 = ker2_quantizer(kernel2_t) # Make sure the inputs are cast correctly if opts.check_numerics: torch.allclose( inp.to(dtype=torch.float32), - inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT], + inp_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) torch.allclose( kernel_t.to(dtype=torch.float32), - kernel_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT], + kernel_t_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: torch.allclose( bulk_inp.to(dtype=torch.float32), - bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT], + bulk_inp_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) elif ub_obj2 is not None: torch.allclose( kernel2_t.to(dtype=torch.float32), - kernel2_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT], + kernel2_t_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) - # Set Fp8 scales for userbuffers - if opts.comm_type == tex.CommOverlapType.AG: - ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT]) - if ub_obj2 is not None: - ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) - elif opts.bulk_overlap: - ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) - else: - ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_OUTPUT]) - # Set up comm/compute buffers - ubuf_out2 = None + rs_out = None rs_out2 = None if opts.comm_type == tex.CommOverlapType.AG: if opts.bulk_overlap: - ub_obj.copy_input_to_ubuf(bulk_inp, 1) + ub_obj.copy_into_buffer(bulk_inp, bulk_inp_quantizer, True) gemm_inp = inp else: - ub_obj.copy_input_to_ubuf(inp_fp8 if opts.fp8 else inp, 1) - gemm_inp = ub_obj.get_ubuf_output(1) - ubuf_out = None - rs_out = None + ub_obj.copy_into_buffer( + inp_fp8 if opts.fp8 else inp, + inp_quantizer, + True + ) + gemm_inp = ub_obj.get_buffer(inp_quantizer, False, inp_g.size()) if ub_obj2 is not None: - ubuf_out2 = ub_obj2.get_ubuf_output(1) + if opts.fp8 and opts.fp8_output: + ub_obj2.set_buffer_params(out_quantizer) rs_out2 = torch.empty( (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" ) else: if opts.bulk_overlap: - ub_obj.copy_input_to_ubuf(bulk_inp_fp8 if opts.fp8 else bulk_inp, 0) - ubuf_out = None - else: - ubuf_out = ub_obj.get_ubuf_output(1) + ub_obj.copy_into_buffer( + bulk_inp_fp8 if opts.fp8 else bulk_inp, + bulk_inp_quantizer, + False + ) + if opts.fp8: + ub_obj.set_buffer_params(bulk_inp_quantizer) + elif opts.fp8 and opts.fp8_output: + ub_obj.set_buffer_params(out_quantizer) gemm_inp = inp_fp8 if opts.fp8 else inp rs_out = torch.empty( (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" @@ -588,88 +571,51 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None # Wrap GEMM ops in condensed functions to make CUDA Graphs easier to use def _fp8_gemm(): - return tex.fp8_gemm( + return tex.general_gemm( kernel_t_fp8, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype, gemm_inp, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype, - torch.uint8 if opts.fp8_output else torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, + workspace, + out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16, + quantization_params=out_quantizer, use_split_accumulator=te.module.base._2X_ACC_FPROP, - ub_algo=ub_algo, ub=ub_obj, - extra_output_tensor=rs_out, - out=ubuf_out, - D_dtype=fp8_dtype if opts.fp8_output else None, - fp8_meta_tensor=fp8_meta if opts.fp8_output else None, - out_index=tex.FP8FwdTensors.GEMM1_OUTPUT if opts.fp8_output else None, + ub_type=opts.comm_type, + extra_output=rs_out, + bulk_overlap=opts.bulk_overlap ) def _fp8_gemm2(gemm1_out): gemm2_inp = tex.gelu( ( - tex.cast_from_fp8( - gemm1_out, - fp8_meta, - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype, - tex.DType.kFloat32, - ) + gemm1_out.dequantize() if opts.fp8_output else gemm1_out ), - fp8_meta, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype, + inp2_quantizer, ) - return tex.fp8_gemm( + return tex.general_gemm( kernel2_t_fp8, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM2_WEIGHT, - fp8_dtype, gemm2_inp, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype, - torch.uint8 if opts.fp8_output else torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, + workspace, + out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16, + quantization_params=out2_quantizer, use_split_accumulator=te.module.base._2X_ACC_FPROP, - ub_algo=( - tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - if opts.atomic_rs_p2p - else tex.CommOverlapAlgo.ATOMIC_GEMM_RS - ), ub=ub_obj2, - extra_output_tensor=rs_out2, - out=ubuf_out2, - D_dtype=fp8_dtype if opts.fp8_output else None, - fp8_meta_tensor=fp8_meta if opts.fp8_output else None, - out_index=tex.FP8FwdTensors.GEMM2_OUTPUT if opts.fp8_output else None, + ub_type=tex.CommOverlapType.AG, + extra_output=rs_out2, ) def _gemm(): - return tex.gemm( + return tex.general_gemm( kernel_t, gemm_inp, - torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, - ub_algo=ub_algo, + workspace, + out_dtype=torch.bfloat16, + use_split_accumulator=te.module.base._2X_ACC_FPROP, ub=ub_obj, - extra_output_tensor=rs_out, - out=ubuf_out, + ub_type=opts.comm_type, + extra_output=rs_out, + bulk_overlap=opts.bulk_overlap ) # Trigger GEMM @@ -746,10 +692,10 @@ def _gemm(): output_info = "" if opts.comm_type == tex.CommOverlapType.AG: # Bulk overlap AG output is already gathered - test_out = ub_obj.get_ubuf_output(1) + test_out = ub_obj.get_buffer(bulk_inp_quantizer, False) else: # Bulk overlap RS output needs to be gathered - out_local = ub_obj.get_ubuf_output(0) + out_local = ub_obj.get_buffer(bulk_inp_quantizer, True) output_info += f"rs_output: {list(out_local.shape)} | " test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0] @@ -776,13 +722,7 @@ def _gemm(): else: # AG Output: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) output = ( - tex.cast_from_fp8( - all_outputs[0], - fp8_meta, - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype, - tex.DType.kFloat32, - ) + all_outputs[0].dequantize() if opts.fp8_output else all_outputs[0] ) @@ -798,24 +738,24 @@ def _gemm(): output = rs_out.to(dtype=torch.float32) test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] - if opts.fp8: - dist_print("GEMM1 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) - fp8_meta_info = ( - f"amax_reference = {fp8_meta.amax_history[1][:3].tolist()}\n" - + f"amax_history = {fp8_meta.amax_history[0][:3].tolist()}\n" - + f"scale = {fp8_meta.scale[:3].tolist()}\n" - + f"scale_inv = {fp8_meta.scale_inv[:3].tolist()}" - ) - dist_print(fp8_meta_info, src=0, group=tp_group) - if ub_obj2 is not None: - dist_print("GEMM2 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) - fp8_meta_info = ( - f"amax_reference = {fp8_meta.amax_history[1][3:].tolist()}\n" - + f"amax_history = {fp8_meta.amax_history[0][3:].tolist()}\n" - + f"scale = {fp8_meta.scale[3:].tolist()}\n" - + f"scale_inv = {fp8_meta.scale_inv[3:].tolist()}" - ) - dist_print(fp8_meta_info, src=0, group=tp_group) + # if opts.fp8: + # dist_print("GEMM1 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) + # fp8_meta_info = ( + # f"amax_reference = {fp8_meta.amax_history[1][:3].tolist()}\n" + # + f"amax_history = {fp8_meta.amax_history[0][:3].tolist()}\n" + # + f"scale = {fp8_meta.scale[:3].tolist()}\n" + # + f"scale_inv = {fp8_meta.scale_inv[:3].tolist()}" + # ) + # dist_print(fp8_meta_info, src=0, group=tp_group) + # if ub_obj2 is not None: + # dist_print("GEMM2 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) + # fp8_meta_info = ( + # f"amax_reference = {fp8_meta.amax_history[1][3:].tolist()}\n" + # + f"amax_history = {fp8_meta.amax_history[0][3:].tolist()}\n" + # + f"scale = {fp8_meta.scale[3:].tolist()}\n" + # + f"scale_inv = {fp8_meta.scale_inv[3:].tolist()}" + # ) + # dist_print(fp8_meta_info, src=0, group=tp_group) ref_out = ref2_g if ub_obj2 is not None else ref_g test_nonzeros = torch.count_nonzero(test_out) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index e49174c24f..418ce1a71b 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -9,6 +9,7 @@ import socket import argparse import warnings +import pprint import torch import torch.distributed as dist @@ -39,6 +40,8 @@ def _te_layer_argtype(name): def _get_layer_args(config, tp_group, tp_size, reference=False): hidden_size = config.num_heads * config.head_dim + ffn_hidden_size = 4 * hidden_size + qkv_size = 3 * hidden_size input_shape = [config.seq_length, config.batch_size, hidden_size] args = [hidden_size] kwargs = { @@ -47,38 +50,41 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): "tp_group": tp_group, "tp_size": tp_size, "sequence_parallel": True, + "ub_overlap_ag": not reference, + "ub_overlap_rs": not reference, } - kwargs["ub_overlap_ag"] = not reference - - if config.layer_type is te.Linear: - input_shape[2] = hidden_size // tp_size - args.append(hidden_size) - kwargs["parallel_mode"] = "row" - kwargs["ub_overlap_rs"] = not reference - kwargs["ub_name"] = "proj" + + if config.layer_type in [te.Linear, te.LayerNormLinear]: + if config.linear_parallel_mode == "row": + input_shape[-1] = ffn_hidden_size // tp_size + args = [ffn_hidden_size, hidden_size] + kwargs["ub_name"] = "proj" if config.layer_type == te.Linear else "fc2" + elif config.linear_parallel_mode == "column": + input_shape[0] = config.seq_length // tp_size + args.append(qkv_size) + kwargs["ub_name"] = "qkv" + kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["parallel_mode"] = config.linear_parallel_mode else: input_shape[0] = config.seq_length // tp_size - kwargs["ub_bulk_wgrad"] = not reference - kwargs["ub_bulk_dgrad"] = not reference - if config.layer_type is te.LayerNormLinear: - args.append(3 * hidden_size) - kwargs["parallel_mode"] = "column" - kwargs["ub_name"] = "qkv" - else: - kwargs["set_parallel_mode"] = True - kwargs["ub_overlap_rs"] = not reference - if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: - args.append(4 * hidden_size) - kwargs["seq_length"] = config.seq_length - if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: - args.append(config.num_heads) - kwargs["attention_dropout"] = 0.0 - kwargs["fuse_qkv_params"] = True - if config.layer_type is te.MultiheadAttention: - kwargs["input_layernorm"] = True - else: - kwargs["ub_tp_comm_overlap"] = not reference - kwargs["hidden_dropout"] = 0.0 + if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: + args.append(ffn_hidden_size) + kwargs["seq_length"] = config.seq_length + if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args.append(config.num_heads) + kwargs["attention_dropout"] = 0.0 + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + else: + kwargs["ub_tp_comm_overlap"] = not reference + kwargs["hidden_dropout"] = 0.0 + kwargs["set_parallel_mode"] = True + kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference return args, kwargs, input_shape @@ -125,6 +131,19 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs." ) + parser.add_argument( + "--linear-parallel-mode", + type=str.lower, + default="row", + choices=["row", "column"], + help="Parallel mode for te.Linear.", + ) + parser.add_argument( + "--overlap-rs-dgrad", + action="store_true", + default=False, + help="Replace bulk DGRAD/WGRAD overlaps with DGRAD+RS in the backward pass for AG+GEMM." + ) parser.add_argument( "--debug", action="store_true", @@ -154,7 +173,7 @@ def _compare_tensors(name, test, ref, rtol, atol): ) return 1, numerics_info - diff = torch.abs(test - ref).flatten() + diff = torch.abs(test.flatten() - ref.flatten()) m = torch.argmax(diff) abs_err = diff[m].item() rel_err = abs_err / max(abs(ref.flatten()[m].item()), 1e-5) @@ -230,12 +249,19 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") # Intialize userbuffers + ub_cfgs = None + if opts.overlap_rs_dgrad: + ub_cfgs = { + "qkv_dgrad" : { "method" : "ring_exchange" }, + "fc1_dgrad" : { "method" : "ring_exchange" }, + } te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], WORLD_SIZE, use_fp8=opts.fp8, dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, + ub_cfgs=ub_cfgs, ) # Initialize the Transformer Engine layer with overlap @@ -243,6 +269,10 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): with te.fp8_model_init(enabled=opts.fp8_init): test_model = opts.layer_type(*args, **kwargs) dist_print("Initialized test model...", debug=True) + if WORLD_RANK == 0: + pprint.pprint(kwargs) + sys.stdout.write("\n") + dist.barrier() # Initialize the reference model and copy all parameters ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, WORLD_SIZE, reference=True) @@ -277,8 +307,8 @@ def run_fwd_bwd(model, x): out, *_ = y else: out = y - loss = out.sum() - loss.backward() + loss = out.sum() + loss.backward() return out torch_rng_state = torch.get_rng_state() @@ -333,7 +363,7 @@ def run_fwd_bwd(model, x): dist_print(grad_info, src=WORLD_RANK, error=grad_failed) numerics_failed[0] = int(grad_failed) dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) - if bool(numerics_failed.item()): + if bool(numerics_failed.item()) and not opts.debug: break te.module.base.destroy_ub() diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index c872aa0bd0..f310b7544a 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -28,6 +28,7 @@ te.MultiheadAttention, te.TransformerLayer, ] +MAX_LAYER_NAME_LENGTH = max([ len(layer.__name__) for layer in TE_LAYERS ]) TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = torch.cuda.device_count() @@ -46,7 +47,7 @@ torch._dynamo.reset() -def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate): +def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8): test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -62,21 +63,15 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg if bulk: test_cmd.append("--bulk-overlap") else: - if fp8_in: + if fp8: if not fp8_available: pytest.skip(reason_for_no_fp8) test_cmd.append("--fp8") - if fp8_out: - if torch.cuda.get_device_properties().major == 10: - pytest.skip("WIP: TE GEMM on Blackwell does not support FP8 output.") - test_cmd.append("--fp8-output") if p2p: test_cmd.append("--p2p") - if aggregate: - test_cmd.append("--aggregate") if atomic: if torch.cuda.get_device_properties(0).major != 9: - pytest.skip("Atomic GEMM requires device compute capability 9.x (Hopper).") + pytest.skip("Atomic GEMM is requires device compute capability 9.x (Hopper).") test_cmd.append("--atomic") result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) @@ -88,7 +83,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg raise AssertionError(result.stderr.decode()) -def _run_layer_with_overlap(layer_type, fp8, fp8_init): +def _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8): test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -99,13 +94,16 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init): f"--head-dim={HEAD_DIM}", f"--layer-type={layer_type}", ] + if layer_type in [te.Linear.__name__, te.LayerNormLinear.__name__]: + test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}") + + if overlap_rs_dgrad: + test_cmd.append("--overlap-rs-dgrad") if fp8: if not fp8_available: pytest.skip(reason_for_no_fp8) test_cmd.append("--fp8") - if fp8_init: - test_cmd.append("--fp8-init") os.environ["PYTORCH_JIT"] = "0" os.environ["NVTE_TORCH_COMPILE"] = "0" @@ -126,88 +124,64 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init): @pytest.mark.parametrize( - "fp8,aggregate", - [ - (False, False), - (False, True), - (True, False), - (True, True), - ], - ids=[ - " BF16 IN - RING-EXCHANGE ", - " BF16 IN - RING-EXCHANGE - 2x AGGREGATED ", - " FP8 IN - RING-EXCHANGE ", - " FP8 IN - RING-EXCHANGE - 2x AGGREGATED ", - ], + "fp8", (False, True), ids=[" BF16 - RING-EXCHANGE ", " FP8 - RING-EXCHANGE "], ) -def test_split_all_gather_overlaps(fp8, aggregate): +def test_split_all_gather_overlaps(fp8): """ Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("AG", False, True, False, fp8, False, aggregate) + _run_gemm_with_overlap("AG", False, True, False, fp8) @pytest.mark.parametrize( - "fp8_in,fp8_out,p2p", + "fp8,p2p", [ - (False, False, False), - (False, False, True), - (True, False, False), - (True, False, True), - (True, True, False), - (True, True, True), + (False, False), + (False, True), + (True, False), + (True, True), ], ids=[ - " BF16 IN - BF16 OUT - PIPELINE ", - " BF16 IN - BF16 OUT - RING-EXCHANGE ", - " FP8 IN - BF16 OUT - PIPELINE ", - " FP8 IN - BF16 OUT - RING-EXCHANGE ", - " FP8 IN - FP8 OUT - PIPELINE ", - " FP8 IN - FP8 OUT - RING-EXCHANGE ", + " BF16 - PIPELINE ", + " BF16 - RING-EXCHANGE ", + " FP8 - PIPELINE ", + " FP8 - RING-EXCHANGE ", ], ) -def test_split_reduce_scatter_overlaps(fp8_in, fp8_out, p2p): +def test_split_reduce_scatter_overlaps(fp8, p2p): """ Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("RS", False, p2p, False, fp8_in, fp8_out, False) + _run_gemm_with_overlap("RS", False, p2p, False, fp8) @pytest.mark.parametrize( - "ag_type,rs_type,p2p,fp8_out", + "ag_type,rs_type,p2p", [ - (0, 0, False, False), - (0, 1, False, False), - (0, 1, False, True), - (0, 2, False, False), - (0, 2, False, True), - (0, 0, True, False), - (0, 0, True, True), - (1, 0, True, False), - (1, 0, True, True), + (0, 0, False), + (0, 1, False), + (0, 2, False), + (0, 0, True), + (1, 0, True), ], ids=[ - " NON-ATOMIC AG - NON-ATOMIC RS - PIPELINE - BF16 OUT ", - " NON-ATOMIC AG - ATOMIC RS - PIPELINE - BF16 OUT ", - " NON-ATOMIC AG - ATOMIC RS - PIPELINE - FP8 OUT ", - " NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - BF16 OUT ", - " NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - FP8 OUT ", - " NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ", - " NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ", - " MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ", - " MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ", + " NON-ATOMIC AG - NON-ATOMIC RS - PIPELINE", + " NON-ATOMIC AG - ATOMIC RS - PIPELINE", + " NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE", + " NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE", + " MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE", ], ) -def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): +def test_atomic_gemm_overlaps(ag_type, rs_type, p2p): """ Test paired (all-gather -> atomic GEMM) and (atomic GEMM -> reduce-scatter) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ os.environ["NVTE_AG_P2P_MULTI_ATOMIC"] = str(ag_type) os.environ["NVTE_RS_STRIDED_ATOMIC"] = str(rs_type) - _run_gemm_with_overlap("AG", False, p2p, True, True, fp8_out, False) + _run_gemm_with_overlap("AG", False, p2p, True, True) @pytest.mark.parametrize( @@ -221,12 +195,12 @@ def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): ("RS", True, 8), ], ids=[ - "ALL-GATHER - BF16 - 1 connections", + "ALL-GATHER - BF16 - 1 connections", "REDUCE-SCATTER - BF16 - 1 connections", - "REDUCE-SCATTER - FP8 - 1 connections", - "ALL-GATHER - BF16 - 8 connections", + "REDUCE-SCATTER - FP8 - 1 connections", + "ALL-GATHER - BF16 - 8 connections", "REDUCE-SCATTER - BF16 - 8 connections", - "REDUCE-SCATTER - FP8 - 8 connections", + "REDUCE-SCATTER - FP8 - 8 connections", ], ) def test_bulk_overlaps(comm_type, fp8, connections): @@ -240,32 +214,39 @@ def test_bulk_overlaps(comm_type, fp8, connections): " 9.0 (HOPPER ARCH)." ) os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" - _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + _run_gemm_with_overlap(comm_type, True, False, False, fp8) os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" else: - _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + _run_gemm_with_overlap(comm_type, True, False, False, fp8) - -@pytest.mark.parametrize( - "layer_type", - [layer.__name__ for layer in TE_LAYERS], - ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS], -) @pytest.mark.parametrize( - "fp8,fp8_init", + "layer_type,linear_parallel_mode,overlap_rs_dgrad", [ - (False, False), - (True, False), - (True, True), - ], + (te.Linear.__name__, "row", False), + (te.Linear.__name__, "column", False), + (te.Linear.__name__, "column", True), + (te.LayerNormLinear.__name__, "row", False), + (te.LayerNormLinear.__name__, "column", False), + (te.LayerNormLinear.__name__, "column", True), + ] + list(zip([layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], + [None] * len(TE_LAYERS[2:]) * 2, + [False, True] * len(TE_LAYERS[2:]))), ids=[ - " BF16 GEMM - BF16 PARAMS ", - " FP8 GEMM - BF16 PARAMS ", - " FP8 GEMM - FP8 PARAMS ", + f" {te.Linear.__name__} - ROW-PARALLEL ", + f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", + f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ", + f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ", + f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", + f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ", + ] + [ + " " + " - ".join(test_name_parts) + " " for test_name_parts in + zip([layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], + ["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:])) ], ) -def test_layers_with_overlap(layer_type, fp8, fp8_init): +@pytest.mark.parametrize("fp8", (False, True), ids=[" BF16 ", " FP8 "]) +def test_layers_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, fp8, fp8_init) + _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index d03eff1c75..1344d5811e 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -21,6 +21,8 @@ #define HALF_BYTES 2 #define UB_MAX_SM 32 +#define AS_VECTOR(shape) std::vector(shape.data, shape.data + shape.ndim) + using namespace std::placeholders; namespace transformer_engine { @@ -137,6 +139,70 @@ CommOverlapCore::~CommOverlapCore() { } } +TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset, + const std::vector &chunk_shape) { + TensorWrapper chunk; + for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) { + auto param_type = static_cast(param_id); + auto param = source.get_parameter(param_type); + auto param_dptr = reinterpret_cast(param.data_ptr); + auto param_dtype = static_cast(param.dtype); + auto param_shape = AS_VECTOR(param.shape); + + if (param_dptr != nullptr) { + if (param_type == NVTETensorParam::kNVTERowwiseData + || param_type == NVTETensorParam::kNVTEColumnwiseData) { + // Offset data pointer + param_dptr += chunk_offset * typeToSize(param_dtype); + param_shape = chunk_shape; + + if (param_type == NVTETensorParam::kNVTEColumnwiseData + && source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) { + // Columnwise shape for non-block scaled tensors shifts the last dimension to the front + auto last_dim = param_shape.back(); + param_shape.pop_back(); + param_shape.insert(param_shape.begin(), last_dim); + } + } else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING + && (param_type == NVTETensorParam::kNVTERowwiseScaleInv + || param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) { + // Calculate block scaling offset and size + auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) ? + source.shape().data[0] : source.columnwise_shape().data[0]; + auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) ? + chunk_shape.front() : chunk_shape.back(); + auto chunk_scale_start = chunk_offset / 32; + auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32; + auto chunk_scale_size = chunk_scale_end - chunk_scale_start; + param_dptr += chunk_scale_start * typeToSize(param_dtype); + param_shape = std::vector{chunk_scale_size}; + } + + // Set chunked source parameters into the chunked tensor output + chunk.set_parameter(param_type, reinterpret_cast(param_dptr), param_dtype, + param_shape); + } + } + return chunk; +} + +TensorWrapper CommOverlapCore::get_buffer_chunk_like( + const TensorWrapper &source, size_t chunk_offset, const std::vector &chunk_shape) { + // Start with a chunk of the source tensor + auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape); + + // Update chunk with offset data pointers from the communication buffer + auto ubuf_ptr = reinterpret_cast(_ubuf.dptr()) + (chunk_offset * _ubuf.element_size()); + if (chunk.dptr() != nullptr) { + chunk.set_rowwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), chunk.shape()); + } + if (chunk.columnwise_dptr() != nullptr) { + chunk.set_columnwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), + chunk.columnwise_shape()); + } + return chunk; +} + /*************************************************************************************************** * Comm+GEMM Overlap Base (Pipelined / Collective) **************************************************************************************************/ @@ -146,11 +212,13 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, - int num_comm_sm, bool set_sm_margin, bool atomic_gemm) + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool rs_overlap_first_gemm) : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, atomic_gemm) { + _rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", @@ -177,8 +245,8 @@ CommOverlapBase::~CommOverlapBase() { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ -void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, +void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, @@ -205,7 +273,7 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); assert(rs_output.element_size() == 2); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, + reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, comm_elements, _ub_comm, _stream_comm, (cudaEvent_t)_comm_launch_event); } else { @@ -230,20 +298,20 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, +void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; // Get GEMM dimensions - size_t m = A.size(0); - size_t k = A.size(1); - size_t n = B.size(0); + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n = _ubuf.size(0); size_t m_chunk = m / _num_splits; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); @@ -264,9 +332,8 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens assert(pre_gelu_out.numel() == 0); - auto output_d = TensorWrapper(_ubuf.dptr(), {n, m}, D.dtype(), D.amax(), D.scale(), nullptr); - auto workspace_chunk = - TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + auto output_d = get_buffer_chunk_like(D, 0, {n, m}); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(), @@ -278,11 +345,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens _ub_comm->sms = UB_MAX_SM; } if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_strided_atomic_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, + rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, &counter_ptr[i], _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, @@ -291,11 +357,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens } } else if (_rs_kernel_type == 2) { if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_strided_multiatomic_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, + rs_output_ptr, D.scale_inv(), _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, counter_ptr, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m, @@ -308,7 +373,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens if (_ubuf.element_size() == 1) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, - reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, _ubuf_scale_inv, + reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, _ub_comm, _stream_comm);); } else { @@ -330,32 +395,24 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, +void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, - cudaStream_t stream_main) { + TensorWrapper &rs_output, cudaStream_t stream_main) { // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; - size_t m = A.size(0); - size_t k = A.size(1); - size_t n = B.size(0); + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n = _ubuf.size(0); size_t m_chunk = m / _num_splits; size_t input_a_chunk_size = m_chunk * k; size_t output_chunk_size = n * m_chunk; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - // Get input, output, and workspace data pointers - char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); - char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); - char *workspace_ptr = reinterpret_cast(workspace.dptr()); - - char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - // Catch up the default torch stream NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); for (size_t i = 0; i < _stream_compute.size(); i++) { @@ -365,30 +422,21 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap assert(pre_gelu_out.numel() == 0); - if (gemm_overlap) { - auto input_a_chunk = - TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv()); - auto output_chunk = - TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); - auto workspace_chunk = TensorWrapper( - workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_rs_overlap_first_gemm) { + auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k}); + auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk}); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[0]); for (int i = 1; i < _num_splits; i++) { - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * D.element_size(); - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - - input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, - A.dtype(), nullptr, nullptr, A.scale_inv()); - output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), {n, m_chunk}, - D.dtype(), D.amax(), D.scale(), nullptr); - workspace_chunk = TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); + output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); + workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), @@ -401,11 +449,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap // Communication chunk if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, + rs_output_ptr, D.scale_inv(), _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, @@ -422,11 +469,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap // Last communication chunk with max SM _ub_comm->sms = UB_MAX_SM; if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, (_num_splits - 1) * output_chunk_size, + rs_output_ptr, D.scale_inv(), _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, @@ -435,16 +481,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap } } else { for (int i = 0; i < _num_splits; i++) { - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - - auto input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, - A.dtype(), nullptr, nullptr, A.scale_inv()); - auto output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), - {n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); - auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); + auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), @@ -459,11 +499,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap _ub_comm->sms = UB_MAX_SM; } if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * output_chunk_size, m_chunk, n, m, + rs_output_ptr, D.scale_inv(), _ub_reg, i * output_chunk_size, m_chunk, n, m, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, @@ -471,8 +510,6 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap } rs_output_ptr += m_chunk * rs_output.element_size(); - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); } } @@ -567,17 +604,30 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]); } +TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, + size_t chunk_id) { + // Start with a chunk of the source tensor + auto chunk = get_tensor_chunk(source, 0, AS_VECTOR(_ubufs[chunk_id].shape())); + + // Update chunk with offset data pointers from the communication buffer + if (chunk.dptr() != nullptr) { + chunk.set_rowwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.shape()); + } + if (chunk.columnwise_dptr() != nullptr) { + chunk.set_columnwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.columnwise_shape()); + } + return chunk; +} + /* ** Split AllGather + AtomicGEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ -void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, - TensorWrapper &B_copy, cudaStream_t stream_main) { +void CommOverlapP2PBase::atomic_gemm_overlap_ag( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -585,8 +635,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T // Get GEMM dimensions between TN and NN input layouts const size_t m = (transa) ? A.size(0) : A.size(1); - const size_t n = _ubuf.size(0); - const size_t n_chunk = n / _tp_size; + const size_t n_chunk = _ubufs[0].size(0); assert(pre_gelu_out.numel() == 0); // Get communication and GEMM output chunk sizes @@ -596,7 +645,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T void *D_buffer_ptr; int D_chunk_bytes = n_chunk * m * D.element_size(); NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main)); - auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); + auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), + D.scale_inv(), D.scale_inv_shape(), D.scaling_mode()); // Reset atomic counters int *counter_ptr = reinterpret_cast(_counter.dptr()); @@ -607,10 +657,9 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); - auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv()); + auto input_b = get_buffer_chunk_like(B, 0, AS_VECTOR(B.shape())); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - auto workspace_chunk = - TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); for (int i = 0; i < _tp_size - 1; i++) { // Set the userbuffer id. Buffer under send is the input for the current @@ -676,11 +725,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ -void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &B_copy, cudaStream_t stream_main) { +void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -693,13 +743,8 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW // Get communication and GEMM output chunk sizes const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const bool do_gelu = pre_gelu_out.numel() > 0; - const int output_chunk_bytes = (n_chunk * m) * D.element_size(); - const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; - - // Get output and workspace data pointers - char *output_ptr = reinterpret_cast(D.dptr()); - char *pre_gelu_out_ptr = reinterpret_cast(pre_gelu_out.dptr()); - char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t input_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); @@ -710,7 +755,8 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW } if (_aggregate) { const int num_steps = _tp_size / 2; - char *input_b_ptr = reinterpret_cast(_ubuf.dptr()); + input_chunk_size *= 2; + output_chunk_size *= 2; // Initial 1X input chunk exchange between neighboring peers int send_chunk_id = _tp_id; @@ -738,27 +784,14 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW recv_offset = comm_bytes * recv_chunk_id; // GEMM - char *input_b_chunk_ptr = input_b_ptr + send_offset; - auto input_b_chunk = - TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(), - nullptr, nullptr, B.scale_inv()); - - char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); - auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), - {n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr); - - char *aux_chunk_ptr = - (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; - auto aux_chunk_shape = - (do_gelu) ? std::vector{n_chunk * 2, m} : std::vector{0}; - auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, - pre_gelu_out.dtype()); - - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, + {n_chunk * 2, k}); + auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m}); + auto aux_chunk = (do_gelu) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, + {n_chunk * 2, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, @@ -795,24 +828,13 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW int recv_offset = comm_bytes * recv_chunk_id; // GEMM - auto input_b_chunk = TensorWrapper(_ubufs[send_chunk_id].dptr(), {n_chunk, k}, B.dtype(), - nullptr, nullptr, B.scale_inv()); - - char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); - auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), {n_chunk, m}, - D.dtype(), D.amax(), D.scale(), nullptr); - - char *aux_chunk_ptr = - (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; - auto aux_chunk_shape = (do_gelu) ? std::vector{n_chunk, m} : std::vector{0}; - auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, - pre_gelu_out.dtype()); - - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k}); + auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m}); + auto aux_chunk = (do_gelu) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, + {n_chunk, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, @@ -853,13 +875,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW /* ** Split ReduceScatter + GEMM using P2P communication */ -void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, - cudaStream_t stream_main) { +void CommOverlapP2PBase::atomic_gemm_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -878,12 +898,9 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T // Atomic GEMM // Process GEMM chunks in the order that AG+GEMM places the output chunks. - auto output_d = TensorWrapper(_ubuf.dptr(), D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); - size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - auto workspace_chunk = - TensorWrapper(workspace.data(), std::vector{workspace_size_chunk}, workspace.dtype()); + auto output_d = get_buffer_chunk_like(D, 0, AS_VECTOR(D.shape())); nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), - transa, transb, grad, workspace_chunk.data(), accumulate, + transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, 0, _tp_size, true, _counter.data(), stream_main); @@ -909,10 +926,9 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, _ubufs[0].numel(), stream_main);); } else { reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); @@ -923,26 +939,26 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T /* ** Split ReduceScatter + GEMM using P2P communication */ -void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, cudaStream_t stream_main) { +void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; - size_t k = A.size(1); - size_t n = B.size(0); // Get communication and GEMM input chunk sizes - size_t n_chunk = n / _tp_size; + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n_chunk = _ubufs[0].size(0); const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - const int input_b_chunk_bytes = n_chunk * k * B.element_size(); // Get input and workspace data pointers - char *input_b_ptr = reinterpret_cast(B.dptr()); - char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t input_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); // Catch up the main stream @@ -960,18 +976,11 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW // GEMM chunk int stream_id = i % _stream_compute.size(); int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; - char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); - - auto input_b_chunk = TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk, k}, - B.dtype(), nullptr, nullptr, B.scale_inv()); - - auto output_chunk = - TensorWrapper(_ubufs[i].dptr(), _ubufs[i].shape(), D.dtype(), D.amax(), D.scale(), nullptr); - char *workspace_chunk_ptr = workspace_ptr + stream_id * workspace_size_chunk; - auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k}); + auto output_chunk = get_buffer_chunk_by_id(D, i); + auto workspace_chunk = get_tensor_chunk(workspace, stream_id * workspace_size_chunk, + {workspace_size_chunk}); nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, @@ -1009,11 +1018,10 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, _ubufs[0].numel(), stream_main);); } else { reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 6c4fc23f86..f8e3f5496d 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -17,6 +17,10 @@ #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 +#define NOT_IMPLEMENTED_ERROR() NVTE_ERROR("Operation is not implemented.") + +#define NOT_SUPPORTED_ERROR() NVTE_ERROR("Operation not supported.") + namespace transformer_engine { /* \brief Check if Userbufers bootstraps with direct calls to MPI collectives. @@ -67,6 +71,8 @@ class CommOverlapCore { cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; public: + CommOverlapCore() {} // dummy constructor for exposing type to Python + CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, @@ -80,26 +86,76 @@ class CommOverlapCore { _ubuf_scale_inv_initialized = true; } + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, + const std::vector &shape); + + TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, + const std::vector &shape); + bool is_atomic_gemm() { return _atomic_gemm; } bool is_p2p_overlap() { return _is_p2p; } bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + + virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } + + virtual void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } + + virtual void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } + + virtual void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } + + virtual void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } }; // CommOverlapCore class CommOverlapBase : public CommOverlapCore { protected: int _rs_kernel_type; + bool _rs_overlap_first_gemm; cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; public: + CommOverlapBase() {} // dummy constructor for exposing type to Python + CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, - bool set_sm_margin = true, bool atomic_gemm = false); + bool set_sm_margin = true, bool atomic_gemm = false, + bool rs_overlap_first_gemm = false); virtual ~CommOverlapBase(); @@ -107,49 +163,66 @@ class CommOverlapBase : public CommOverlapCore { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ - void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, - TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main); + void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) override; + + void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override { + NOT_SUPPORTED_ERROR(); + } + + void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override { + NOT_SUPPORTED_ERROR(); + } /* ** Split FPROP GEMM + ReduceScatter */ - void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, - TensorWrapper &rs_output, cudaStream_t stream_main); + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; /* ** Split FPROP GEMM + ReduceScatter */ - void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output, - cudaStream_t stream_main); + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { protected: bool _is_reduce_scatter{false}; bool _use_multiatomic_ag{false}; - + bool _aggregate; int _next_rank; int _prev_rank; int _rank_round_tp; - int _aggregate; int _num_ubuf_chunks; int _self_chunk_id; - std::vector _ubufs; - std::vector _stream_send; cudaStream_t _stream_recv; cudaEvent_t _stop_send, _stop_recv; public: + + CommOverlapP2PBase() {} // dummy constructor for exposing type to Python + CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, @@ -160,45 +233,55 @@ class CommOverlapP2PBase : public CommOverlapCore { virtual ~CommOverlapP2PBase(); + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); + + void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) override { + NOT_SUPPORTED_ERROR(); + } + /* ** Split AllGather + AtomicGEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ - void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main); + void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override; /* ** Split AllGather + GEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ - void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main); + cudaStream_t stream_main) override; /* ** Split ReduceScatter + GEMM using P2P communication */ - void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main); + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; /* ** Split ReduceScatter + GEMM using P2P communication */ - void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main); + cudaStream_t stream_main) override; }; // CommOverlapP2PBase } // namespace transformer_engine diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index de44d50757..f8862d6918 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -15,7 +15,7 @@ #include "cuda_runtime.h" #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ - pybind11::enum_(m, "DType") \ + pybind11::enum_(m, "DType", pybind11::module_local()) \ .value("kByte", transformer_engine::DType::kByte) \ .value("kInt32", transformer_engine::DType::kInt32) \ .value("kFloat32", transformer_engine::DType::kFloat32) \ @@ -23,12 +23,12 @@ .value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ - pybind11::enum_(m, "NVTE_Bias_Type") \ + pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ - pybind11::enum_(m, "NVTE_Mask_Type") \ + pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) \ .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ @@ -36,7 +36,7 @@ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ - pybind11::enum_(m, "NVTE_QKV_Layout") \ + pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ @@ -52,15 +52,18 @@ .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ - pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend", \ + pybind11::module_local()) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ - pybind11::enum_(m, "CommOverlapType") \ + pybind11::enum_(m, "CommOverlapType", \ + pybind11::module_local()) \ .value("RS", transformer_engine::CommOverlapType::RS) \ .value("AG", transformer_engine::CommOverlapType::AG); \ - pybind11::enum_(m, "CommOverlapAlgo") \ + pybind11::enum_(m, "CommOverlapAlgo", \ + pybind11::module_local()) \ .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ .value("SPLIT_PIPELINED_AG_P2P", \ @@ -71,6 +74,29 @@ .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + py::class_>(m, "CommOverlapCore", \ + pybind11::module_local()) \ + .def(py::init([](){return new transformer_engine::CommOverlapCore();}), \ + py::call_guard()) \ + .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ + py::call_guard()) \ + .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ + py::call_guard()) \ + .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapBase", \ + pybind11::module_local()) \ + .def(py::init([](){return new transformer_engine::CommOverlapBase();}), \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ + pybind11::module_local()) \ + .def(py::init([](){return new transformer_engine::CommOverlapP2PBase();}), \ + py::call_guard()); \ m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ py::call_guard(), py::arg("device_id") = -1); \ m.def( \ diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3af1b99fb1..0f03331089 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7815,11 +7815,11 @@ def __init__( fuse_qkv_params: bool = False, zero_centered_gamma: bool = False, qkv_weight_interleaved: bool = True, - ub_bulk_wgrad: bool = False, - ub_bulk_dgrad: bool = False, - ub_overlap_rs_dgrad: bool = False, - ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, bias: bool = True, normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 0685ca50be..ff475caf21 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -16,6 +16,8 @@ """ TE_DType = { torch.uint8: tex.DType.kByte, + torch.float8_e4m3fn: tex.DType.kFloat8E4M3, + torch.float8_e5m2: tex.DType.kFloat8E5M2, torch.int32: tex.DType.kInt32, torch.float32: tex.DType.kFloat32, torch.half: tex.DType.kFloat16, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 44914a620e..7f3a5dbdc8 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -81,9 +81,10 @@ def general_gemm( bias: Optional[torch.Tensor] = None, use_split_accumulator: bool = False, grad: bool = False, - ub_algo: tex.CommOverlapAlgo = None, ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, - ub_buffer: Optional[torch.Tensor] = None, + ub_type: tex.CommOverlapType = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, ) -> Iterable[Optional[torch.Tensor]]: """GEMM supporting fp8 inputs.""" @@ -91,15 +92,25 @@ def general_gemm( transa = layout[0] == "T" transb = layout[1] == "T" # assert quantization_params is None, "FP8 output not supported yet" + + if ub_type is not None: + assert ub is not None, ( + f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires" + + "a valid `ub` communicator object." + ) + + if ub is not None: + assert ub_type is not None, f"Comm+GEMM overlap requires a valid `comm_type` argument." + if ub_type == tex.CommOverlapType.RS: + if not (bulk_overlap and not ub.is_fp8_ubuf()): + assert extra_output is not None, "GEMM+RS overlap requires extra output tensor." + if out is not None: if not out.is_contiguous(): raise ValueError("Output tensor is not contiguous.") # Use bfloat16 as default bias_dtype - bias_dtype = torch.bfloat16 if bias is None else bias.dtype - bias_dtype = TE_DType[bias_dtype] - if bias is None and not grad: - bias = _empty_tensor() + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] args = ( A, @@ -119,105 +130,21 @@ def general_gemm( accumulate, use_split_accumulator, ) - - fn = tex.generic_gemm - if ub_algo is not None: - raise ValueError("Not implemented yet!") - if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG: - fn = ub.bulk_overlap - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple( - args - + ( - tex.CommOverlapType.AG, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS: - fn = ub.bulk_overlap - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple( - args - + ( - tex.CommOverlapType.RS, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P: - fn = ub.split_overlap_ag_p2p - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P: - assert A_scaling_mode == [-1, -1, 1] and B_scaling_mode == [ - -1, - -1, - 1, - ], "Block scaling unsupported for atomic GEMM." - fn = ub.atomic_gemm_overlap_ag_p2p - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS: - fn = ub.split_overlap_rs - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS requires extra output tensor" - args = tuple( - args - + ( - True, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P: - fn = ub.split_overlap_rs_p2p - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS: - assert A_scaling_mode == [-1, -1, 1] and B_scaling_mode == [ - -1, - -1, - 1, - ], "Block scaling unsupported for atomic GEMM." - fn = ub.atomic_gemm_overlap_rs - assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor" - args = tuple( - args - + ( - True, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P: - assert A_scaling_mode == [-1, -1, 1] and B_scaling_mode == [ - -1, - -1, - 1, - ], "Block scaling unsupported for atomic GEMM." - fn = ub.atomic_gemm_overlap_rs_p2p - assert ( - extra_output_tensor is not None - ), "ATOMIC_GEMM_RS_P2P requires extra output tensor" - args = tuple(args + (extra_output_tensor,)) - if ub_algo is not None and ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P: - out = fn(*args) - gelu_input = None - bias_grad = None + kwargs = { + 'comm_overlap': ub, + 'comm_type': ub_type, + 'extra_output': extra_output, + 'bulk_overlap': bulk_overlap, + } + + if ub_type == tex.CommOverlapType.AG and ub.is_p2p_overlap(): + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) else: original_scale_inverses = swizzle_inputs(A, B, layout) - out, bias_grad, gelu_input = fn(*args) + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) reset_swizzled_inputs(A, B, original_scale_inverses) - return out, bias_grad, gelu_input + return out, bias_grad, gelu_input, extra_output def general_grouped_gemm( diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index ada8c9d318..4d7b2dea76 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -143,14 +143,18 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, scaling_mode); } -size_t product(const std::vector& shape) { - size_t ret = 1; +template +T product(const std::vector& shape) { + T ret = 1; for (auto s : shape) { ret *= s; } return ret; } +template size_t product(const std::vector& shape); +template int64_t product(const std::vector& shape); + size_t product(const NVTEShape& shape, size_t begin, size_t end) { NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end, " in a shape with ", shape.ndim, " entries"); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index e981eb9927..7c72c92741 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -169,9 +169,11 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) { case transformer_engine::DType::kBFloat16: return at::kBFloat16; case transformer_engine::DType::kByte: + return at::kByte; case transformer_engine::DType::kFloat8E4M3: + return at::kFloat8_e4m3fn; case transformer_engine::DType::kFloat8E5M2: - return at::kByte; + return at::kFloat8_e5m2; default: NVTE_ERROR("Invalid type"); } @@ -179,6 +181,10 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) { inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { switch (t) { + case at::kFloat8_e4m3fn: + return transformer_engine::DType::kFloat8E4M3; + case at::kFloat8_e5m2: + return transformer_engine::DType::kFloat8E5M2; case at::kHalf: return transformer_engine::DType::kFloat16; case at::kFloat: @@ -232,7 +238,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); -size_t product(const std::vector& shape); +template +T product(const std::vector& shape); size_t product(const NVTEShape& shape, size_t begin, size_t end); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 93af90b4a0..949a1cfef3 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -189,7 +189,9 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans py::handle quantizer, std::optional out_dtype, MaybeTensor bias, DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator); + bool use_split_accumulator, CommOverlapCore* comm_overlap = nullptr, + std::optional comm_type = std::nullopt, + MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false); /*************************************************************************************************** * Cast fusions @@ -396,74 +398,26 @@ class CommOverlapHelper : torch::CustomClassHolder { }; class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { - private: - torch::Tensor _ubuf_torch; - torch::Tensor _ubuf_counter; - public: CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, - bool set_sm_margin = true, bool atomic_gemm = false); - - void set_ubuf_scale_inv(torch::Tensor scale_inv) { - assert(scale_inv.numel()); - assert(scale_inv.scalar_type() == torch::kFloat32); - transformer_engine::CommOverlapBase::set_ubuf_scale_inv( - reinterpret_cast(scale_inv.data_ptr())); - } - - void copy_input_to_ubuf(torch::Tensor input, int comm_type); - - torch::Tensor get_ubuf_output(int comm_type); - - /* - ** Bulk GEMM + COMM - ** This function assumes the communication input is pre-copied to _ubuf - */ - std::vector bulk_overlap( - at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - std::vector A_scaling_mode, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, std::vector B_scaling_mode, bool transb, - at::Tensor D, at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, - transformer_engine::CommOverlapType comm_type, at::Tensor rs_output); - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, - transformer_engine::DType A_type, std::vector A_scaling_mode, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, std::vector B_scaling_mode, - bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output); - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - std::vector A_scaling_mode, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, transformer_engine::DType B_type, - std::vector B_scaling_mode, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, at::Tensor rs_output); + bool set_sm_margin = true, bool atomic_gemm = false, + bool rs_overlap_first_gemm = false); + + ~CommOverlap() {} + + void set_buffer_params(py::handle quantizer); + + void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk); + + py::object get_buffer(py::handle quantizer, bool local_chunk, + std::optional> shape = std::nullopt); + }; // CommOverlap class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { - private: - torch::Tensor _ubuf_torch; - torch::Tensor _ubuf_counter; - public: CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, @@ -473,76 +427,15 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, bool aggregate = false); - void set_ubuf_scale_inv(torch::Tensor scale_inv) { - assert(scale_inv.numel()); - assert(scale_inv.scalar_type() == torch::kFloat32); - transformer_engine::CommOverlapP2PBase::set_ubuf_scale_inv( - reinterpret_cast(scale_inv.data_ptr())); - } - - void copy_input_to_ubuf(torch::Tensor input, bool chunk); - - torch::Tensor get_ubuf_output(int comm_type); - - /* - ** Split AllGather + AtomicGEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - void atomic_gemm_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, - transformer_engine::DType A_type, std::vector A_scaling_mode, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, std::vector B_scaling_mode, - bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor B_copy); - - /* - ** Split AllGather + GEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - void split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - std::vector A_scaling_mode, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, transformer_engine::DType B_type, - std::vector B_scaling_mode, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - at::Tensor B_copy); - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, - transformer_engine::DType A_type, std::vector A_scaling_mode, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, std::vector B_scaling_mode, - bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor rs_output); - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - std::vector A_scaling_mode, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, transformer_engine::DType B_type, - std::vector B_scaling_mode, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - at::Tensor rs_output); + ~CommOverlapP2P() {} + + void set_buffer_params(py::handle quantizer); + + void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk); + + py::object get_buffer(py::handle quantizer, bool local_chunk, + std::optional> shape = std::nullopt); + }; // CommOverlapP2P #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 8e63feffd1..b4667d5c94 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -15,35 +15,6 @@ using namespace std::placeholders; namespace te = transformer_engine; -// TODO: Actually take care of scaling modes -#define MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inv, A_scaling_mode, A_type, B, B_scale_inv, \ - B_scaling_mode, B_type, D, D_amax, D_scale, D_type, bias, \ - bias_type, pre_gelu_out, workspace) \ - A = A.contiguous(); \ - NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; \ - auto A_ = makeTransformerEngineTensor( \ - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, \ - nullptr, nullptr, A_scale_inv.data_ptr(), getTensorShape(A_scale_inv), nvte_scaling_modeA); \ - B = B.contiguous(); \ - NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; \ - auto B_ = makeTransformerEngineTensor( \ - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, \ - nullptr, nullptr, B_scale_inv.data_ptr(), getTensorShape(B_scale_inv), nvte_scaling_modeB); \ - auto D_ = makeTransformerEngineTensor( \ - D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, \ - D_amax.data_ptr(), D_scale.data_ptr(), nullptr); \ - auto bias_ = makeTransformerEngineTensor( \ - bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); \ - const auto gelu_shape = (pre_gelu_out.data_ptr() == nullptr) \ - ? std::vector{static_cast(pre_gelu_out.size(0))} \ - : std::vector{static_cast(pre_gelu_out.size(0)), \ - static_cast(pre_gelu_out.size(1))}; \ - auto pre_gelu_out_ = makeTransformerEngineTensor( \ - pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); \ - auto workspace_ = makeTransformerEngineTensor( \ - workspace.data_ptr(), std::vector{static_cast(workspace.size(0))}, \ - te::DType::kByte); - /*************************************************************************************************** * CommOverlapHelper **************************************************************************************************/ @@ -169,316 +140,179 @@ void CommOverlapHelper::ub_barrier(ExtComm group) { * CommOverlap **************************************************************************************************/ -CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, int num_splits, - int num_max_streams, int comm_cga_size, int gemm_priority, - int comm_priority, int num_comm_sm, bool set_sm_margin, bool atomic_gemm) +CommOverlap::CommOverlap( + const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, + int tp_size, int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, + int comm_priority, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool rs_overlap_first_gemm) : te::CommOverlapBase( buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, num_max_streams, - comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, atomic_gemm) { - // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to - // for PyTorch to factor externally allocated memory into its memory pool and garbage collection - // threshold calculation. - _ubuf_torch = torch::from_blob( - _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, - at::device(torch::kCUDA).dtype(buffer_dtype)); - if (_atomic_gemm) { - _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, - at::device(torch::kCUDA).dtype(torch::kInt32)); - } -} - -/* -** Bulk GEMM + COMM -** This function assumes the communication input is pre-copied to _ubuf -*/ -std::vector CommOverlap::bulk_overlap( - at::Tensor A, at::Tensor A_scale_inverse, te::DType A_type, std::vector A_scaling_mode, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, te::DType B_type, - std::vector B_scaling_mode, bool transb, at::Tensor D, at::Tensor D_scale, - te::DType D_type, at::Tensor D_amax, at::Tensor bias, te::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, te::CommOverlapType comm_type, at::Tensor rs_output) { - using namespace transformer_engine::pytorch; - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_scaling_mode, A_type, B, B_scale_inverse, - B_scaling_mode, B_type, D, D_amax, D_scale, D_type, bias, - bias_type, pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::bulk_overlap(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, - grad, accumulate, use_split_accumulator, comm_type, rs_out_, - stream_main); - - // Get the current userbuf offset - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); - if (comm_type == te::CommOverlapType::RS) { - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - } - - // Generate output tensor from userbuf data pointer - int output_c_dim0 = - (comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - auto output_tensor = - torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); - - return {D, output_tensor}; -} // CommOverlap::bulk_overlap - -/* -** Split FPROP GEMM + ReduceScatter -*/ -void CommOverlap::atomic_gemm_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, te::DType A_type, std::vector A_scaling_mode, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, te::DType B_type, - std::vector B_scaling_mode, bool transb, at::Tensor D, at::Tensor D_scale, - te::DType D_type, at::Tensor D_amax, at::Tensor bias, te::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, at::Tensor rs_output) { - using namespace transformer_engine::pytorch; - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_scaling_mode, A_type, B, B_scale_inverse, - B_scaling_mode, B_type, D, D_amax, D_scale, D_type, bias, - bias_type, pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - gemm_overlap, rs_out_, stream_main); -} // CommOverlap::split_overlap_rs + comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, atomic_gemm, + rs_overlap_first_gemm) {} -/* -** Split FPROP GEMM + ReduceScatter -*/ -void CommOverlap::split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, te::DType A_type, - std::vector A_scaling_mode, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, te::DType B_type, - std::vector B_scaling_mode, bool transb, at::Tensor D, - at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, - at::Tensor bias, te::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output) { - using namespace transformer_engine::pytorch; - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_scaling_mode, A_type, B, B_scale_inverse, - B_scaling_mode, B_type, D, D_amax, D_scale, D_type, bias, - bias_type, pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - gemm_overlap, rs_out_, stream_main); -} // CommOverlap::split_overlap_rs +void CommOverlap::set_buffer_params(py::handle quantizer) { + std::unique_ptr my_quantizer = te::pytorch::convert_quantizer(quantizer); + my_quantizer->set_quantization_params(&_ubuf); + _ubuf_scale_inv_initialized = true; +} /* ** Helper function to copy input to _ubuf */ -void CommOverlap::copy_input_to_ubuf(torch::Tensor input, int comm_type) { +void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { + auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); + auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr(); + NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); + char *ubuf_ptr = reinterpret_cast(_ubuf.dptr()); - te::CommOverlapType _comm_type = static_cast(comm_type); - if (_comm_type == te::CommOverlapType::AG) { - if ((input.numel() * _tp_size) != (int64_t)_ubuf.numel() || - input.element_size() != (int64_t)_ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + if (local_chunk) { + if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the local communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size(); } else { - if (input.numel() != (int64_t)_ubuf.numel() || - input.element_size() != (int64_t)_ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } + if (input_tensor.numel() > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the global communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); } + // Copy either row or columnwise data into the communication buffer's columnwise data + // NOTE: _ubuf.columnwise_dptr() is not a valid copy target because it is not registered with + // the Userbuffers communicator. at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); - NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); + NVTE_CHECK_CUDA(cudaMemcpyAsync( + ubuf_ptr, input_tensor.dptr(), input_tensor.numel() * input_tensor.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); } -torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { - using namespace transformer_engine::pytorch; +py::object CommOverlap::get_buffer( + py::handle quantizer, bool local_chunk, std::optional> shape) { + using namespace te::pytorch; char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); - te::CommOverlapType _comm_type = static_cast(comm_type); - if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) - NVTE_ERROR("Invalid comm_type"); - if (_comm_type == te::CommOverlapType::RS) + if (local_chunk) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - int output_c_dim0 = - (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, - torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); + + std::vector torch_shape; + if (shape.has_value()) { + torch_shape = shape.value(); + auto requested = product(torch_shape); + auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel(); + NVTE_CHECK(requested == expected, + "Number of elements in the requested shape (", requested, + ") does not match allocated buffer size (", expected, + ")!"); + } else { + int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); + int64_t output_c_dim1 = _ubuf.size(1); + torch_shape = {output_c_dim0, output_c_dim1}; + } + + auto ubuf_tensor = torch::from_blob(reinterpret_cast(ubuf_wt_ptr), torch_shape, + at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); + + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + std::vector te_shape; + for (auto s : torch_shape) + te_shape.emplace_back(static_cast(s)); + + auto is_internal = my_quantizer->internal; + my_quantizer->internal = false; + auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor); + my_quantizer->internal = is_internal; + return py_tensor; } /*************************************************************************************************** * CommOverlapP2P **************************************************************************************************/ -CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, - te::CommOverlapType comm_type, int num_max_streams, - int comm_cga_size, int gemm_priority, int comm_priority, - int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, - bool aggregate) +CommOverlapP2P::CommOverlapP2P( + const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, + int tp_size, te::CommOverlapType comm_type, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool use_ce, bool aggregate) : te::CommOverlapP2PBase( buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, - atomic_gemm, aggregate) { - // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to - // for PyTorch to factor externally allocated memory into its memory pool and garbage collection - // threshold calculation. - _ubuf_torch = torch::from_blob( - _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, - at::device(torch::kCUDA).dtype(buffer_dtype)); - if (_atomic_gemm) { - _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, - at::device(torch::kCUDA).dtype(torch::kInt32)); - } -} - -/* -** Split AllGather + AtomicGEMM using P2P communication -** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is -*needed to have AG outputs -** in each rank to be in the contiguous memory space after all ring exchange -*phases. -*/ -void CommOverlapP2P::atomic_gemm_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, te::DType A_type, std::vector A_scaling_mode, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, te::DType B_type, - std::vector B_scaling_mode, bool transb, at::Tensor D, at::Tensor D_scale, - te::DType D_type, at::Tensor D_amax, at::Tensor bias, te::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, at::Tensor B_copy) { - using namespace transformer_engine::pytorch; - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_scaling_mode, A_type, B, B_scale_inverse, - B_scaling_mode, B_type, D, D_amax, D_scale, D_type, bias, - bias_type, pre_gelu_out, workspace) - - auto B_copy_ = makeTransformerEngineTensor(B_copy); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::atomic_gemm_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, - use_split_accumulator, B_copy_, stream_main); -} // atomic_gemm_overlap_ag - -/* -** Split AllGather + GEMM using P2P communication -** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is -*needed to have AG outputs -** in each rank to be in the contiguous memory space after all ring exchange -*phases. -*/ -void CommOverlapP2P::split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, te::DType A_type, - std::vector A_scaling_mode, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, te::DType B_type, - std::vector B_scaling_mode, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, - at::Tensor D_amax, at::Tensor bias, te::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, - bool use_split_accumulator, at::Tensor B_copy) { - using namespace transformer_engine::pytorch; - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_scaling_mode, A_type, B, B_scale_inverse, - B_scaling_mode, B_type, D, D_amax, D_scale, D_type, bias, - bias_type, pre_gelu_out, workspace) - - auto B_copy_ = makeTransformerEngineTensor(B_copy); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::split_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - B_copy_, stream_main); -} // split_overlap_ag + atomic_gemm, aggregate) {} -/* -** Split ReduceScatter + GEMM using P2P communication -*/ -void CommOverlapP2P::atomic_gemm_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, te::DType A_type, std::vector A_scaling_mode, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, te::DType B_type, - std::vector B_scaling_mode, bool transb, at::Tensor D, at::Tensor D_scale, - te::DType D_type, at::Tensor D_amax, at::Tensor bias, te::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, at::Tensor rs_output) { - using namespace transformer_engine::pytorch; - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_scaling_mode, A_type, B, B_scale_inverse, - B_scaling_mode, B_type, D, D_amax, D_scale, D_type, bias, - bias_type, pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, - use_split_accumulator, rs_out_, stream_main); -} - -/* -** Split ReduceScatter + GEMM using P2P communication -*/ -void CommOverlapP2P::split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, te::DType A_type, - std::vector A_scaling_mode, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, te::DType B_type, - std::vector B_scaling_mode, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, - at::Tensor D_amax, at::Tensor bias, te::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, - bool use_split_accumulator, at::Tensor rs_output) { - using namespace transformer_engine::pytorch; - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_scaling_mode, A_type, B, B_scale_inverse, - B_scaling_mode, B_type, D, D_amax, D_scale, D_type, bias, - bias_type, pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - rs_out_, stream_main); +void CommOverlapP2P::set_buffer_params(py::handle quantizer) { + std::unique_ptr my_quantizer = te::pytorch::convert_quantizer(quantizer); + my_quantizer->set_quantization_params(&_ubuf); + for (size_t i = 0; i < _ubufs.size(); i++) + my_quantizer->set_quantization_params(&_ubufs[i]); } /* ** Copy input to _ubufs[0] */ -void CommOverlapP2P::copy_input_to_ubuf(torch::Tensor input, bool chunk) { +void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { + auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); + auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr(); + NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - if (chunk) { + if (local_chunk) { // Copy input to the target ubuf chunk by rank offset - if (input.numel() != (int64_t)_ubufs[0].numel() || - input.element_size() != (int64_t)_ubufs[0].element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); + if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the local communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr, + input_tensor.numel() * input_tensor.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); + } else { - if (input.numel() != (int64_t)_ubuf.numel() || - input.element_size() != (int64_t)_ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); + if (input_tensor.numel() > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the global communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr, + input_tensor.numel() * input_tensor.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); } } -torch::Tensor CommOverlapP2P::get_ubuf_output(int comm_type) { - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); - te::CommOverlapType _comm_type = static_cast(comm_type); - if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) - NVTE_ERROR("Invalid comm_type"); - if (_comm_type == te::CommOverlapType::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); - int output_c_dim0 = - (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); +py::object CommOverlapP2P::get_buffer( + py::handle quantizer, bool local_chunk, std::optional> shape) { + using namespace te::pytorch; + char *ubuf_wt_ptr = reinterpret_cast(local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr()); + + std::vector torch_shape; + if (shape.has_value()) { + torch_shape = shape.value(); + auto requested = product(torch_shape); + auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel(); + NVTE_CHECK(requested == expected, + "Number of elements in the requested shape (", requested, + ") does not match allocated buffer size (", expected, + ")!"); + } else { + int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); + int64_t output_c_dim1 = _ubuf.size(1); + torch_shape = {output_c_dim0, output_c_dim1}; + } + auto ubuf_tensor = torch::from_blob(reinterpret_cast(ubuf_wt_ptr), torch_shape, + at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); + + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + std::vector te_shape; + for (auto s : torch_shape) + te_shape.emplace_back(static_cast(s)); + + auto is_internal = my_quantizer->internal; + my_quantizer->internal = false; + auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor); + my_quantizer->internal = is_internal; + return py_tensor; } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 39e21224f8..6222561d1a 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -86,7 +86,9 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans py::handle quantizer, std::optional out_dtype, MaybeTensor bias, DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator) { + bool use_split_accumulator, CommOverlapCore* comm_overlap, + std::optional comm_type, MaybeTensor extra_output, + bool bulk_overlap) { // Input tensors NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); @@ -121,15 +123,15 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans TensorWrapper bias_tensor; MaybeTensor bias_grad = std::nullopt; if (bias.has_value()) { - if (!bias->is_contiguous()) { - bias = bias->contiguous(); - } - if (!grad) { - bias_tensor = makeTransformerEngineTensor(*bias); - } else { + if (grad) { auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA); - bias_grad = at::empty({B_shape.data[B_shape.ndim - 1]}, opts); + bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); bias_tensor = makeTransformerEngineTensor(*bias_grad); + } else { + if (!bias->is_contiguous()) { + bias = bias->contiguous(); + } + bias_tensor = makeTransformerEngineTensor(*bias); } } @@ -166,29 +168,64 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + auto main_stream = at::cuda::getCurrentCUDAStream(); if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { - // Launch GEMM - nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), - te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), accumulate, - use_split_accumulator, num_math_sms, at::cuda::getCurrentCUDAStream()); + if (comm_overlap) { + // Prepare extra output tensor + TensorWrapper extra_output_tensor; + if (extra_output.has_value()) { + extra_output_tensor = makeTransformerEngineTensor(*extra_output); + } else { + extra_output_tensor = makeTransformerEngineTensor( + nullptr, std::vector{0}, DType::kByte); + } + + // Direct GEMM call to the correct overlap + if (bulk_overlap) { + comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, comm_type.value(), extra_output_tensor, + main_stream); + } else if (comm_type.value() == CommOverlapType::AG) { + if (comm_overlap->is_atomic_gemm()) { + comm_overlap->atomic_gemm_overlap_ag( + A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, + te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, + main_stream); + } else { + comm_overlap->split_overlap_ag( + A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, + te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, + main_stream); + } + } else { + if (comm_overlap->is_atomic_gemm()) { + comm_overlap->atomic_gemm_overlap_rs( + A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, + te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, + main_stream); + } else { + comm_overlap->split_overlap_rs( + A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, + te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, + main_stream); + } + } + } else { + // Launch GEMM + nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), + te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), + accumulate, use_split_accumulator, num_math_sms, main_stream); + } } else { if (D_tensor.numel() != 0 && !accumulate) { - D_tensor.zero_(at::cuda::getCurrentCUDAStream()); + D_tensor.zero_(main_stream); } if (bias.has_value()) { if (bias->numel() != 0 && grad) { bias_grad->zero_(); } } - std::vector out; - out.emplace_back(std::move(D)); - out.emplace_back(py::cast(bias_grad)); - if (gelu && !grad) { - out.emplace_back(py::cast(*pre_gelu_out)); - } else { - out.emplace_back(py::none()); - } - return out; } // Pack outputs @@ -200,6 +237,11 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { out.emplace_back(py::none()); } + if (extra_output.has_value()) { + out.emplace_back(py::cast(extra_output)); + } else { + out.emplace_back(py::none()); + } return out; } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 42e496e83b..aa0110b76f 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -75,11 +75,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("otype")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); - m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply", + m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", py::arg("A"), py::arg("transA"), py::arg("B"), py::arg("transB"), py::arg("D"), py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"), py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"), - py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator")); + py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"), + py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, + py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false); m.def("rowwise_swizzle", &rowwise_swizzle, "Swizzle rowwise scale inverses.", py::call_guard()); m.def("columnwise_swizzle", &columnwise_swizzle, "Swizzle columnwise scale inverses.", @@ -116,7 +118,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("fwd_input"), py::arg("quantizer")); m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize", @@ -168,6 +169,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm"); m.def("fused_multi_quantize", &fused_multi_quantize, "Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype")); + m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); m.def("fused_attn_fwd", &fused_attn_fwd, "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); @@ -284,31 +286,28 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard(), py::arg("world_group"), py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none()); - py::class_(m, "CommOverlap") + py::class_, + transformer_engine::CommOverlapBase, + transformer_engine::CommOverlapCore>(m, "CommOverlap") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, - int, int, int, int, bool, bool>(), + int, int, int, int, bool, bool, bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, - py::arg("atomic_gemm") = false) - .def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard()) - .def("split_overlap_rs", &CommOverlap::split_overlap_rs, - py::call_guard()) - .def("atomic_gemm_overlap_rs", &CommOverlap::atomic_gemm_overlap_rs, - py::call_guard()) - .def("copy_input_to_ubuf", &CommOverlap::copy_input_to_ubuf, - py::call_guard()) - .def("get_ubuf_output", &CommOverlap::get_ubuf_output, - py::call_guard()) - .def("set_ubuf_scale_inv", &CommOverlap::set_ubuf_scale_inv, - py::call_guard()) - .def("is_atomic_gemm", &CommOverlap::is_atomic_gemm, py::call_guard()) - .def("is_p2p_overlap", &CommOverlap::is_p2p_overlap, py::call_guard()) - .def("is_fp8_ubuf", &CommOverlap::is_fp8_ubuf, py::call_guard()); + py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) + .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), + py::arg("quantizer"), py::arg("local_chunk")) + .def("get_buffer", &CommOverlap::get_buffer, py::arg("quantizer"), py::arg("local_chunk"), + py::arg("shape") = std::nullopt) + .def("set_buffer_params", &CommOverlap::set_buffer_params); - py::class_(m, "CommOverlapP2P") + py::class_, + transformer_engine::CommOverlapP2PBase, + transformer_engine::CommOverlapCore>(m, "CommOverlapP2P") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, transformer_engine::CommOverlapType, int, int, int, int, int, bool, bool, bool, bool>(), @@ -318,23 +317,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) - .def("split_overlap_ag_p2p", &CommOverlapP2P::split_overlap_ag, - py::call_guard()) - .def("split_overlap_rs_p2p", &CommOverlapP2P::split_overlap_rs, - py::call_guard()) - .def("atomic_gemm_overlap_ag_p2p", &CommOverlapP2P::atomic_gemm_overlap_ag, - py::call_guard()) - .def("atomic_gemm_overlap_rs_p2p", &CommOverlapP2P::atomic_gemm_overlap_rs, - py::call_guard()) - .def("copy_input_to_ubuf", &CommOverlapP2P::copy_input_to_ubuf, - py::call_guard()) - .def("get_ubuf_output", &CommOverlapP2P::get_ubuf_output, - py::call_guard()) - .def("set_ubuf_scale_inv", &CommOverlapP2P::set_ubuf_scale_inv, - py::call_guard()) - .def("is_fp8_ubuf", &CommOverlapP2P::is_fp8_ubuf, py::call_guard()) - .def("is_atomic_gemm", &CommOverlapP2P::is_atomic_gemm, - py::call_guard()) - .def("is_p2p_overlap", &CommOverlapP2P::is_p2p_overlap, - py::call_guard()); + .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), + py::arg("quantizer"), py::arg("local_chunk")) + .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("quantizer"), py::arg("local_chunk"), + py::arg("shape") = std::nullopt) + .def("set_buffer_params", &CommOverlapP2P::set_buffer_params); } diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index e9c7767abf..13b1a15a09 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -172,7 +172,7 @@ std::pair MXFP8Quantizer::create_tensor( at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, columnwise_scale_inv; // TODO(pgadzinski) - change opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - auto last_dim = torch_shape.back(); + auto last_dim = static_cast(torch_shape.back()); at::Tensor data; if (rowwise_usage) { @@ -181,7 +181,8 @@ std::pair MXFP8Quantizer::create_tensor( } else { data = at::empty(torch_shape, opts); } - rowwise_scale_inv = at::empty({numel / last_dim, last_dim / MXFP8_BLOCK_SIZE}, opts); + rowwise_scale_inv = at::empty({static_cast(numel / last_dim), + static_cast(last_dim / MXFP8_BLOCK_SIZE)}, opts); tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); tensor.set_rowwise_scale_inv( rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, @@ -190,7 +191,8 @@ std::pair MXFP8Quantizer::create_tensor( } if (columnwise_usage) { columnwise_data = at::empty(torch_shape, opts); - columnwise_scale_inv = at::empty({numel / (last_dim * MXFP8_BLOCK_SIZE), last_dim}, opts); + columnwise_scale_inv = at::empty({static_cast(numel / (last_dim * MXFP8_BLOCK_SIZE)), + static_cast(last_dim)}, opts); tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); tensor.set_columnwise_scale_inv( diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 19951bb2af..de318fe6f7 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -305,31 +305,33 @@ def get_default_config(name): "is_reduce_scatter": is_reduce_scatter, "num_sm": 1 if method == "ring_exchange" else 16, "cga_size": 1 if method == "ring_exchange" else 2, - "set_sm_margin": False, - "num_splits": 4 if method == "pipeline" else tp_size, + "set_sm_margin": False if method == "ring_exchange" else True, + "num_splits": tp_size if method == "ring_exchange" else 4, "aggregate": False, "atomic_gemm": False, "use_ce": True, "fp8_buf": name in layers_all_gather_overlap, "comm_priority": _MAX_STREAM_PRIORITY, "gemm_priority": _MIN_STREAM_PRIORITY, + "pipeline_rs_overlap_first_gemm": False, } return default_cfg def add_ub( name: str, method: str, - is_reduce_scatter: int, + is_reduce_scatter: bool, num_sm: int = 16, cga_size: int = 2, - set_sm_margin: int = 0, + set_sm_margin: bool = False, num_splits: int = 0, - aggregate: int = 0, - atomic_gemm: int = 0, + aggregate: bool = False, + atomic_gemm: bool = False, use_ce: bool = True, fp8_buf: bool = False, comm_priority: int = 0, gemm_priority: int = 0, + pipeline_rs_overlap_first_gemm: bool = False, ) -> None: if atomic_gemm: warnings.warn( @@ -397,6 +399,7 @@ def add_ub( atomic_gemm=atomic_gemm, gemm_priority=gemm_priority, comm_priority=comm_priority, + rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, ) _ub_communicators[name] = ub_obj @@ -872,8 +875,8 @@ def grad_output_preprocess( if not ctx.ub_overlap_ag: grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) else: - ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True) - grad_output = ctx.ub_obj_gradout.get_ubuf_output(1) + ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, True) + grad_output = ctx.ub_obj_gradout.get_buffer(quantizer, False) return grad_output, None # FP8 with all-gather: unfused bgrad, fused cast + transpose @@ -882,15 +885,20 @@ def grad_output_preprocess( if ctx.use_bias: grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) if ctx.ub_overlap_ag: - # TODO: Implement - raise NotImplementedError( - "Overlapped tensor parallelism with Userbuffers is not yet supported" + # Quantize the gradient if needed + if not isinstance(grad_output, + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + grad_output = quantizer(grad_output) + + # Copy into communication buffer, and replace original gradient with it + ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, True) + grad_output = ctx.ub_obj_gradout.get_buffer(quantizer, False) + else: + grad_output, _ = gather_along_first_dim( + grad_output, + ctx.tp_group, + quantizer=quantizer, ) - grad_output, _ = gather_along_first_dim( - grad_output, - ctx.tp_group, - quantizer=quantizer, - ) return grad_output, grad_bias # FP8 without all-gather: fused bgrad + cast + transpose diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index eb4164947e..074e999ab8 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -6,12 +6,16 @@ import os import warnings from typing import Callable, Dict, Optional, Tuple, Union +from functools import reduce +from operator import mul as multiply_op import torch from torch.nn import init import transformer_engine_torch as tex +from transformer_engine.common.recipe import BlockScaling + from .base import ( get_workspace, get_ub, @@ -98,10 +102,12 @@ def forward( bwd_ln_sm_margin: int, zero_centered_gamma: bool, normalization: str, + ub_overlap_ag_fprop: bool, + ub_overlap_rs_fprop: bool, + ub_overlap_ag_dgrad: bool, + ub_overlap_rs_dgrad: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, - ub_overlap_rs_dgrad: bool, - ub_overlap_ag: bool, ub_name: str, fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, @@ -122,24 +128,30 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) - if ub_overlap_ag: - raise NotImplementedError - tp_world_size = get_distributed_world_size(tp_group) - if tp_world_size == 1 or (not is_grad_enabled): - ub_overlap_ag = False - if ub_overlap_ag: - raise NotImplementedError - dim_size = list(inputmat.size()) - dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub(ub_name + "_fprop") + tp_world_size = get_distributed_world_size(tp_group) + ub_overlap_ag_fprop = ( + ub_overlap_ag_fprop + and is_grad_enabled + and not return_layernorm_output + ) weight_requires_grad = weight.requires_grad backward_needs_input = is_grad_enabled and weight_requires_grad - with_input_all_gather = parallel_mode == "column" and sequence_parallel + with_input_all_gather = ( + parallel_mode == "column" + and sequence_parallel + and not ub_overlap_ag_fprop + ) # Configure quantizer for normalization output - if fp8 and input_quantizer is None: - raise ValueError("Missing quantizer for input tensor") + if fp8: + if (any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) + and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling)): + raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") + + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + with_quantized_norm = fp8 and not return_layernorm_output if with_quantized_norm: if with_input_all_gather: @@ -152,10 +164,22 @@ def forward( columnwise=backward_needs_input, ) + ub_obj_fprop = None + ln_out = None + if ub_overlap_ag_fprop: + ub_obj_fprop = get_ub(ub_name + "_fprop") + ln_out = ub_obj_fprop.get_buffer(input_quantizer, True) + elif with_quantized_norm: + ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda") + else: + ln_out = torch.empty_like( + inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format + ) + # Apply normalization - ln_out, mu, rsigma = apply_normalization( + _, mu, rsigma = apply_normalization( inputmat, - None, + ln_out, ln_weight, ln_bias, eps, @@ -234,7 +258,23 @@ def forward( if weight_quantizer is not None: weight_quantizer.calibrate(weight) - out, _, _ = general_gemm( + ub_obj = None + ub_type = None + rs_out = None + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.RS + out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features] + rs_out = torch.empty(out_shape, dtype=activation_dtype, device=ln_out_total.device) + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.AG + if fp8: + assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer." + ln_out_total = ub_obj.get_buffer(input_quantizer, False) + + out, *_ = general_gemm( weightmat, ln_out_total, get_workspace(), @@ -242,8 +282,9 @@ def forward( out_dtype=activation_dtype, bias=bias, use_split_accumulator=_2X_ACC_FPROP, - ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, - ub=ub_obj_lnout if ub_overlap_ag else None, + ub=ub_obj, + ub_type=ub_type, + extra_output=rs_out, ) if not weight.requires_grad: if not return_layernorm_output: @@ -312,9 +353,10 @@ def forward( ctx.return_layernorm_output_gathered = return_layernorm_output_gathered ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma + ctx.ub_overlap_ag = ub_overlap_ag_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad - ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_name = ub_name ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization @@ -326,10 +368,13 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear - if parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) + if ub_overlap_rs_fprop: + out = rs_out + elif parallel_mode == "row": + if sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif tensor_parallel: + out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP out = out.view(-1, *inp_shape[1:-1], out_features) @@ -349,6 +394,14 @@ def backward( # pylint: disable=missing-function-docstring with torch.cuda.nvtx.range("_LayerNormLinear_backward"): + if (ctx.fp8 + and any([ctx.ub_overlap_ag, + ctx.ub_overlap_rs_dgrad, + ctx.ub_bulk_dgrad, + ctx.ub_bulk_wgrad]) + and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling)): + raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") + saved_tensors = ctx.saved_tensors inputmat, weight, _, bias, ln_weight, ln_out, mu, rsigma = restore_from_saved( ctx.tensor_objects, saved_tensors @@ -386,30 +439,46 @@ def backward( if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: weight.main_grad = main_grad - if ctx.ub_overlap_rs_dgrad: - raise NotImplementedError - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_overlap_rs_dgrad = False - if ctx.ub_bulk_dgrad: - raise NotImplementedError - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not weight.requires_grad: - ctx.ub_bulk_dgrad = False - if ctx.ub_bulk_dgrad: - raise NotImplementedError - dim_size = list(ln_out.size()) - dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub(ctx.ub_name + "_dgrad") - ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) - - if ctx.ub_bulk_wgrad: - raise NotImplementedError - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not weight.requires_grad: - ctx.ub_bulk_wgrad = False + ctx.ub_obj_gradout = None + ub_obj_dgrad = None + ub_obj_wgrad = None + ub_type_dgrad = None + ub_type_wgrad = None + dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] + rs_out = None + dgrad_bulk = None + if ctx.ub_overlap_ag: + # Overlap grad_output all-gather with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + + elif ctx.ub_overlap_rs_dgrad: + # Overlap dgrad reduce-scatter with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.RS + rs_out = torch.empty(dgrad_shape, dtype=ctx.activation_dtype, + device=inputmat.device) + + else: + if ctx.ub_bulk_dgrad: + # Overlap inputmat all-gather with dgrad compute + # NOTE: Copying into communication buffer will always prefer rowwise data, + # and will copy columnwise data if rowwise does not exist. In that case, + # the all-gather will apply to the leading dimension of the transpose, + # which then needs to be interleaved correctly before WGRAD. + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + ub_obj_dgrad.copy_into_buffer(ln_out, ctx.input_quantizer, True) + + if ctx.ub_bulk_wgrad: + # Overlap dgrad reduce-scatter with wgrad compute + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_type_wgrad = tex.CommOverlapType.RS + ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer) + dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, False) ( grad_output, @@ -425,12 +494,15 @@ def backward( # Note: Perform tensor-parallel communication if needed ln_out_total = None ln_out_total_work = None - if ctx.requires_wgrad and ctx.parallel_mode == "column" and ctx.sequence_parallel: + if (ctx.requires_wgrad + and ctx.parallel_mode == "column" + and ctx.sequence_parallel + and not ctx.ub_bulk_dgrad): quantizer = None if ctx.fp8: quantizer = ctx.input_quantizer quantizer.set_usage(rowwise=True, columnwise=True) - ln_out_total, ln_out_total_async = gather_along_first_dim( + ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True, @@ -455,20 +527,29 @@ def backward( if grad_output._transpose is None: grad_output._create_transpose() - dgrad, _, _ = general_gemm( + dgrad, *_= general_gemm( weight, grad_output, get_workspace(), layout="NN", grad=True, quantization_params=ctx.grad_input_quantizer, + out=dgrad_bulk, out_dtype=ctx.activation_dtype, use_split_accumulator=_2X_ACC_DGRAD, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_dgrad, ) # Launch tensor-parallel communication dgrad_work = None - if ctx.parallel_mode == "column": + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + elif ctx.ub_bulk_dgrad: + ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer, False) + elif ctx.parallel_mode == "column": if ctx.sequence_parallel: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) @@ -483,17 +564,37 @@ def backward( # Compute grad weight tensor wgrad = None if ctx.requires_wgrad: - # Synchronize tensor-parallel communication - if ln_out_total_work is not None: - ln_out_total_work.wait() - ln_out_total_work = None + if ctx.ub_bulk_dgrad: + ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer, False) + if ctx.fp8: + # FP8 GEMM on Hopper only supports TN layout so the gathered input must have + # a valid transpose. + if ln_out._data is None: + # All-gather executed on columnwise data and result is in rowwise data, + # so we need to fix the interleaving before WGRAD. + ln_out_total._fix_gathered_transpose(tp_size=ctx.tp_size, + from_rowwise=True) + else: + # Otherwise, we would have all-gathered rowwise data and would need to + # create the transpose (on Hopper). + ln_out_total._create_transpose() + + else: + if ln_out_total_work is not None: + # Synchronize tensor-parallel communication + ln_out_total_work.wait() + ln_out_total_work = None + + if hasattr(ln_out_total, "_create_transpose"): + ln_out_total._create_transpose() # TODO(pgadzinski) - temporary - if hasattr(ln_out_total, "_create_transpose"): - ln_out_total._create_transpose() # TODO(pgadzinski) - temporary + if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): + rs_out = torch.empty(dgrad_shape, dtype=ctx.activation_dtype, + device=inputmat.device) # wgrad GEMM # Note: Fuse with bgrad computation if needed - wgrad, grad_bias_, _ = general_gemm( + wgrad, grad_bias_, *_ = general_gemm( ln_out_total, grad_output, get_workspace(), @@ -506,7 +607,18 @@ def backward( out=main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, accumulate=accumulate_wgrad_into_param_main_grad, + ub=ub_obj_wgrad, + ub_type=ub_type_wgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_wgrad, ) + + if ctx.ub_bulk_wgrad: + if ub_obj_wgrad.is_fp8_ubuf(): + dgrad = rs_out + else: + dgrad = ub_obj_wgrad.get_buffer(None, True) + if grad_bias is None: grad_bias = grad_bias_ del grad_bias_ @@ -616,10 +728,12 @@ def backward( None, # bwd_ln_sm_margin None, # zero_centered_gamma None, # normalization - None, # ub_bulk_wgrad - None, # ub_bulk_dgrad + None, # ub_overlap_ag_fprop + None, # ub_overlap_rs_fprop + None, # ub_overlap_ag_dgrad None, # ub_overlap_rs_dgrad - None, # ub_overlap_ag + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # ub_name None, # fsdp_group None, # module @@ -734,10 +848,11 @@ def __init__( parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", - ub_bulk_wgrad: bool = False, - ub_bulk_dgrad: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, + ub_bulk_wgrad: bool = False, + ub_bulk_dgrad: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -754,13 +869,6 @@ def __init__( self.return_layernorm_output = return_layernorm_output self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma - self.ub_bulk_wgrad = ub_bulk_wgrad - self.ub_bulk_dgrad = ub_bulk_dgrad - self.ub_overlap_ag = ub_overlap_ag - self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad - if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag, ub_overlap_rs_dgrad]): - assert ub_name is not None, "Userbuffer name [string] is not set." - self.ub_name = ub_name if tp_group is None: self.tp_size = tp_size @@ -786,9 +894,55 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel + # Column-parallel overlaps + self.ub_overlap_ag_fprop = ( + ub_overlap_ag + and self.sequence_parallel + and self.parallel_mode == "column" + ) + self.ub_overlap_rs_dgrad = ( + ub_overlap_rs_dgrad + and self.sequence_parallel + and self.parallel_mode == "column" + ) + self.ub_bulk_wgrad = ( + ub_bulk_wgrad + and self.sequence_parallel + and self.parallel_mode == "column" + and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + ub_bulk_dgrad + and self.sequence_parallel + and self.parallel_mode == "column" + and not self.ub_overlap_rs_dgrad + ) + + # Row-parallel overlaps + self.ub_overlap_rs_fprop = ( + ub_overlap_rs + and self.sequence_parallel + and self.parallel_mode == "row" + ) + self.ub_overlap_ag_dgrad = ( + ub_overlap_ag + and self.sequence_parallel + and self.parallel_mode == "row" + ) + if any([ + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + ]): + assert ub_name is not None, "Userbuffer name [string] is not set." + self.ub_name = ub_name + self.eps = eps layer_norm_weight = torch.nn.Parameter( - torch.empty(in_features, device=device, dtype=params_dtype) + torch.empty(self.in_features, device=device, dtype=params_dtype) ) self.register_parameter( "layer_norm_weight", @@ -797,7 +951,7 @@ def __init__( ) if self.normalization != "RMSNorm": layer_norm_bias = torch.nn.Parameter( - torch.empty(in_features, device=device, dtype=params_dtype) + torch.empty(self.in_features, device=device, dtype=params_dtype) ) self.register_parameter( "layer_norm_bias", layer_norm_bias, init_fn=init_method_constant(0.0) @@ -1074,10 +1228,12 @@ def forward( self.bwd_ln_sm_margin, self.zero_centered_gamma, self.normalization, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_rs_dgrad, self.ub_bulk_wgrad, self.ub_bulk_dgrad, - self.ub_overlap_rs_dgrad, - self.ub_overlap_ag, self.ub_name, self.fsdp_group, self, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 647ff3f980..60c9a64635 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -6,6 +6,8 @@ import os import warnings from typing import Callable, Optional, Tuple, Union +from functools import reduce +from operator import mul as multiply_op import torch from torch.nn.parameter import Parameter @@ -13,6 +15,8 @@ import transformer_engine_torch as tex +from transformer_engine.common.recipe import BlockScaling + from .base import ( get_workspace, _ub_communicators, @@ -134,11 +138,11 @@ def forward( zero_centered_gamma: bool, activation: str, normalization: str, + ub_overlap_ag: bool, + ub_overlap_rs: bool, + ub_overlap_rs_dgrad: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, - ub_overlap_rs_dgrad: bool, - ub_overlap_rs: bool, - ub_overlap_ag: bool, gemm_gelu_fusion: bool, fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, @@ -152,6 +156,9 @@ def forward( inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) + if (any([ub_overlap_ag, ub_overlap_rs]) + and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling)): + raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") activation_func = _act_func(activation)[0] device = inp.device @@ -169,23 +176,9 @@ def forward( with_quantized_norm = fp8 and not return_layernorm_output tp_world_size = get_distributed_world_size(tp_group) - ln_out_gathered = False - if ub_overlap_ag: - raise NotImplementedError - if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: - ub_overlap_ag = False - if ub_overlap_ag: - raise NotImplementedError - ub_obj_lnout = get_ub("fc1_fprop") - ln_out = ub_obj_lnout.get_ubuf_output(0) - else: - ln_out_dtype = torch.uint8 if with_quantized_norm else inputmat.dtype - ln_out = torch.empty_like( - inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format - ) - ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs - - with_input_all_gather = tp_world_size > 1 and sequence_parallel + ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output + ub_overlap_rs = ub_overlap_rs and is_grad_enabled + with_input_all_gather = sequence_parallel and not ub_overlap_ag # Configure quantizer for normalization output if fp8 and fc1_input_quantizer is None: @@ -201,10 +194,23 @@ def forward( columnwise=(is_grad_enabled and fc1_weight.requires_grad), ) + ub_obj_lnout = None + ln_out = None + if ub_overlap_ag: + ub_obj_lnout = get_ub("fc1_fprop") + ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, True) + elif with_quantized_norm: + ln_out = fc1_input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, + device="cuda") + else: + ln_out = torch.empty_like( + inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format + ) + # Apply normalization - ln_out, mu, rsigma = apply_normalization( + _, mu, rsigma = apply_normalization( inputmat, - None, + ln_out, ln_weight, ln_bias, eps, @@ -231,34 +237,28 @@ def forward( ) ln_out_gathered = True else: - ln_out_total = ln_out with_quantized_all_gather = False + if ub_overlap_ag: + ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer, False) + else: + ln_out_total = ln_out # If residual connection is after LN, we need `ln_out` # tensor in higher precision, this comes at the cost # of an extra fp8 cast. if return_layernorm_output: ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out - if fp8: - if ub_overlap_ag: - raise NotImplementedError - ln_out = pytex.cast_to_fp8( - ln_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) - elif not with_quantized_all_gather: - ln_out_total = fc1_input_quantizer(ln_out_total) - if ln_out_gathered: - rank = torch.distributed.get_rank(tp_group) - slice_start = rank * ln_out.size(0) - slice_end = (rank + 1) * ln_out.size(0) - ln_out = ln_out_total[ - slice_start:slice_end, ... - ] # TODO(pgadzinski) - check this - else: - ln_out = ln_out_total + if fp8 and not with_quantized_all_gather: + ln_out_total = fc1_input_quantizer(ln_out_total) + if ln_out_gathered: + rank = torch.distributed.get_rank(tp_group) + slice_start = rank * ln_out.size(0) + slice_end = (rank + 1) * ln_out.size(0) + ln_out = ln_out_total[ + slice_start:slice_end, ... + ] # TODO(pgadzinski) - check this + else: + ln_out = ln_out_total # Cast weights to expected dtype fc1_weight_final = fc1_weight @@ -335,9 +335,9 @@ def forward( fc1_bias if not bias_gelu_fusion else None ), # otherwise bias is added later (fused with gelu) gelu=gemm_gelu_fusion, - ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, - ub=ub_obj_lnout if ub_overlap_ag else None, accumulate=_2X_ACC_FPROP, + ub=ub_obj_lnout, + ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None, ) if not is_grad_enabled and (ln_out_total is not ln_out_return): clear_tensor_data(ln_out_total) @@ -348,12 +348,12 @@ def forward( if bias_gelu_fusion: fc1_out = None - fc1_out_without_bias, _, _ = fc1_outputs + fc1_out_without_bias, _, _, _ = fc1_outputs act_out = bias_gelu_fused(fc1_out_without_bias, fc1_bias) elif gemm_gelu_fusion: - act_out, _, fc1_out = fc1_outputs + act_out, _, fc1_out, _ = fc1_outputs else: - fc1_out, _, _ = fc1_outputs + fc1_out, _, _, _ = fc1_outputs act_out = activation_func(fc1_out, fc2_input_quantizer) if not is_grad_enabled: @@ -363,17 +363,16 @@ def forward( fc2_input_quantizer.calibrate(act_out) fc2_weight_quantizer.calibrate(fc2_weight) + ub_obj_fc2out = None + rs_out = None + fc2_out = None if ub_overlap_rs: ub_obj_fc2out = get_ub("fc2_fprop") - fc2_out = ub_obj_fc2out.get_ubuf_output(1) dim_size = list(act_out.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device) - if ub_obj_fc2out.is_p2p_overlap(): - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS + fc2_out = ub_obj_fc2out.get_buffer(output_quantizer, False) else: dim_size = list(act_out.size()) dim_size[1] = fc2_weight.size(0) @@ -389,8 +388,9 @@ def forward( quantization_params=output_quantizer, out=fc2_out, use_split_accumulator=_2X_ACC_FPROP, - ub_algo=ub_algo_rs if ub_overlap_rs else None, - ub=ub_obj_fc2out if ub_overlap_rs else None, + ub=ub_obj_fc2out, + ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None, + extra_output=rs_out, ) if not is_grad_enabled: clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) @@ -440,7 +440,7 @@ def forward( tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, - ln_out, + ln_out.clone() if ub_overlap_ag else ln_out, # avoid saving a UB buffer fc1_weight_final, fc1_bias, fc1_out, @@ -513,7 +513,6 @@ def forward( # Row Parallel Linear if ub_overlap_rs: - raise NotImplementedError fc2_out = rs_out elif set_parallel_mode and sequence_parallel: fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) @@ -537,6 +536,14 @@ def backward( ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with torch.cuda.nvtx.range("_LayerNormMLP_backward"): + if (ctx.fp8 + and any([ctx.ub_overlap_ag, + ctx.ub_overlap_rs_dgrad, + ctx.ub_bulk_dgrad, + ctx.ub_bulk_wgrad]) + and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling)): + raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") + saved_tensors = ctx.saved_tensors ( # pylint: disable=unbalanced-tuple-unpacking inputmat, @@ -590,35 +597,9 @@ def backward( # fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None, # ) - if ctx.ub_overlap_rs_dgrad: - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_overlap_rs_dgrad = False - if ctx.ub_bulk_dgrad: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not ctx.fc1_weight_requires_grad: - ctx.ub_bulk_dgrad = False - if ctx.ub_bulk_dgrad: - dim_size = list(ln_out.size()) - dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub("fc1_dgrad") - ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) - if ctx.ub_overlap_ag: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_overlap_ag = False - - ub_algo = None - if ctx.ub_overlap_ag: - dim_size = list(grad_outputs[0].size()) - dim_size[0] = dim_size[0] * tp_world_size - ctx.ub_obj_gradout = get_ub("fc2_dgrad") - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + # No need to do bulk DGRAD/WGRAD overlap if WGRAD is not required + ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad + ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad # Prepare grad output tensor # Note: Cast to expected dtype and perform tensor-parallel communication @@ -628,6 +609,10 @@ def backward( columnwise=True, # TODO(pgadzinski) - remove ) + ub_obj_fc2_dgrad = None + if ctx.ub_overlap_ag: + ub_obj_fc2_dgrad = get_ub("fc2_dgrad") + ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, fc2_bias_grad, @@ -635,17 +620,14 @@ def backward( ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer ) - if ctx.ub_bulk_wgrad: - raise NotImplementedError - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not ctx.fc1_weight_requires_grad: - ctx.ub_bulk_wgrad = False - # Prepare FC1 GEMM input # Note: Perform tensor-parallel communication if needed ln_out_total = None ln_out_total_work = None - if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: + if (ctx.fc1_weight_requires_grad + and ctx.tensor_parallel + and ctx.sequence_parallel + and not ctx.ub_bulk_dgrad): quantizer = None if ctx.fp8: quantizer = ctx.fc1_input_quantizer @@ -676,21 +658,24 @@ def backward( not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) ) - fc2_wgrad = None # FC2 DGRAD; Unconditional - gemm_output, _, _ = general_gemm( + gemm_output, _, _, _ = general_gemm( fc2_weight, grad_output, get_workspace(), layout="NN", grad=True, - quantization_params=None, # high precision to activation + quantization_params=( + ctx.grad_fc1_output_quantizer + if fc2_dgrad_gemm_gelu_fusion + else None + ), # high precision to activation out_dtype=ctx.activation_dtype, gelu=fc2_dgrad_gemm_gelu_fusion, gelu_in=fc1_out if fc2_dgrad_gemm_gelu_fusion else None, use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=(tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None), - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + ub=ub_obj_fc2_dgrad, + ub_type=tex.CommOverlapType.AG if ctx.ub_overlap_ag else None, ) if fc2_dgrad_gemm_gelu_fusion: dact = gemm_output @@ -702,7 +687,7 @@ def backward( if ctx.fc2_weight_requires_grad: if ctx.fc2_input_quantizer is not None and hasattr(act_out, "_create_transpose"): act_out._create_transpose() - fc2_wgrad, fc2_bias_grad_, _ = general_gemm( + fc2_wgrad, fc2_bias_grad_, _, _ = general_gemm( act_out, grad_output, get_workspace(), @@ -758,76 +743,97 @@ def backward( # Overwrite data. Deleting the tensor does not release underlying memory. clear_tensor_data(fc1_out, fc1_out_without_bias) - fc1_dgrad_size = list(inputmat.size()) - fc1_dgrad_size[1] = fc1_weight.size(1) - if ctx.ub_bulk_wgrad: # allocate dgrad output - raise NotImplementedError - ub_obj_dgrad = get_ub("fc1_wgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - elif ctx.ub_overlap_rs_dgrad: - raise NotImplementedError - ub_obj_dgrad = get_ub("fc1_dgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - - # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap - if ctx.ub_bulk_dgrad: - raise NotImplementedError - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - ub_obj = ub_obj_lnout - elif ctx.ub_overlap_rs_dgrad: - raise NotImplementedError - dim_size = list(inputmat.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = fc1_weight.size(1) - rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=ctx.device) - if ub_obj_dgrad.is_p2p_overlap(): - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ub_obj = ub_obj_dgrad + # Set UB algo and UB obj for fc1_dgrad/wgrad bulk/pipelined overlap + ub_obj_fc1_dgrad = None + ub_obj_fc1_wgrad = None + ub_type_fc1_dgrad = None + fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] + fc1_dgrad_rs_out = None + fc1_dgrad_bulk = None + if ctx.ub_overlap_rs_dgrad: + # Overlap DGRAD+RS + ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_type_fc1_dgrad = tex.CommOverlapType.RS + fc1_dgrad_rs_out = torch.empty(fc1_dgrad_shape, dtype=ctx.activation_dtype, + device="cuda") + else: - ub_algo = None - ub_obj = None + if ctx.ub_bulk_dgrad: + # Overlap ln_out all-gather with DGRAD compute + # NOTE: Copying into communication buffer will always prefer rowwise data, + # and will copy columnwise data if rowwise does not exist. In that case, + # the all-gather will apply to the leading dimension of the transpose, + # which then needs to be interleaved correctly before WGRAD. + ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_type_fc1_dgrad = tex.CommOverlapType.AG + ub_obj_fc1_dgrad.copy_into_buffer(ln_out, ctx.fc1_input_quantizer, True) + + if ctx.ub_bulk_wgrad: + # Overlap FC1 DGRAD reduce-scatter with WGRAD compute + ub_obj_fc1_wgrad = get_ub("fc1_wgrad") + fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None, False) + # FC1 DGRAD: Unconditional - fc1_dgrad, _, _ = general_gemm( + fc1_dgrad, _, _, fc1_dgrad_rs_out = general_gemm( fc1_weight, dact, get_workspace(), + out=fc1_dgrad_bulk, out_dtype=ctx.activation_dtype, layout="NN", grad=True, - ub_algo=ub_algo, - ub=ub_obj, - # extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, + ub=ub_obj_fc1_dgrad, + ub_type=ub_type_fc1_dgrad, + extra_output=fc1_dgrad_rs_out, + bulk_overlap=ctx.ub_bulk_dgrad, ) - if ctx.ub_bulk_dgrad: - raise NotImplementedError - ln_out_total = ub_obj_lnout.get_ubuf_output(1) # Overlap dgrad-RS/AR with wgrad fc1_dgrad_work = None - if ctx.set_parallel_mode and ctx.sequence_parallel: - if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: - fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad) - fc1_dgrad, fc1_dgrad_work = reduce_scatter_along_first_dim( - fc1_dgrad, - ctx.tp_group, - async_op=True, - ) - elif ctx.set_parallel_mode and ctx.tensor_parallel: - fc1_dgrad, fc1_dgrad_work = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) + if ctx.ub_overlap_rs_dgrad: + fc1_dgrad = fc1_dgrad_rs_out + elif ctx.set_parallel_mode and not ctx.ub_bulk_wgrad: + if ctx.sequence_parallel: + if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: + fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad) + fc1_dgrad, fc1_dgrad_work = reduce_scatter_along_first_dim( + fc1_dgrad, + ctx.tp_group, + async_op=True, + ) + elif ctx.tensor_parallel: + fc1_dgrad, fc1_dgrad_work = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) # FC1 WGRAD fc1_wgrad = None if ctx.fc1_weight_requires_grad: + if ctx.ub_bulk_dgrad: + ln_out_total = ub_obj_fc1_dgrad.get_buffer(ctx.fc1_input_quantizer, False) + if ctx.fp8: + # FP8 GEMM on Hopper only supports TN layout so the gathered input must have + # a valid transpose. + if ln_out._data is None: + # All-gather executed on columnwise data and result is in rowwise data, + # so we need to fix the interleaving before WGRAD. + ln_out_total._fix_gathered_transpose(tp_size=ctx.tp_size, + from_rowwise=True) + else: + # Otherwise, we would have all-gathered rowwise data and would need to + # create the transpose (on Hopper). + ln_out_total._create_transpose() + + else: + if ln_out_total_work is not None: + # Synchronize tensor-parallel communication + ln_out_total_work.wait() + ln_out_total_work = None - # Synchronize tensor-parallel communication - if ln_out_total_work is not None: - ln_out_total_work.wait() - ln_out_total_work = None + if hasattr(ln_out_total, "_create_transpose"): + ln_out_total._create_transpose() # TODO(pgadzinski) - temporary - if hasattr(ln_out_total, "_create_transpose"): - ln_out_total._create_transpose() # TODO(pgadzinski) - temporary + if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf(): + fc1_dgrad_rs_out = torch.empty(fc1_dgrad_shape, dtype=ctx.activation_dtype, + device="cuda") fc1_wgrad_outputs = general_gemm( ln_out_total, @@ -839,19 +845,24 @@ def backward( bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, accumulate=accumulate_wgrad_into_param_main_grad, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + ub=ub_obj_fc1_wgrad, + ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None, + extra_output=fc1_dgrad_rs_out, + bulk_overlap=ctx.ub_bulk_wgrad, ) clear_tensor_data(ln_out_total, dact) if fuse_gemm_and_bias_fc1_wgrad: - fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs + fc1_wgrad, fc1_bias_grad, _, _ = fc1_wgrad_outputs else: - fc1_wgrad, _, _ = fc1_wgrad_outputs + fc1_wgrad, _, _, _ = fc1_wgrad_outputs if ctx.ub_bulk_wgrad: - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output + if ub_obj_fc1_wgrad.is_fp8_ubuf(): + fc1_dgrad = fc1_dgrad_rs_out + else: + fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(None, True) # Synchronize tensor parallel communication if ln_out_total_work is not None: @@ -977,11 +988,11 @@ def backward( None, # zero_centered_gamma None, # activation None, # normalization - None, # ub_bulk_wgrad - None, # ub_bulk_dgrad - None, # ub_overlap_rs_dgrad - None, # ub_overlap_rs None, # ub_overlap_ag + None, # ub_overlap_rs + None, # ub_overlap_rs_dgrad + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # gemm_gelu_fusion None, # fsdp_group None, # module @@ -1106,11 +1117,11 @@ def __init__( set_parallel_mode: bool = False, zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", - ub_bulk_wgrad: bool = False, - ub_bulk_dgrad: bool = False, - ub_overlap_rs_dgrad: bool = False, - ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, ) -> None: super().__init__() @@ -1129,11 +1140,7 @@ def __init__( ) self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma - self.ub_bulk_wgrad = ub_bulk_wgrad - self.ub_bulk_dgrad = ub_bulk_dgrad - self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad - self.ub_overlap_rs = ub_overlap_rs - self.ub_overlap_ag = ub_overlap_ag + # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap self.gemm_gelu_fusion = ( bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) @@ -1158,6 +1165,20 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.size_per_partition = divide(ffn_hidden_size, self.tp_size) + self.ub_overlap_ag = ub_overlap_ag and self.sequence_parallel + self.ub_overlap_rs = ub_overlap_rs and self.sequence_parallel + self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad and self.sequence_parallel + self.ub_bulk_wgrad = ( + ub_bulk_wgrad + and self.sequence_parallel + and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + ub_bulk_dgrad + and self.sequence_parallel + and not self.ub_overlap_rs_dgrad + ) + # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1385,11 +1406,11 @@ def forward( self.zero_centered_gamma, self.activation, self.normalization, - self.ub_bulk_wgrad, - self.ub_bulk_dgrad, - self.ub_overlap_rs_dgrad, - self.ub_overlap_rs, self.ub_overlap_ag, + self.ub_overlap_rs, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, self.gemm_gelu_fusion, self.fsdp_group, self, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 96de3861b8..cc62d2c655 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -4,11 +4,15 @@ """Linear API""" from typing import Callable, Dict, Optional, Tuple, Union +from functools import reduce +from operator import mul as multiply_op import torch import transformer_engine_torch as tex +from transformer_engine.common.recipe import BlockScaling + from .base import ( get_workspace, get_ub, @@ -83,8 +87,12 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, - ub_overlap_rs: bool, - ub_overlap_ag: bool, + ub_overlap_rs_fprop: bool, + ub_overlap_ag_dgrad: bool, + ub_overlap_ag_fprop: bool, + ub_overlap_rs_dgrad: bool, + ub_bulk_dgrad: bool, + ub_bulk_wgrad: bool, ub_name: str, fp8_output: bool, fsdp_group: Union[dist_group_type, None], @@ -94,25 +102,31 @@ def forward( # pylint: disable=missing-function-docstring # Make sure input dimensions are compatible - _, in_features = weight.shape + out_features, in_features = weight.shape inp_shape = inp.shape assert inp_shape[-1] == in_features, "GEMM not possible" tp_world_size = get_distributed_world_size(tp_group) - ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs - backward_needs_input = is_grad_enabled and weight.requires_grad # Prepare input tensor # Note: Cast to expected dtype and perform tensor-parallel communication inputmat = inp inputmat_total = None - with_input_all_gather = parallel_mode == "column" and sequence_parallel + with_input_all_gather_nccl = ( + parallel_mode == "column" + and sequence_parallel + and not ub_overlap_ag_fprop + ) own_quantized_input = False if fp8: + if (any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) + and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling)): + raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") + if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - if with_input_all_gather: + if with_input_all_gather_nccl: assert not isinstance( inputmat, QuantizedTensor ), "All gather of fp8 input is not supported" @@ -134,7 +148,7 @@ def forward( inputmat_total = inputmat else: inputmat = cast_if_needed(inp, activation_dtype) - if with_input_all_gather: + if with_input_all_gather_nccl: inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) else: inputmat_total = inputmat @@ -183,35 +197,35 @@ def forward( if weight_quantizer is not None: weight_quantizer.calibrate(weight) - if ub_overlap_rs: - # I think this should be inside the gemm call rather than linear - ub_obj_projout = get_ub(ub_name + "_fprop") - ub_buffer = ub_obj_projout.get_ubuf_output(1) - if ub_obj_projout.is_p2p_overlap(): - if ub_obj_projout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ub_obj_projout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS - if fp8 and ub_obj_projout.is_fp8_ubuf(): - assert fp8_output - ub_obj_projout.set_ubuf_scale_inv(torch.reciprocal(output_quantizer.scale)) - - out, _, _ = general_gemm( + ub_obj = None + ub_type = None + rs_out = None + out_dtype = activation_dtype + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.RS + out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features] + rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device) + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.AG + if fp8: + assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer." + ub_obj.copy_into_buffer(inputmat_total, input_quantizer, True) + inputmat_total = ub_obj.get_buffer(input_quantizer, False) + + out, _, _, rs_out = general_gemm( weightmat, inputmat_total, get_workspace(), quantization_params=output_quantizer, - out_dtype=activation_dtype, + out_dtype=out_dtype, bias=bias, use_split_accumulator=_2X_ACC_FPROP, - ub_algo=ub_algo if ub_overlap_rs else None, - ub=ub_obj_projout if ub_overlap_rs else None, - ub_buffer=ub_buffer if ub_overlap_rs else None, + ub=ub_obj, + ub_type=ub_type, + extra_output=rs_out, ) if is_grad_enabled: @@ -263,7 +277,10 @@ def forward( ctx.inp_shape = inp_shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group - ctx.ub_overlap_ag = ub_overlap_ag + ctx.ub_overlap_ag = ub_overlap_ag_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad @@ -278,12 +295,15 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear - if not ub_overlap_rs: - if parallel_mode == "row" and sequence_parallel: + if ub_overlap_rs_fprop: + out = rs_out + elif parallel_mode == "row": + if sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: + elif tensor_parallel: out, _ = allreduce(out, tp_group) + out = out.view(-1, *inp_shape[1:-1], out_features) return out @staticmethod @@ -291,6 +311,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # pylint: disable=missing-function-docstring with torch.cuda.nvtx.range("_Linear_backward"): + if (ctx.fp8 + and any([ctx.ub_overlap_ag, + ctx.ub_overlap_rs_dgrad, + ctx.ub_bulk_dgrad, + ctx.ub_bulk_wgrad]) + and isinstance(FP8GlobalStateManager.get_fp8_recipe(), BlockScaling)): + recipe = FP8GlobalStateManager.get_fp8_recipe() + print(f"FP8 Recipe: {type(recipe)} -> {recipe}") + raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") + saved_tensors = ctx.saved_tensors inputmat, weight_fp8, weight, bias = ( restore_from_saved( # pylint: disable=unbalanced-tuple-unpacking @@ -319,17 +349,46 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weight_fp8, ) - tp_world_size = get_distributed_world_size(ctx.tp_group) - ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag - ub_algo = None + ctx.ub_obj_gradout = None + ub_obj_dgrad = None + ub_obj_wgrad = None + ub_type_dgrad = None + ub_type_wgrad = None + dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] + rs_out = None + dgrad_bulk = None if ctx.ub_overlap_ag: - dim_size = list(grad_output.size()) - dim_size[0] = dim_size[0] * tp_world_size + # Overlap grad_output all-gather with dgrad compute ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + + elif ctx.ub_overlap_rs_dgrad: + # Overlap dgrad reduce-scatter with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.RS + rs_out = torch.empty(dgrad_shape, dtype=ctx.activation_dtype, + device=grad_output.device) + + else: + if ctx.ub_bulk_dgrad: + # Overlap inputmat all-gather with dgrad compute + # NOTE: Copying into communication buffer will always prefer rowwise data, + # and will copy columnwise data if rowwise does not exist. In that case, + # the all-gather will apply to the leading dimension of the transpose, + # which then needs to be interleaved correctly before WGRAD. + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + ub_obj_dgrad.copy_into_buffer(inputmat, ctx.input_quantizer, True) + + if ctx.ub_bulk_wgrad: + # Overlap dgrad reduce-scatter with wgrad compute + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_type_wgrad = tex.CommOverlapType.RS + ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer) + dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, False) # Prepare grad output tensor # Note: Cast to expected dtype and perform tensor-parallel communication @@ -352,7 +411,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: Perform tensor-parallel communication if needed inputmat_total = None inputmat_total_work = None - if ctx.requires_wgrad and ctx.parallel_mode == "column" and ctx.sequence_parallel: + if (ctx.requires_wgrad + and ctx.parallel_mode == "column" + and ctx.sequence_parallel + and not ctx.ub_bulk_dgrad): quantizer = None if ctx.fp8: quantizer = ctx.input_quantizer @@ -384,21 +446,26 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # dgrad GEMM - dgrad, _, _ = general_gemm( + dgrad, _, _, rs_out = general_gemm( weight_fp8, grad_output, get_workspace(), layout="NN", grad=True, quantization_params=ctx.grad_input_quantizer, + out=dgrad_bulk, out_dtype=ctx.activation_dtype, use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo if ctx.ub_overlap_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_dgrad, ) # Launch tensor-parallel communication - if ctx.parallel_mode == "column": + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + elif ctx.parallel_mode == "column": if ctx.sequence_parallel: dgrad, dgrad_work = reduce_scatter_along_first_dim( dgrad, @@ -411,28 +478,38 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Compute grad weight tensor wgrad = None if ctx.requires_wgrad: - - # Synchronize tensor-parallel communication - if inputmat_total_work is not None: - inputmat_total_work.wait() - inputmat_total_work = None - - if ctx.fp8: - # TODO: deal with this - if ctx.ub_overlap_ag: - raise NotImplementedError - if isinstance(grad_output_c, QuantizedTensor): - grad_output_t = grad_output_c.transpose_2d() + if ctx.ub_bulk_dgrad: + inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer, False) + if ctx.fp8: + # FP8 GEMM on Hopper only supports TN layout so the gathered input must have + # a valid transpose. + if inputmat._data is None: + # All-gather executed on columnwise data and result is in rowwise data, + # so we need to fix the interleaving before WGRAD. + inputmat_total._fix_gathered_transpose(tp_size=ctx.tp_size, + from_rowwise=True) else: - grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) + # Otherwise, we would have all-gathered rowwise data and would need to + # create the transpose (on Hopper). + inputmat_total._create_transpose() + + else: + if inputmat_total_work is not None: + # Synchronize tensor-parallel communication + inputmat_total_work.wait() + inputmat_total_work = None if isinstance(grad_output, QuantizedTensor): if grad_output._transpose is None: grad_output._create_transpose() + if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): + rs_out = torch.empty(dgrad_shape, dtype=ctx.activation_dtype, + device=grad_output.device) + # wgrad GEMM # Note: Fuse with bgrad computation if needed - wgrad, grad_bias_, _ = general_gemm( + wgrad, grad_bias_, _, rs_out = general_gemm( inputmat_total, grad_output, get_workspace(), @@ -445,7 +522,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], out=main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, accumulate=accumulate_wgrad_into_param_main_grad, + ub=ub_obj_wgrad, + ub_type=ub_type_wgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_wgrad, ) + + if ctx.ub_bulk_wgrad: + if ub_obj_wgrad.is_fp8_ubuf(): + dgrad = rs_out + else: + dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, True) + if grad_bias is None: grad_bias = grad_bias_ del grad_bias_ @@ -515,8 +603,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # activation_dtype None, # parallel_mode None, # is_grad_enabled - None, # ub_overlap_rs - None, # ub_overlap_ag + None, # ub_overlap_rs_fprop + None, # ub_overlap_ag_dgrad + None, # ub_overlap_ag_fprop + None, # ub_overlap_rs_dgrad + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # ub_name None, # fp8_output None, # fsdp_group @@ -612,8 +704,11 @@ def __init__( parallel_mode: Optional[str] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, device: Union[torch.device, str] = "cuda", - ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -625,11 +720,6 @@ def __init__( self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias - self.ub_overlap_rs = ub_overlap_rs - self.ub_overlap_ag = ub_overlap_ag - if ub_overlap_rs or ub_overlap_ag: - assert ub_name is not None, "Userbuffer name [string] is not set." - self.ub_name = ub_name self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name @@ -656,6 +746,55 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel + # Column parallel TP overlap options + self.ub_overlap_ag_fprop = ( + self.parallel_mode == "column" + and self.sequence_parallel + and ub_overlap_ag + ) + self.ub_overlap_rs_dgrad = ( + self.parallel_mode == "column" + and self.sequence_parallel + and ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + self.parallel_mode == "column" + and self.sequence_parallel + and ub_bulk_dgrad + and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_wgrad = ( + self.parallel_mode == "column" + and self.sequence_parallel + and ub_bulk_wgrad + and not self.ub_overlap_rs_dgrad + ) + + # Row parallel TP overlap options + self.ub_overlap_rs_fprop = ( + self.parallel_mode == "row" + and self.sequence_parallel + and ub_overlap_rs + ) + self.ub_overlap_ag_dgrad = ( + self.parallel_mode == "row" + and self.sequence_parallel + and ub_overlap_ag + ) + + if any( + [ + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + ] + ): + assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." + self.ub_name = ub_name + # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -893,8 +1032,12 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), - self.ub_overlap_rs, - self.ub_overlap_ag, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, self.ub_name, fp8_output, self.fsdp_group, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index b90e1ad707..a35735c252 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -261,6 +261,29 @@ def _create_transpose(self): self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose) self._transpose_invalid = False + def _fix_gathered_transpose(self, tp_size=1, from_rowwise=False): + assert tp_size > 1, "The tensor transpose cannot be interleaved when TP size is 1" + if from_rowwise: + assert self._data is not None, "The tensor does not hold any rowwise data" + data = self._data + else: + assert self._transpose is not None, "The tensor does not hold any columwise data" + assert not self._transpose_invalid, "The tensor's columnwise data is not valid" + data = self._transpose + + if tp_size == 1: + self._transpose = data + else: + assert data.shape[0] % tp_size == 0, ( + "Leading dimension of data is not divisble by TP size" + ) + interleaved_shape = [tp_size, data.shape[0] // tp_size, *data.shape[1:]] + self._transpose = data.view(interleaved_shape).transpose(0, 1).contiguous() + self._transpose_invalid = False + + if from_rowwise: + self._data = None + def update_usage(self, rowwise_usage=True, columnwise_usage=True): assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor" if rowwise_usage: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 7c3da9a73f..97b1361163 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -267,11 +267,11 @@ def __init__( zero_centered_gamma: bool = False, qkv_weight_interleaved: bool = True, ub_tp_comm_overlap: bool = False, - ub_bulk_wgrad: bool = True, - ub_bulk_dgrad: bool = True, ub_overlap_ag: bool = True, ub_overlap_rs: bool = True, ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = True, + ub_bulk_wgrad: bool = True, bias: bool = True, activation: str = "gelu", normalization: str = "LayerNorm",