Skip to content

Commit

Permalink
Fix arg name in numerics test
Browse files Browse the repository at this point in the history
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
  • Loading branch information
ksivaman committed Jan 19, 2024
1 parent 051db0d commit 2ceeee7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
hidden_states_format="sbhd"
attn_input_format="sbhd"
)
.to(dtype=dtype)
.cuda()
Expand All @@ -1248,7 +1248,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
hidden_states_format="bshd"
attn_input_format="bshd"
)
.to(dtype=dtype)
.cuda()
Expand Down

0 comments on commit 2ceeee7

Please sign in to comment.