Skip to content

Commit

Permalink
C++ code and TE/PyTorch general_gemm updated to support TP overlap wi…
Browse files Browse the repository at this point in the history
…th cppqtensor

Signed-off-by: Alp Dener <adener@nvidia.com>

CommOverlap objects can now return overlap buffers to PyTorch as QuantizedTensors

Signed-off-by: Alp Dener <adener@nvidia.com>

updated comm+GEMM overlap test for pure GEMM, both BF16 and FP8 working with QuantizedTensor

Signed-off-by: Alp Dener <adener@nvidia.com>

te.Linear and te.LayerNormMLP updated for TP overlap w/ QuantizedTensor. All overlaps work in BF16. All ovrlaps except bulk WGRAD work in FP8.

Signed-off-by: Alp Dener <adener@nvidia.com>

completed TP overlap QuantizedTensor updates for LayerNormLinear, but issues with quantized normalization

Signed-off-by: Alp Dener <adener@nvidia.com>

all overlaps working with bf16, all but bulk WGRAD working with FP8

Signed-off-by: Alp Dener <adener@nvidia.com>

all overlaps work with Float8Tensor, except bulk wgrad in LayerNormMLP (works in other modules)

Signed-off-by: Alp Dener <adener@nvidia.com>

all overlaps working with QuantizedTensor in BF16 and FP8

Signed-off-by: Alp Dener <adener@nvidia.com>

cleaned up pytest formatting

Signed-off-by: Alp Dener <adener@nvidia.com>
  • Loading branch information
denera committed Jan 28, 2025
1 parent b653134 commit f1dcf35
Show file tree
Hide file tree
Showing 22 changed files with 1,556 additions and 1,441 deletions.
284 changes: 112 additions & 172 deletions tests/pytorch/distributed/run_gemm_with_overlap.py

Large diffs are not rendered by default.

96 changes: 63 additions & 33 deletions tests/pytorch/distributed/run_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import socket
import argparse
import warnings
import pprint

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -39,6 +40,8 @@ def _te_layer_argtype(name):

def _get_layer_args(config, tp_group, tp_size, reference=False):
hidden_size = config.num_heads * config.head_dim
ffn_hidden_size = 4 * hidden_size
qkv_size = 3 * hidden_size
input_shape = [config.seq_length, config.batch_size, hidden_size]
args = [hidden_size]
kwargs = {
Expand All @@ -47,38 +50,41 @@ def _get_layer_args(config, tp_group, tp_size, reference=False):
"tp_group": tp_group,
"tp_size": tp_size,
"sequence_parallel": True,
"ub_overlap_ag": not reference,
"ub_overlap_rs": not reference,
}
kwargs["ub_overlap_ag"] = not reference

if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["parallel_mode"] = "row"
kwargs["ub_overlap_rs"] = not reference
kwargs["ub_name"] = "proj"

if config.layer_type in [te.Linear, te.LayerNormLinear]:
if config.linear_parallel_mode == "row":
input_shape[-1] = ffn_hidden_size // tp_size
args = [ffn_hidden_size, hidden_size]
kwargs["ub_name"] = "proj" if config.layer_type == te.Linear else "fc2"
elif config.linear_parallel_mode == "column":
input_shape[0] = config.seq_length // tp_size
args.append(qkv_size)
kwargs["ub_name"] = "qkv"
kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["parallel_mode"] = config.linear_parallel_mode
else:
input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not reference
kwargs["ub_bulk_dgrad"] = not reference
if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column"
kwargs["ub_name"] = "qkv"
else:
kwargs["set_parallel_mode"] = True
kwargs["ub_overlap_rs"] = not reference
if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]:
args.append(4 * hidden_size)
kwargs["seq_length"] = config.seq_length
if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
args.append(config.num_heads)
kwargs["attention_dropout"] = 0.0
kwargs["fuse_qkv_params"] = True
if config.layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
else:
kwargs["ub_tp_comm_overlap"] = not reference
kwargs["hidden_dropout"] = 0.0
if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]:
args.append(ffn_hidden_size)
kwargs["seq_length"] = config.seq_length
if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
args.append(config.num_heads)
kwargs["attention_dropout"] = 0.0
kwargs["fuse_qkv_params"] = True
if config.layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
else:
kwargs["ub_tp_comm_overlap"] = not reference
kwargs["hidden_dropout"] = 0.0
kwargs["set_parallel_mode"] = True
kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference

return args, kwargs, input_shape

Expand Down Expand Up @@ -125,6 +131,19 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs."
)
parser.add_argument(
"--linear-parallel-mode",
type=str.lower,
default="row",
choices=["row", "column"],
help="Parallel mode for te.Linear.",
)
parser.add_argument(
"--overlap-rs-dgrad",
action="store_true",
default=False,
help="Replace bulk DGRAD/WGRAD overlaps with DGRAD+RS in the backward pass for AG+GEMM."
)
parser.add_argument(
"--debug",
action="store_true",
Expand Down Expand Up @@ -154,7 +173,7 @@ def _compare_tensors(name, test, ref, rtol, atol):
)
return 1, numerics_info

diff = torch.abs(test - ref).flatten()
diff = torch.abs(test.flatten() - ref.flatten())
m = torch.argmax(diff)
abs_err = diff[m].item()
rel_err = abs_err / max(abs(ref.flatten()[m].item()), 1e-5)
Expand Down Expand Up @@ -230,19 +249,30 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")

# Intialize userbuffers
ub_cfgs = None
if opts.overlap_rs_dgrad:
ub_cfgs = {
"qkv_dgrad" : { "method" : "ring_exchange" },
"fc1_dgrad" : { "method" : "ring_exchange" },
}
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
WORLD_SIZE,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
ub_cfgs=ub_cfgs,
)

# Initialize the Transformer Engine layer with overlap
args, kwargs, input_shape = _get_layer_args(opts, nccl_world, WORLD_SIZE)
with te.fp8_model_init(enabled=opts.fp8_init):
test_model = opts.layer_type(*args, **kwargs)
dist_print("Initialized test model...", debug=True)
if WORLD_RANK == 0:
pprint.pprint(kwargs)
sys.stdout.write("\n")
dist.barrier()

# Initialize the reference model and copy all parameters
ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, WORLD_SIZE, reference=True)
Expand Down Expand Up @@ -277,8 +307,8 @@ def run_fwd_bwd(model, x):
out, *_ = y
else:
out = y
loss = out.sum()
loss.backward()
loss = out.sum()
loss.backward()
return out

torch_rng_state = torch.get_rng_state()
Expand Down Expand Up @@ -333,7 +363,7 @@ def run_fwd_bwd(model, x):
dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
numerics_failed[0] = int(grad_failed)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
if bool(numerics_failed.item()):
if bool(numerics_failed.item()) and not opts.debug:
break

te.module.base.destroy_ub()
Expand Down
Loading

0 comments on commit f1dcf35

Please sign in to comment.