Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert aten.sum.dim_IntList to ttnn.sum #264

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tests/lowering/reduction/test_sum_dim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
import torch_ttnn
import pytest
import ttnn

from tests.utils import assert_with_pcc


class SumDimModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, dim, keepdim=False):
return torch.sum(input, dim, keepdim=keepdim)


@pytest.mark.parametrize(
"input_shape, dim",
(
((1, 32, 32), (-1, -2)),
((1, 1, 768), (0, 1)),
((1, 1000), (0,)),
((1, 1024, 256), (0, 1)),
((1, 1024, 7, 7), (2, 3)),
((1, 12, 16), (1,)),
((1, 12, 16), (2,)),
((1, 512), (1,)),
((1, 64), (0,)),
((1024, 160), (0,)),
((1024, 640), (0,)),
((14, 2048), (0,)),
((14, 512), (0,)),
((16384, 128), (0,)),
((16384, 32), (0,)),
((197, 1024), (0,)),
((197, 3072), (0,)),
((197, 4096), (0,)),
((197, 768), (0,)),
((2, 512), (1,)),
((2, 7, 512), (0,)),
((50, 768), (0,)),
((768, 196), (0,)),
),
)
def test_sum_dim(device, input_shape, dim):
m = SumDimModule()
input = torch.empty(input_shape, dtype=torch.bfloat16).uniform_(-1, 1)
keepdim = True
result_before = m.forward(input, dim, keepdim)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input, dim, keepdim)
option._out_fx_graphs[0].print_tabular()

# Check the graph has been rewritten and contains ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
trivial = all(input_shape[n] == 1 for n in dim)
assert [node.target for node in nodes].count(ttnn.sum) >= (not trivial)
# Check inference result
assert_with_pcc(result_before, result_after)
5 changes: 3 additions & 2 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,14 @@ def is_tt_compute(node) -> bool:
ttnn.zeros_like,
ttnn.mean,
ttnn.moreh_cumsum,
ttnn.sum,
ttnn.global_avg_pool2d,
ttnn.clip,
ttnn.squeeze,
ttnn.unsqueeze,
ttnn.full,
ttnn.as_tensor,
ttnn.expand,
ttnn.moreh_cumsum,
ttnn.sum,
ttnn.typecast,
ttnn.argmax,
]
Expand Down
17 changes: 17 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,23 @@ def reshape_1d(code, args=args, kwargs=kwargs):
else:
return None

if node.target == torch.ops.aten.sum.dim_IntList:
tensor, dims, keepdim = args

if (shape := get_shape(gm, tensor)) is not None:
dims = (n if n >= 0 else len(shape) + n for n in dims)
dims = [n for n in dims if shape[n] > 1]

if len(dims) == 0:
return tensor

tensor = g.call_function(ttnn.sum, (tensor, dims))

if not keepdim:
tensor = g.call_function(ttnn.squeeze, (tensor, dims))

return tensor

if node.target == torch.ops.aten.select.int:
tensor, dim, start = args

Expand Down
Loading