Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make TransformerLayer accept a bshd or sbhd tensor format #557

77 changes: 77 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,3 +1197,80 @@ def test_gpt_fp8_parameters(dtype, bs, model):
outputs = _test_gpt_fp8_parameters(bs, dtype, config, False)
outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True)
assert_all_equal(outputs, outputs_fp8_params)

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_transformer_layer_hidden_states_format(dtype, bs, model):
config = model_configs[model]

sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

# Set `torch.manual_seed` to make sure the weights are identical to the
# other layer. Set `*dropout` values to 0 to make sure the forward pass
# is identical to the other layer.
torch.manual_seed(0)
block_sbhd = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
hidden_states_format="sbhd"
)
.to(dtype=dtype)
.cuda()
)

# Set `torch.manual_seed` to make sure the weights are identical to the
# other layer. Set `*dropout` values to 0 to make sure the forward pass
# is identical to the other layer.
torch.manual_seed(0)
block_bshd = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
hidden_states_format="bshd"
)
.to(dtype=dtype)
.cuda()
)

for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()):
assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical"

x_sbhd = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).to(dtype).cuda()

x_bshd = x_sbhd.transpose(0,1).contiguous()

# To make sure forward is also identical (just in case some module decides
# to act fancy)
torch.manual_seed(0)
y_sbhd = block_sbhd(x_sbhd)

# To make sure forward is also identical (just in case some module decides
# to act fancy)
torch.manual_seed(0)
y_bshd = block_bshd(x_bshd)

assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()])
46 changes: 40 additions & 6 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,11 +845,32 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor, tensor_format: str = "sbhd") -> torch.Tensor:
cyanguwa marked this conversation as resolved.
Show resolved Hide resolved
"""
input tensor t is of shape [seq_length, ..., dim]
rotary positional embeding tensor `freqs` is of shape [seq_length, ..., dim]
Parameters
----------
t: torch.Tensor
input tensor on which rotary positional embedding will be applied
freqs: torch.Tensor
rotary positional embeding tensor `freqs` is of shape
`[seq_length, ..., dim]`
batch_first_format: bool, default = False
sudhakarsingh27 marked this conversation as resolved.
Show resolved Hide resolved
is `True` if `t` is of shape [bs, seq, ...], `False` otherwise

"""
assert tensor_format in ("sbhd", "bshd"),("Only formats `sbhd` or `bshd` "
"are supported for input tensor "
"`t`.")
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert cur_seq_len <= max_seq_len
sudhakarsingh27 marked this conversation as resolved.
Show resolved Hide resolved
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
freqs = freqs.transpose(0,1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]

rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
Expand Down Expand Up @@ -2463,6 +2484,16 @@ class MultiheadAttention(torch.nn.Module):
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
qkv_format: str, default = `sbhd`
dimension format for `query_layer`, `key_layer` and `value_layer`,
{`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
sudhakarsingh27 marked this conversation as resolved.
Show resolved Hide resolved
`h` the number of heads, `d` head size, and `t` the total number of sequences
in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
are used for when sequences in a batch are of equal length or padded to
equal length, and the `thd` format is used for when sequences in a batch
have different lengths. Please note that these formats do not reflect how
sudhakarsingh27 marked this conversation as resolved.
Show resolved Hide resolved
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `_get_qkv_layout` to gain the layout information.

Parallelism parameters
----------------------
Expand Down Expand Up @@ -2540,9 +2571,11 @@ def __init__(
bias: bool = True,
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
qkv_format: str = "sbhd",
) -> None:
super().__init__()

self.qkv_format = qkv_format
self.attn_mask_type = attn_mask_type
self.layer_number = layer_number
self.input_layernorm = input_layernorm
Expand Down Expand Up @@ -2678,6 +2711,7 @@ def __init__(
kv_channels,
num_gqa_groups=self.num_gqa_groups,
attention_dropout=attention_dropout,
qkv_format=self.qkv_format,
tp_size=tp_size,
get_rng_state_tracker=get_rng_state_tracker,
sequence_parallel=sequence_parallel,
Expand Down Expand Up @@ -3038,14 +3072,14 @@ def forward(
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format)

context_layer = self.core_attention(
query_layer,
key_layer,
value_layer,
qkv_format='sbhd',
qkv_format=self.qkv_format,
cu_seqlens_q=None,
cu_seqlens_kv=None,
attention_mask=attention_mask,
Expand Down
13 changes: 13 additions & 0 deletions transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ class TransformerLayer(torch.nn.Module):
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
hidden_states_format: str, default = 'sbhd'
sudhakarsingh27 marked this conversation as resolved.
Show resolved Hide resolved
This controls whether the dimensions of the
intermediate hidden states is 'batch first' ('bshd') or
'sequence first' ('sbhd'). `s` stands for the sequence
length, `b` batch size, `h` the number of heads, `d`
head size. Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules.
cyanguwa marked this conversation as resolved.
Show resolved Hide resolved
Options are: 'sbhd' and 'bshd'

Parallelism parameters
----------------------
Expand Down Expand Up @@ -242,6 +251,7 @@ def __init__(
activation: str = 'gelu',
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
hidden_states_format: str = "sbhd",
) -> None:
super().__init__()

Expand Down Expand Up @@ -318,6 +328,8 @@ def __init__(

self.get_rng_state_tracker = get_rng_state_tracker

self.hidden_states_format = hidden_states_format

attention_args = (
hidden_size,
num_attention_heads,
Expand Down Expand Up @@ -347,6 +359,7 @@ def __init__(
"ub_split_rs" : ub_split_rs,
"ub_atomic_gemm_rs" : ub_atomic_gemm_rs,
"ub_atomic_gemm_ag" : ub_atomic_gemm_ag,
"qkv_format" : self.hidden_states_format,
}

self.self_attention = MultiheadAttention(
Expand Down