From 7b4c81b5dfa26446720dc5aa09483f77cb91f97f Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 19 Sep 2024 07:05:40 +0000 Subject: [PATCH 1/6] Convert aten.prod.dim_int to ttnn.prod --- tests/lowering/reduction/test_prod.py | 47 +++++++++++++++++++ .../passes/lowering/add_data_move_pass.py | 6 +++ torch_ttnn/passes/lowering/to_tt_pass.py | 5 ++ 3 files changed, 58 insertions(+) create mode 100644 tests/lowering/reduction/test_prod.py diff --git a/tests/lowering/reduction/test_prod.py b/tests/lowering/reduction/test_prod.py new file mode 100644 index 000000000..1232aad8e --- /dev/null +++ b/tests/lowering/reduction/test_prod.py @@ -0,0 +1,47 @@ +import torch +import torch_ttnn +import pytest +import ttnn + + +class ProdDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, dim, keep_dim): + return torch.prod(input, dim, keep_dim) + + +@pytest.mark.parametrize( + "input_shape, dim, keep_dim", + [ + ((1, 2, 32, 32), -1, True), + # ((1, 2, 32, 32), -1, False), # Not support keep_dim = False + ((1, 1, 32, 32), 2, True), + ((1, 2, 32, 32), 1, True), + ((2, 1, 32, 32), 0, True), + # ((2, 1, 1, 32, 32), -1, True), # Output size cannot fit input with offset + # ((1, 1, 32, 16), -1, True), # Need 1.0 padding + # ((1, 1, 16, 32), -1, True), # Need to crop + # ((32, 32), -1, True), # Need 4d shape + ], +) +def test_prod_dim(device, input_shape, dim, keep_dim): + m = ProdDimModule() + input = torch.rand(input_shape, dtype=torch.bfloat16) + 0.5 + result_before = m.forward(input, dim, keep_dim) + + option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=True) + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(input, dim, keep_dim) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten + nodes = list(option._out_fx_graphs[0].nodes) + # There should be no op + assert [node.target for node in nodes].count(ttnn.prod) == 1 + # Check inference result + assert result_before.shape == result_after.shape + # Give higher tolerance for product as it's not associative with float + assert torch.allclose(result_before, result_after, rtol=0.1) diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index 7add48c36..0f0974627 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -109,6 +109,11 @@ def is_function_call(node) -> bool: ttnn.where, ] +TTNN_REDUCTION_OPS = [ + ttnn.mean, + ttnn.prod, +] + TTNN_MATRIX_MULPIPLICATION_OPS = [ ttnn.matmul, ttnn.linear, @@ -147,6 +152,7 @@ def is_tt_compute(node) -> bool: TTNN_POINTWISE_UNARY_OPS + TTNN_POINTWISE_BINARY_OPS + TTNN_POINTWISE_TRINARY_OPS + + TTNN_REDUCTION_OPS + TTNN_MATRIX_MULPIPLICATION_OPS + TTNN_TARGET_WRAPPERS + TTNN_DATAMOVE_OPS diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index b0bb2c8f2..5481f229b 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -289,6 +289,11 @@ def call_function(self, target, args, kwargs): if target == torch.ops.aten.min.default: return self.call_function_prop_meta(ttnn.min, args, kwargs) + if target == torch.ops.aten.prod.dim_int: + # Args: input, all_dimensions=false, dim + new_args = (args[0], False, args[1]) + return self.call_function_prop_meta(ttnn.prod, new_args, kwargs) + ############################################################ # Data movement ############################################################ From 792f816dc4339b81f476f2a0401b77de9d971ff9 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 20 Sep 2024 07:18:49 +0000 Subject: [PATCH 2/6] Fallback unsupported cases --- tests/lowering/reduction/test_prod.py | 30 ++++++++++++++---------- torch_ttnn/passes/lowering/to_tt_pass.py | 19 +++++++++++---- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/tests/lowering/reduction/test_prod.py b/tests/lowering/reduction/test_prod.py index 1232aad8e..0fef594ae 100644 --- a/tests/lowering/reduction/test_prod.py +++ b/tests/lowering/reduction/test_prod.py @@ -13,20 +13,26 @@ def forward(self, input, dim, keep_dim): @pytest.mark.parametrize( - "input_shape, dim, keep_dim", + "input_shape, dim, keep_dim, converted", [ - ((1, 2, 32, 32), -1, True), - # ((1, 2, 32, 32), -1, False), # Not support keep_dim = False - ((1, 1, 32, 32), 2, True), - ((1, 2, 32, 32), 1, True), - ((2, 1, 32, 32), 0, True), - # ((2, 1, 1, 32, 32), -1, True), # Output size cannot fit input with offset - # ((1, 1, 32, 16), -1, True), # Need 1.0 padding - # ((1, 1, 16, 32), -1, True), # Need to crop - # ((32, 32), -1, True), # Need 4d shape + ((1, 2, 32, 32), 3, True, True), + ((1, 1, 32, 32), 2, True, True), + ((1, 2, 32, 32), 1, True, True), + ((2, 1, 32, 32), 0, True, True), + ((2, 1, 32, 32), 0, True, True), + # TODO(TODO): Cannot get the device from a tensor with host storage + ((1, 1, 1, 32, 32), 3, True, False), + # TODO(TODO): Not support keep_dim = False + ((1, 2, 32, 32), 3, False, False), + # TODO(TODO): dim >= -4 && dim <= 3 && "Dimension out of range (expected to be in range of [-4, 3] + ((1, 1, 1, 32, 32), 4, True, False), + # TODO(TODO): Need to pad with 1.0 instead of 0 + ((1, 1, 32, 16), -1, True, False), + # TODO(TODO): Need 4d shape + ((32, 32), 1, True, False), ], ) -def test_prod_dim(device, input_shape, dim, keep_dim): +def test_prod_dim(device, input_shape, dim, keep_dim, converted): m = ProdDimModule() input = torch.rand(input_shape, dtype=torch.bfloat16) + 0.5 result_before = m.forward(input, dim, keep_dim) @@ -40,7 +46,7 @@ def test_prod_dim(device, input_shape, dim, keep_dim): # Check the graph has be rewritten nodes = list(option._out_fx_graphs[0].nodes) # There should be no op - assert [node.target for node in nodes].count(ttnn.prod) == 1 + assert [node.target for node in nodes].count(ttnn.prod) == (1 if converted else 0) # Check inference result assert result_before.shape == result_after.shape # Give higher tolerance for product as it's not associative with float diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 5481f229b..c05c7bdd1 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -289,11 +289,6 @@ def call_function(self, target, args, kwargs): if target == torch.ops.aten.min.default: return self.call_function_prop_meta(ttnn.min, args, kwargs) - if target == torch.ops.aten.prod.dim_int: - # Args: input, all_dimensions=false, dim - new_args = (args[0], False, args[1]) - return self.call_function_prop_meta(ttnn.prod, new_args, kwargs) - ############################################################ # Data movement ############################################################ @@ -628,6 +623,20 @@ def rewrite_node(node): input = g.call_function(ttnn.to_layout, args=(input, TtnnRowMajorLayout())) return g.call_function(ttnn.pad, args=(input, full_pad, value)) + if node.target == torch.ops.aten.prod.dim_int: + input_shape = args[0].meta["val"].size() + if len(input_shape) != 4: + return None + # TODO(TODO): Not support keepdim = False (default value) + if len(args) < 3 or args[2] == False: + return None + # TODO(TODO): Not support non-tile-aligned shape + if len(input_shape) < 2 or any(size % ttnn.TILE_SIZE != 0 for size in input_shape[-2:]): + return None + # Args: input, all_dimensions=false, dim + new_args = (args[0], False, args[1]) + return g.call_function(ttnn.prod, new_args, kwargs) + with g.inserting_before(node): new_node = rewrite_node(node) if new_node is not None: From 9cb3f033ae2c509cb68985fc0d3cf8595154b7ab Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 20 Sep 2024 07:32:55 +0000 Subject: [PATCH 3/6] Add TODO comment --- tests/lowering/reduction/test_prod.py | 12 ++++++------ torch_ttnn/passes/lowering/to_tt_pass.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/lowering/reduction/test_prod.py b/tests/lowering/reduction/test_prod.py index 0fef594ae..6599a83d9 100644 --- a/tests/lowering/reduction/test_prod.py +++ b/tests/lowering/reduction/test_prod.py @@ -20,15 +20,15 @@ def forward(self, input, dim, keep_dim): ((1, 2, 32, 32), 1, True, True), ((2, 1, 32, 32), 0, True, True), ((2, 1, 32, 32), 0, True, True), - # TODO(TODO): Cannot get the device from a tensor with host storage - ((1, 1, 1, 32, 32), 3, True, False), - # TODO(TODO): Not support keep_dim = False + # TODO(#244): Unexpected output shape [1, 1, 2, 1] + ((1, 1, 2, 32, 32), -1, True, False), + # TODO(#244): Not support keep_dim = False ((1, 2, 32, 32), 3, False, False), - # TODO(TODO): dim >= -4 && dim <= 3 && "Dimension out of range (expected to be in range of [-4, 3] + # TODO(#244): dim >= -4 && dim <= 3 && "Dimension out of range (expected to be in range of [-4, 3]" ((1, 1, 1, 32, 32), 4, True, False), - # TODO(TODO): Need to pad with 1.0 instead of 0 + # TODO(#244): Need to pad with 1.0 instead of 0 ((1, 1, 32, 16), -1, True, False), - # TODO(TODO): Need 4d shape + # TODO(#244): Input rank can't < 4 ((32, 32), 1, True, False), ], ) diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index c05c7bdd1..87876d410 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -627,10 +627,10 @@ def rewrite_node(node): input_shape = args[0].meta["val"].size() if len(input_shape) != 4: return None - # TODO(TODO): Not support keepdim = False (default value) + # TODO(#244): Not support keepdim = False (default value) if len(args) < 3 or args[2] == False: return None - # TODO(TODO): Not support non-tile-aligned shape + # TODO(#244): Not support non-tile-aligned shape if len(input_shape) < 2 or any(size % ttnn.TILE_SIZE != 0 for size in input_shape[-2:]): return None # Args: input, all_dimensions=false, dim From 1b160e72f8db63f518c1728ba4e4dd63481a3444 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 20 Sep 2024 07:35:42 +0000 Subject: [PATCH 4/6] Fix test names --- tests/lowering/reduction/test_prod.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/lowering/reduction/test_prod.py b/tests/lowering/reduction/test_prod.py index 6599a83d9..1664e1e30 100644 --- a/tests/lowering/reduction/test_prod.py +++ b/tests/lowering/reduction/test_prod.py @@ -4,7 +4,7 @@ import ttnn -class ProdDimModule(torch.nn.Module): +class ProdDimIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -32,8 +32,8 @@ def forward(self, input, dim, keep_dim): ((32, 32), 1, True, False), ], ) -def test_prod_dim(device, input_shape, dim, keep_dim, converted): - m = ProdDimModule() +def test_prod_dim_int(device, input_shape, dim, keep_dim, converted): + m = ProdDimIntModule() input = torch.rand(input_shape, dtype=torch.bfloat16) + 0.5 result_before = m.forward(input, dim, keep_dim) From 356ce4ca9589bf458f3daab5ffe203e5a23f0869 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 20 Sep 2024 10:45:55 +0000 Subject: [PATCH 5/6] Fix comments --- tests/lowering/reduction/test_prod.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/lowering/reduction/test_prod.py b/tests/lowering/reduction/test_prod.py index 1664e1e30..2547e1818 100644 --- a/tests/lowering/reduction/test_prod.py +++ b/tests/lowering/reduction/test_prod.py @@ -45,7 +45,6 @@ def test_prod_dim_int(device, input_shape, dim, keep_dim, converted): # Check the graph has be rewritten nodes = list(option._out_fx_graphs[0].nodes) - # There should be no op assert [node.target for node in nodes].count(ttnn.prod) == (1 if converted else 0) # Check inference result assert result_before.shape == result_after.shape From e985ae9d7b01cd2bfe9a4834a5afde18a638d95d Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Mon, 23 Sep 2024 02:08:57 +0000 Subject: [PATCH 6/6] Use xfail --- tests/lowering/reduction/test_prod.py | 42 +++++++++++++++------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/tests/lowering/reduction/test_prod.py b/tests/lowering/reduction/test_prod.py index 2547e1818..5b38597d5 100644 --- a/tests/lowering/reduction/test_prod.py +++ b/tests/lowering/reduction/test_prod.py @@ -13,26 +13,32 @@ def forward(self, input, dim, keep_dim): @pytest.mark.parametrize( - "input_shape, dim, keep_dim, converted", + "input_shape, dim, keep_dim", [ - ((1, 2, 32, 32), 3, True, True), - ((1, 1, 32, 32), 2, True, True), - ((1, 2, 32, 32), 1, True, True), - ((2, 1, 32, 32), 0, True, True), - ((2, 1, 32, 32), 0, True, True), - # TODO(#244): Unexpected output shape [1, 1, 2, 1] - ((1, 1, 2, 32, 32), -1, True, False), - # TODO(#244): Not support keep_dim = False - ((1, 2, 32, 32), 3, False, False), - # TODO(#244): dim >= -4 && dim <= 3 && "Dimension out of range (expected to be in range of [-4, 3]" - ((1, 1, 1, 32, 32), 4, True, False), - # TODO(#244): Need to pad with 1.0 instead of 0 - ((1, 1, 32, 16), -1, True, False), - # TODO(#244): Input rank can't < 4 - ((32, 32), 1, True, False), + ((1, 2, 32, 32), 3, True), + ((1, 1, 32, 32), 2, True), + ((1, 2, 32, 32), 1, True), + ((2, 1, 32, 32), 0, True), + ((2, 1, 32, 32), 0, True), + pytest.param( + (1, 1, 2, 32, 32), -1, True, marks=pytest.mark.xfail(reason="Unexpected output shape [1, 1, 2, 1] (#244)") + ), + pytest.param((1, 2, 32, 32), 3, False, marks=pytest.mark.xfail(reason="Not support keep_dim = False (#244)")), + pytest.param( + (1, 1, 1, 32, 32), + 4, + True, + marks=pytest.mark.xfail( + reason='dim >= -4 && dim <= 3 && "Dimension out of range (expected to be in range of [-4, 3]" (#244)' + ), + ), + pytest.param( + (1, 1, 32, 16), -1, True, marks=pytest.mark.xfail(reason="Need to pad with 1.0 instead of 0 (#244)") + ), + pytest.param((32, 32), 1, True, marks=pytest.mark.xfail(reason="Input rank can't < 4")), ], ) -def test_prod_dim_int(device, input_shape, dim, keep_dim, converted): +def test_prod_dim_int(device, input_shape, dim, keep_dim): m = ProdDimIntModule() input = torch.rand(input_shape, dtype=torch.bfloat16) + 0.5 result_before = m.forward(input, dim, keep_dim) @@ -45,7 +51,7 @@ def test_prod_dim_int(device, input_shape, dim, keep_dim, converted): # Check the graph has be rewritten nodes = list(option._out_fx_graphs[0].nodes) - assert [node.target for node in nodes].count(ttnn.prod) == (1 if converted else 0) + assert [node.target for node in nodes].count(ttnn.prod) == 1 # Check inference result assert result_before.shape == result_after.shape # Give higher tolerance for product as it's not associative with float