Skip to content

Commit

Permalink
Fix passing arguments to ttnn.transformer.scaled_dot_product_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
jdh8 committed Dec 13, 2024
1 parent 2f08f39 commit ea1841f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 30 deletions.
50 changes: 50 additions & 0 deletions tests/lowering/misc/test_scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import torch
import torch_ttnn
import ttnn

from tests.utils import assert_with_pcc


class ScaledDotProductAttentionModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(*args, **kwargs)


@pytest.mark.parametrize(
"input_shape, is_causal",
(
((1, 16, 197, 64), False),
((1, 12, 197, 64), False),
((1, 16, 50, 64), False),
((1, 8, 4096, 40), False),
((1, 8, 1024, 80), False),
((1, 8, 256, 160), False),
((1, 8, 64, 160), False),
((1, 12, 50, 64), False),
((1, 16, 1370, 80), False),
((1, 12, 1, 64), False),
((1, 12, 4, 64), True),
),
)
def test_sdpa(device, input_shape, is_causal):
module = ScaledDotProductAttentionModule()
query = torch.rand(input_shape, dtype=torch.bfloat16)
key = torch.rand(input_shape, dtype=torch.bfloat16)
value = torch.rand(input_shape, dtype=torch.bfloat16)
result_before = module.forward(query, key, value, is_causal=is_causal)

option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=False)
# The compilation is lazy, so we need to run forward once to trigger the compilation
module = torch.compile(module, backend=torch_ttnn.backend, options=option)
result_after = module.forward(query, key, value, is_causal=is_causal)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = [node.target for node in option._out_fx_graphs[0].nodes]
assert torch.ops.aten._scaled_dot_product_flash_attention.default not in nodes
assert nodes.count(ttnn.transformer.scaled_dot_product_attention) == 1
assert_with_pcc(result_before, result_after)
1 change: 1 addition & 0 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def is_tt_compute(node) -> bool:
ttnn.moreh_cumsum,
ttnn.sum,
ttnn.typecast,
ttnn.transformer.scaled_dot_product_attention,
]
)

Expand Down
45 changes: 15 additions & 30 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,21 @@ def __init__(self, target, args, kwargs):
# TODO: no ttnn op can convert _adaptive_avg_pool2d
return self.call_function_prop_meta(target, args, kwargs)

if target == torch.ops.aten._scaled_dot_product_flash_attention.default:

def select(dropout_p=0.0, is_causal=False):
# TODO(jdh8): Add suuport for training mode
if dropout_p > 0.0:
return self.call_function_prop_meta(target, args, kwargs)

return self.call_function_prop_meta(
ttnn.transformer.scaled_dot_product_attention,
args[:3],
{"is_causal": is_causal},
)

return select(*args[3:])

return self.call_function_prop_meta(target, args, kwargs)


Expand Down Expand Up @@ -1147,36 +1162,6 @@ def lower_binary_eltwise(fn, args):

return g.call_function(ttnn.concat, (tensors_to_concat, dim))

if node.target == torch.ops.aten._scaled_dot_product_flash_attention.default:
query, key, value = args
query_shape = query.meta["val"].size()
key_shape = key.meta["val"].size()
value_shape = value.meta["val"].size()

attn_mask = kwargs.get("attn_mask")
dropout_p = kwargs.get("dropout_p", 0.0)
scale = kwargs.get("scale", 1.0 / math.sqrt(query_shape[-1]))

if kwargs.get("is_causal", False):
attn_mask = torch.ones(query_shape[-2], key_shape[-2], dtype=torch.bool).tril()

key_perm = [*range(len(key_shape))]
key_perm[-2], key_perm[-1] = key_perm[-1], key_perm[-2]
key = g.call_function(ttnn.permute, (key, key_perm))

attn_weight = g.call_function(ttnn.matmul, (query, key))
attn_weight = g.call_function(ttnn.mul, (attn_weight, scale))

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_weight = g.call_function(ttnn.where, (attn_mask, attn_weight, -math.inf))
else:
attn_weight = g.call_function(ttnn.add, (attn_weight, attn_mask))

attn_weight = g.call_function(ttnn.softmax, (attn_weight,), {"dim": -1, "numeric_stable": True})
attn_weight = g.call_function(ttnn.dropout, (attn_weight,), {"p": dropout_p})
return g.call_function(ttnn.matmul, (attn_weight, value))

# PEP 8 suggests this explicit statement
return None

Expand Down

0 comments on commit ea1841f

Please sign in to comment.