Skip to content

Commit

Permalink
Fix failing CI due to PR #557 merge (#616)
Browse files Browse the repository at this point in the history
fix failing tests due to PR #557

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com>
  • Loading branch information
sudhakarsingh27 and cyanguwa authored Jan 20, 2024
1 parent e4f506a commit bacefdb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
hidden_states_format="sbhd"
attn_input_format="sbhd"
)
.to(dtype=dtype)
.cuda()
Expand All @@ -1248,7 +1248,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
hidden_states_format="bshd"
attn_input_format="bshd"
)
.to(dtype=dtype)
.cuda()
Expand Down
12 changes: 9 additions & 3 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,11 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor, tensor_format: str = "sbhd") -> torch.Tensor:
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd"
) -> torch.Tensor:
"""
Parameters
----------
Expand All @@ -1056,8 +1060,10 @@ def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor, tensor_format: st

# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert cur_seq_len <= max_seq_len, (f"Rotary Embeddings only supported "
"upto {max_seq_len} sequence length!")
if cur_seq_len > max_seq_len:
raise Exception(f"Rotary Embeddings only supported upto {max_seq_len} "
"sequence length!")

freqs = freqs[:cur_seq_len].to(t.dtype)
if tensor_format == "bshd":
freqs = freqs.transpose(0,1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
Expand Down

0 comments on commit bacefdb

Please sign in to comment.