diff --git a/tests/lowering/reduction/test_prod.py b/tests/lowering/reduction/test_prod.py new file mode 100644 index 000000000..5b38597d5 --- /dev/null +++ b/tests/lowering/reduction/test_prod.py @@ -0,0 +1,58 @@ +import torch +import torch_ttnn +import pytest +import ttnn + + +class ProdDimIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, dim, keep_dim): + return torch.prod(input, dim, keep_dim) + + +@pytest.mark.parametrize( + "input_shape, dim, keep_dim", + [ + ((1, 2, 32, 32), 3, True), + ((1, 1, 32, 32), 2, True), + ((1, 2, 32, 32), 1, True), + ((2, 1, 32, 32), 0, True), + ((2, 1, 32, 32), 0, True), + pytest.param( + (1, 1, 2, 32, 32), -1, True, marks=pytest.mark.xfail(reason="Unexpected output shape [1, 1, 2, 1] (#244)") + ), + pytest.param((1, 2, 32, 32), 3, False, marks=pytest.mark.xfail(reason="Not support keep_dim = False (#244)")), + pytest.param( + (1, 1, 1, 32, 32), + 4, + True, + marks=pytest.mark.xfail( + reason='dim >= -4 && dim <= 3 && "Dimension out of range (expected to be in range of [-4, 3]" (#244)' + ), + ), + pytest.param( + (1, 1, 32, 16), -1, True, marks=pytest.mark.xfail(reason="Need to pad with 1.0 instead of 0 (#244)") + ), + pytest.param((32, 32), 1, True, marks=pytest.mark.xfail(reason="Input rank can't < 4")), + ], +) +def test_prod_dim_int(device, input_shape, dim, keep_dim): + m = ProdDimIntModule() + input = torch.rand(input_shape, dtype=torch.bfloat16) + 0.5 + result_before = m.forward(input, dim, keep_dim) + + option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=True) + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(input, dim, keep_dim) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten + nodes = list(option._out_fx_graphs[0].nodes) + assert [node.target for node in nodes].count(ttnn.prod) == 1 + # Check inference result + assert result_before.shape == result_after.shape + # Give higher tolerance for product as it's not associative with float + assert torch.allclose(result_before, result_after, rtol=0.1) diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index 7add48c36..0f0974627 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -109,6 +109,11 @@ def is_function_call(node) -> bool: ttnn.where, ] +TTNN_REDUCTION_OPS = [ + ttnn.mean, + ttnn.prod, +] + TTNN_MATRIX_MULPIPLICATION_OPS = [ ttnn.matmul, ttnn.linear, @@ -147,6 +152,7 @@ def is_tt_compute(node) -> bool: TTNN_POINTWISE_UNARY_OPS + TTNN_POINTWISE_BINARY_OPS + TTNN_POINTWISE_TRINARY_OPS + + TTNN_REDUCTION_OPS + TTNN_MATRIX_MULPIPLICATION_OPS + TTNN_TARGET_WRAPPERS + TTNN_DATAMOVE_OPS diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index b0bb2c8f2..87876d410 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -623,6 +623,20 @@ def rewrite_node(node): input = g.call_function(ttnn.to_layout, args=(input, TtnnRowMajorLayout())) return g.call_function(ttnn.pad, args=(input, full_pad, value)) + if node.target == torch.ops.aten.prod.dim_int: + input_shape = args[0].meta["val"].size() + if len(input_shape) != 4: + return None + # TODO(#244): Not support keepdim = False (default value) + if len(args) < 3 or args[2] == False: + return None + # TODO(#244): Not support non-tile-aligned shape + if len(input_shape) < 2 or any(size % ttnn.TILE_SIZE != 0 for size in input_shape[-2:]): + return None + # Args: input, all_dimensions=false, dim + new_args = (args[0], False, args[1]) + return g.call_function(ttnn.prod, new_args, kwargs) + with g.inserting_before(node): new_node = rewrite_node(node) if new_node is not None: