Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pipeline parallelism with FusedAttn #635

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 39 additions & 47 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,32 +1587,30 @@ def forward(
assert (
max_seqlen_q == max_seqlen_kv
), "Maximum sequence length for Q and KV should be the same."
if self.layer_number == 1:
if cu_seqlens_q is None:
assert (attention_mask is not None
), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask)
else:
_cu_seqlens_q = cu_seqlens_q
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
if cu_seqlens_q is None:
assert (attention_mask is not None
), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask)
else:
_cu_seqlens_q = cu_seqlens_q
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
_cu_seqlens_kv = _cu_seqlens_q
query_layer_packed, key_layer_packed, value_layer_packed = PackTensors.apply(
_indices_q, query_layer, key_layer, value_layer
)
else:
if self.layer_number == 1:
if cu_seqlens_q is None or cu_seqlens_kv is None:
assert (attention_mask is not None
), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(
attention_mask[0])
_cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices(
attention_mask[1])
else:
_cu_seqlens_q = cu_seqlens_q
_cu_seqlens_kv = cu_seqlens_kv
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
_indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
if cu_seqlens_q is None or cu_seqlens_kv is None:
assert (attention_mask is not None
), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(
attention_mask[0])
_cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices(
attention_mask[1])
else:
_cu_seqlens_q = cu_seqlens_q
_cu_seqlens_kv = cu_seqlens_kv
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
_indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
query_layer_packed = PackTensors.apply(_indices_q, query_layer)
key_layer_packed, value_layer_packed = PackTensors.apply(
_indices_kv, key_layer, value_layer
Expand Down Expand Up @@ -2030,39 +2028,33 @@ def forward(
global _cu_seqlens_q, _cu_seqlens_kv
if (cu_seqlens_q is not None and cu_seqlens_kv is not None):
# use cu_seqlens when both cu_seqlens and attention_mask are present
if self.layer_number == 1:
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
elif attention_mask is not None:
if self.attention_type == "self":
if self.layer_number == 1:
_cu_seqlens_q = get_cu_seqlens(attention_mask)
_cu_seqlens_kv = _cu_seqlens_q
_cu_seqlens_q = get_cu_seqlens(attention_mask)
_cu_seqlens_kv = _cu_seqlens_q
else:
if self.layer_number == 1:
_cu_seqlens_q = get_cu_seqlens(attention_mask[0])
_cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
_cu_seqlens_q = get_cu_seqlens(attention_mask[0])
_cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
else:
raise Exception("Please provide attention_mask or cu_seqlens for padding!")
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
else:
if self.layer_number == 1:
if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
else:
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv

qkv_dtype = TE_DType[query_layer.dtype]

Expand Down
Loading