Skip to content

Commit

Permalink
Fixed merge conflict
Browse files Browse the repository at this point in the history
Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
  • Loading branch information
Selvaraj Anandaraj committed Jan 18, 2024
2 parents 3e4b3d5 + bcdc562 commit 7b70947
Show file tree
Hide file tree
Showing 47 changed files with 2,331 additions and 1,177 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.3.0dev
1.4.0dev
53 changes: 53 additions & 0 deletions examples/pytorch/fsdp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

# Basic Example for Using PyTorch Fully Sharded Data Parallel mode with Transformer Engine

```bash
# FSDP without deferred initialization:
# Duplicate modules initialized on each device. Load on device memory reduced only after
# torch.distributed.fsdp.FullyShardedDataParallel mode shards model parameters.
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py
# Sample output on 8xL40S:
# [GPU-0] WORLD_SIZE = 8
# [GPU-0] TransformerEngine Model:
# TransformerLayer(
# (self_attention): MultiheadAttention(
# (layernorm_qkv): LayerNormLinear()
# (core_attention): DotProductAttention(
# (flash_attention): FlashAttention()
# (fused_attention): FusedAttention()
# (unfused_attention): UnfusedDotProductAttention(
# (scale_mask_softmax): FusedScaleMaskSoftmax()
# (attention_dropout): Dropout(p=0.1, inplace=False)
# )
# )
# (proj): Linear()
# )
# (layernorm_mlp): LayerNormMLP()
# )
# [GPU-0] Pre-FSDP memory use = 83.935232MiB
# [GPU-0] Post-FSDP memory use = 10.491904MiB
# [GPU-0] Iter. 1
# [GPU-0] Iter. 2
# [GPU-0] Iter. 3
# [GPU-0] Training Time: 6.647654296875s
# [GPU-0] Avg. Iter. Time: 2.2158847656250003s
# [GPU-0] Peak memory use = 3000MiB

# FSDP with deferred initialization:
# Modules initialized with empty paramaters via `device='meta'` option. Zero load on device
# memory until torch.distributed.fsdp.FullyShardedDataParallel mode triggers a reset on
# on already sharded model parameters.
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py --defer-init
# Sample output on 8xL40S:
# [GPU-0] WORLD_SIZE = 8
# ...
# [GPU-0] Pre-FSDP memory use = 0.0MiB
# [GPU-0] Post-FSDP memory use = 10.491904MiB
# ...
```

**NOTE:** This example has `fp8_autocast()` enabled by default. To run on GPUs without Fp8 support
(e.g.: A100), add the `--no-fp8` option to the commands shown above.
195 changes: 195 additions & 0 deletions examples/pytorch/fsdp/fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
import argparse
from functools import partial

import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

def lowercase(s):
return str(s).lower()

def torch_dtype(d):
typemap = {
'fp32' : torch.float32,
'float32' : torch.float32,
'fp16' : torch.float16,
'float16' : torch.float16,
'bf16' : torch.bfloat16,
'bfloat16' : torch.bfloat16
}
if lowercase(d) not in typemap.keys():
raise TypeError
return typemap[lowercase(d)]

te_layer_map = {
'linear': te.Linear,
'layernorm': te.LayerNorm,
'rmsnorm': te.RMSNorm,
'layernormlinear': te.LayerNormLinear,
'layernormmlp': te.LayerNormMLP,
'multiheadattention': te.MultiheadAttention,
'transformerlayer': te.TransformerLayer
}
def te_layer(l):
if lowercase(l) not in te_layer_map.keys():
raise TypeError
return te_layer_map[lowercase(l)]

def get_layer_args(args):
hidden_size = args.num_heads * args.head_dim
layer_args = (hidden_size, )
layer_kwargs = {
'params_dtype': args.dtype,
'device': 'meta' if args.defer_init else 'cuda'
}
if args.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 3 * hidden_size if args.num_layers == 1 else hidden_size
layer_args += (ffn_hidden_size, )
layer_kwargs['bias'] = True
if args.layer_type == te.LayerNormMLP:
layer_kwargs['seq_length'] = args.seq_length
elif args.layer_type == te.MultiheadAttention:
layer_args += (args.num_heads, )
layer_kwargs['fuse_qkv_params'] = True
elif args.layer_type == te.TransformerLayer:
layer_args += (3 * hidden_size, args.num_heads)
layer_kwargs['fuse_qkv_params'] = True
layer_kwargs['seq_length'] = args.seq_length
return layer_args, layer_kwargs

def parse_fsdp_args():
parser = argparse.ArgumentParser(description="Run Transformer Engine modules with the " +
"torch.distributed.fsdp.FullyShardedDataParallel strategy.")
parser.add_argument("-t", "--layer-type", type=te_layer, default=te.TransformerLayer,
choices=list(te_layer_map.values()),
help="TE module type used to construct the test model.")
parser.add_argument("--no-fp8", action="store_true", default=False,
help="Disables the te.fp8_autocast() context.")
parser.add_argument('-i', "--num-iters", type=int, default=3,
help="Number of dummy 'training' iterations.")
parser.add_argument('-b', "--batch-size", type=int, default=32,
help="Input batch size.")
parser.add_argument('-s', "--seq-length", type=int, default=1048,
help="Input sequence length.")
parser.add_argument('-n', "--num-heads", type=int, default=16,
help="Number of attention heads.")
parser.add_argument('-d', "--head-dim", type=int, default=128,
help="Dimension of each attention head (number of KV channels).")
parser.add_argument('-l', "--num-layers", type=int, default=1,
help="Number of modules chained together with nn.Sequential.")
parser.add_argument("--seed", type=int, default=1234,
help="PyTorch RNG seed.")
parser.add_argument("--defer-init", action="store_true",
help="Defer module parameter initialization until after FSDP sharding.")
parser.add_argument('-v', "--verbose", action="store_true", default=False,
help="Print out information from all GPUs instead of only the root GPU-0.")
parser.add_argument("--dtype", type=torch_dtype, default=torch.bfloat16,
help="Data type for input tensor and Transformer Engine module parameters.")
return parser.parse_args()

def train(args):
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])

# Initialize torch.distributed global process group
dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
if local_rank == 0:
print(f"[GPU-0] WORLD_SIZE = {world_size}\n\n", end='')
torch.manual_seed(args.seed)

# Construct a simple homogeneous model (only one layer type) with NO PARALLELISM
layer_args, layer_kwargs = get_layer_args(args)
if args.num_layers > 1:
te_layer_list = []
for i in range(args.num_layers):
if args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
layer_kwargs['layer_number'] = i+1
te_layer_list.append(args.layer_type(*layer_args, **layer_kwargs))
te_model = nn.Sequential(*te_layer_list)
else:
# Single layer model
te_model = args.layer_type(*layer_args, **layer_kwargs)
if local_rank == 0:
print(f"[GPU-0] TransformerEngine Model:\n{te_model}\n", end='')

# Print out allocated device memory before the model parameters are sharded by FSDP
pre_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6
if local_rank == 0 or args.verbose:
print(f"[GPU-{local_rank}] Pre-FSDP memory use = {pre_mem_use}MiB\n", end='')

# Wrap the model with FSDP
# NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and
# controls all communication.
all_gpus = dist.new_group(backend='nccl')
fsdp_wrap_policy = always_wrap_policy
if args.layer_type == te.TransformerLayer:
# NOTE: FSDP causes illegal memory access without this special policy for Transformers
fsdp_wrap_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls={te.TransformerLayer})
te_model = FullyShardedDataParallel(te_model,
process_group=all_gpus,
use_orig_params=True,
mixed_precision=MixedPrecision(
param_dtype=args.dtype,
reduce_dtype=torch.float32,
),
sync_module_states=True,
auto_wrap_policy=fsdp_wrap_policy)

# Print out allocated device memory after the model parameters are sharded
post_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6
if local_rank == 0 or args.verbose:
print(f"[GPU-{local_rank}] Post-FSDP memory use = {post_mem_use}MiB\n", end='')

# Fp8 setup for TE
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")

# Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded
optim = torch.optim.Adam(te_model.parameters(), lr=0.0001)

# Start and time dummy "training" iterations
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for i in range(args.num_iters):
# Generate a random input batch
x = torch.rand(args.seq_length, args.batch_size,
args.num_heads*args.head_dim).to(dtype=args.dtype).cuda()
# fp8_autocast needs to be given the FSDP process group for amax reductions
with te.fp8_autocast(enabled=not args.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
y = te_model(x)
loss = y.sum()
# calculate gradient and take training step outside the fp8_autocast context
loss.backward()
optim.step()
del x
if local_rank == 0:
print(f"[GPU-0] Iter. {i+1}\n", end='')
end.record()
torch.cuda.synchronize()

# Print out "training" time and peak memory use stats
train_time = start.elapsed_time(end)/1000.
max_memory_alloc = int(torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") * 1e-6)
if local_rank == 0 or args.verbose:
print(f"[GPU-{local_rank}] Training Time: {train_time}s\n" +
f"[GPU-{local_rank}] Avg. Iter. Time: {train_time /args.num_iters}s\n" +
f"[GPU-{local_rank}] Peak memory use = {max_memory_alloc}MiB\n\n", end='')


if __name__ == "__main__":
args = parse_fsdp_args()
train(args)
3 changes: 3 additions & 0 deletions qa/L0_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

set -xe

# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"

: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_*

2 changes: 2 additions & 0 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ pytest -Wignore -v $TE_PATH/examples/jax/mnist

# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"
pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
2 changes: 1 addition & 1 deletion tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_qdq(self):
scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
scale_inv = (1 / scale).reshape(1)

y = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
y, _ = quantize(x, q_dtype=jnp.float8_e4m3fn, scale=scale)
z = dequantize(y, dq_dtype=jnp.float32, scale_inv=scale_inv)

assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
Expand Down
Loading

0 comments on commit 7b70947

Please sign in to comment.