-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix passing arguments to
ttnn.transformer.scaled_dot_product_attention
- Loading branch information
Showing
3 changed files
with
66 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import pytest | ||
import torch | ||
import torch_ttnn | ||
import ttnn | ||
|
||
from tests.utils import assert_with_pcc | ||
|
||
|
||
class ScaledDotProductAttentionModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, *args, **kwargs): | ||
return torch.nn.functional.scaled_dot_product_attention(*args, **kwargs) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape, is_causal", | ||
( | ||
((1, 16, 197, 64), False), | ||
((1, 12, 197, 64), False), | ||
((1, 16, 50, 64), False), | ||
((1, 8, 4096, 40), False), | ||
((1, 8, 1024, 80), False), | ||
((1, 8, 256, 160), False), | ||
((1, 8, 64, 160), False), | ||
((1, 12, 50, 64), False), | ||
((1, 16, 1370, 80), False), | ||
((1, 12, 1, 64), False), | ||
((1, 12, 4, 64), True), | ||
), | ||
) | ||
def test_sdpa(device, input_shape, is_causal): | ||
module = ScaledDotProductAttentionModule() | ||
query = torch.rand(input_shape, dtype=torch.bfloat16) | ||
key = torch.rand(input_shape, dtype=torch.bfloat16) | ||
value = torch.rand(input_shape, dtype=torch.bfloat16) | ||
result_before = module.forward(query, key, value, is_causal=is_causal) | ||
|
||
option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=False) | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
module = torch.compile(module, backend=torch_ttnn.backend, options=option) | ||
result_after = module.forward(query, key, value, is_causal=is_causal) | ||
option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = [node.target for node in option._out_fx_graphs[0].nodes] | ||
assert torch.ops.aten._scaled_dot_product_flash_attention.default not in nodes | ||
assert nodes.count(ttnn.transformer.scaled_dot_product_attention) == 1 | ||
assert_with_pcc(result_before, result_after) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters