Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx authored Jan 11, 2024
2 parents ca1808c + e547f8e commit c8ad95b
Show file tree
Hide file tree
Showing 16 changed files with 967 additions and 200 deletions.
224 changes: 224 additions & 0 deletions tests/paddle/parallel_tests/attention_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unittest for Transformer layer in tensor parallel"""

import unittest

import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu import mp_ops

from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks
import transformer_engine.paddle as te


class TestAttentionTp(unittest.TestCase):
"""Tests MultiHeadAttention layer with model parallel in BF16"""

def setUp(self):
self.set_attr()
self.init_dist_env()
paddle.set_default_dtype(self.global_dtype)

def init_dist_env(self):
"""Init Paddle Fleet environment"""
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()

def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.rtol = 5e-3
self.atol = 5e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False

def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False):
inp, mask = inp_list
if sequence_parallel:
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
else:
input_parallel = inp
with te.fp8_autocast(enabled=fp8_enabled):
out = layer(input_parallel, mask)
if sequence_parallel:
total_out = mp_ops._c_concat(out, group=self.tp_group)
total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, total_out

def test_parallel_layer(self):
"""Tests parallel Transformer"""
set_random_seed(1024)
common_args = (
self.hidden_size,
self.num_heads,
)
common_kwargs = {
'layernorm_epsilon': self.eps,
'attention_dropout': 0.0,
'attn_mask_type': self.mask_type,
'attention_type': 'self',
"tp_group": self.tp_group,
"input_layernorm": True,
}

layer_tp = te.MultiHeadAttention(*common_args,
**common_kwargs,
set_parallel_mode=True,
sequence_parallel=self.sequence_parallel)
layer_single = te.MultiHeadAttention(*common_args, **common_kwargs, set_parallel_mode=False)

def _get_total_weight(local_weight, tp_group, axis, interleave=False):
total_weight = []
partial_weight = local_weight.clone().detach()
paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group)
if interleave:
# Due to the interleaved qkv layout, need to concat on num_head
# dimention for column parallel linear in MultiHeadAttention layer
assert axis == 0
assert [3 * self.hidden_size // self.world_size,
self.hidden_size] == partial_weight.shape
local_num_head = self.num_heads // self.world_size
for idx, _ in enumerate(total_weight):
total_weight[idx] = total_weight[idx].reshape(
[3, local_num_head, -1, self.hidden_size])
total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
else:
total_weight = paddle.concat(total_weight, axis=axis)
return total_weight

def _get_weight(obj, weight_names):
for name in weight_names:
obj = getattr(obj, name)
return obj

def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False):
weight_src = _get_weight(layer_src, weight_names)
weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None:
total_weight = weight_src
elif partition_mode == 'column':
total_weight = _get_total_weight(weight_src,
tp_group=self.tp_group,
axis=0,
interleave=interleave)
elif partition_mode == 'row':
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else:
raise ValueError(f"Partition Mode {partition_mode} is not supported.")
assert weight_dst.shape == total_weight.shape, \
f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
weight_dst.copy_(total_weight, True)

copy_weight(layer_tp, layer_single, None, ['layernorm_qkv', 'ln_weight'])
copy_weight(layer_tp, layer_single, 'column', ['layernorm_qkv', 'weight'], interleave=True)
copy_weight(layer_tp, layer_single, 'row', ['proj', 'weight'])

if self.sequence_parallel:
register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)

optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
optimizer_single = paddle.optimizer.SGD(learning_rate=0.01,
parameters=layer_single.parameters())

layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp)

for _ in range(5):
inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size],
self.global_dtype)
mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
dtype='bool')
loss_tp, out_tp = self._train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8,
self.sequence_parallel)
loss_single, out_single = self._train_one_step(layer_single, [inp, mask],
optimizer_single, self.fp8)
assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)


class TestAttentionTpFp8(TestAttentionTp):
"""Tests MultiHeadAttention layer with model parallel in FP8"""

def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = False


class TestAttentionSp(TestAttentionTp):
"""Tests MultiHeadAttention layer with sequence parallel in BF16"""

def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.rtol = 5e-3
self.atol = 5e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True


class TestAttentionSpFp8(TestAttentionTp):
"""Tests MultiHeadAttention layer with sequence parallel in FP8"""

def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.hidden_size = 1024
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
self.mask_type = 'padding'
self.global_dtype = 'bfloat16'
self.rtol = 5e-2
self.atol = 1e-1
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = True


if __name__ == '__main__':
unittest.main()
93 changes: 77 additions & 16 deletions tests/paddle/parallel_tests/layernorm_linear_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ def init_dist_env(self):
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
strategy.hybrid_configs["mp_configs"].need_broadcast_data = False
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
self.tp_group = self.hcg.get_model_parallel_group()
self.world_size = self.hcg.get_model_parallel_world_size()

def set_attr(self):
"""Set test configs"""
Expand All @@ -44,6 +47,39 @@ def set_attr(self):
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False

def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
assert split_input in ['none', 'column', 'row']
if split_input == 'column':
split_size = inp.shape[1] // self.world_size
input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
elif split_input == 'row':
split_size = inp.shape[0] // self.world_size
input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
out = layer(input_parallel)
if gather_output:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if split_input != 'none':
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
if split_input == 'column':
grad_input = paddle.concat(grad_input, axis=1)
elif split_input == 'row':
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
return loss, grad_input

def test_column_parallel_layer(self):
"""Tests column parallel LayerNormLinear"""
Expand All @@ -53,6 +89,7 @@ def test_column_parallel_layer(self):
self.out_features,
eps=self.eps,
parallel_mode='column',
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.LayerNormLinear(
self.in_features,
Expand All @@ -77,25 +114,16 @@ def test_column_parallel_layer(self):
layer_te = fleet.distributed_model(layer_te)
optimizer_te = fleet.distributed_optimizer(optimizer_te)

def train_one_step(layer, inp, optimizer, gather=False):
inp = paddle.to_tensor(inp)
inp.stop_gradient = False
out = layer(inp)
if gather:
total_out = mp_ops._c_concat(out, group=self.tp_group)
else:
total_out = out
loss = total_out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss, inp.grad

for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
loss_tp, grad_input = train_one_step(layer_te, inp, optimizer_te, gather=True)
loss_ref, grad_input_ref = train_one_step(layer_pd, inp, optimizer_pd)
loss_tp, grad_input = self._train_one_step(
layer_te,
inp,
optimizer_te,
split_input='row' if self.sequence_parallel else 'none',
gather_output=True)
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)

Expand All @@ -113,6 +141,39 @@ def set_attr(self):
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = False


class TestLayerNormLinearSp(TestLayerNormLinearTp):
"""Tests LayernormLinear layer with sequence parallelism"""

def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = True


class TestLayerNormLinearSpFp8(TestLayerNormLinearTp):
"""Tests LayernormLinear layer with sequence parallelism in FP8"""

def set_attr(self):
"""Set test configs"""
self.batch_size = 16
self.in_features = 32
self.out_features = 64
self.global_dtype = 'bfloat16'
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
self.fp8 = True
self.sequence_parallel = True


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit c8ad95b

Please sign in to comment.