Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to convert aten.prod.dim_int to ttnn.prod #245

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions tests/lowering/reduction/test_prod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import torch_ttnn
import pytest
import ttnn


class ProdDimIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, dim, keep_dim):
return torch.prod(input, dim, keep_dim)


@pytest.mark.parametrize(
"input_shape, dim, keep_dim",
[
((1, 2, 32, 32), 3, True),
((1, 1, 32, 32), 2, True),
((1, 2, 32, 32), 1, True),
((2, 1, 32, 32), 0, True),
((2, 1, 32, 32), 0, True),
pytest.param(
(1, 1, 2, 32, 32), -1, True, marks=pytest.mark.xfail(reason="Unexpected output shape [1, 1, 2, 1] (#244)")
),
pytest.param((1, 2, 32, 32), 3, False, marks=pytest.mark.xfail(reason="Not support keep_dim = False (#244)")),
pytest.param(
(1, 1, 1, 32, 32),
4,
True,
marks=pytest.mark.xfail(
reason='dim >= -4 && dim <= 3 && "Dimension out of range (expected to be in range of [-4, 3]" (#244)'
),
),
pytest.param(
(1, 1, 32, 16), -1, True, marks=pytest.mark.xfail(reason="Need to pad with 1.0 instead of 0 (#244)")
),
pytest.param((32, 32), 1, True, marks=pytest.mark.xfail(reason="Input rank can't < 4")),
],
)
def test_prod_dim_int(device, input_shape, dim, keep_dim):
m = ProdDimIntModule()
input = torch.rand(input_shape, dtype=torch.bfloat16) + 0.5
result_before = m.forward(input, dim, keep_dim)

option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=True)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input, dim, keep_dim)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.prod) == 1
# Check inference result
assert result_before.shape == result_after.shape
# Give higher tolerance for product as it's not associative with float
assert torch.allclose(result_before, result_after, rtol=0.1)
6 changes: 6 additions & 0 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def is_function_call(node) -> bool:
ttnn.where,
]

TTNN_REDUCTION_OPS = [
ttnn.mean,
ttnn.prod,
]

TTNN_MATRIX_MULPIPLICATION_OPS = [
ttnn.matmul,
ttnn.linear,
Expand Down Expand Up @@ -147,6 +152,7 @@ def is_tt_compute(node) -> bool:
TTNN_POINTWISE_UNARY_OPS
+ TTNN_POINTWISE_BINARY_OPS
+ TTNN_POINTWISE_TRINARY_OPS
+ TTNN_REDUCTION_OPS
+ TTNN_MATRIX_MULPIPLICATION_OPS
+ TTNN_TARGET_WRAPPERS
+ TTNN_DATAMOVE_OPS
Expand Down
14 changes: 14 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,20 @@ def rewrite_node(node):
input = g.call_function(ttnn.to_layout, args=(input, TtnnRowMajorLayout()))
return g.call_function(ttnn.pad, args=(input, full_pad, value))

if node.target == torch.ops.aten.prod.dim_int:
input_shape = args[0].meta["val"].size()
if len(input_shape) != 4:
return None
# TODO(#244): Not support keepdim = False (default value)
if len(args) < 3 or args[2] == False:
return None
# TODO(#244): Not support non-tile-aligned shape
if len(input_shape) < 2 or any(size % ttnn.TILE_SIZE != 0 for size in input_shape[-2:]):
return None
# Args: input, all_dimensions=false, dim
new_args = (args[0], False, args[1])
return g.call_function(ttnn.prod, new_args, kwargs)

with g.inserting_before(node):
new_node = rewrite_node(node)
if new_node is not None:
Expand Down