Skip to content

Commit

Permalink
Broadcast minimum input
Browse files Browse the repository at this point in the history
  • Loading branch information
swimdi committed Dec 11, 2024
1 parent cce8552 commit a6822d3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 14 deletions.
16 changes: 4 additions & 12 deletions tests/lowering/eltwise/binary/test_minimum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,10 @@ def forward(self, x, y):
"input_shapes",
(
((32, 32), (32, 32)),
pytest.param(
((64,), (32, 64)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
pytest.param(
((64, 32), (64, 1)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
pytest.param(
((64, 1), (1, 64)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
((64,), (32, 64)),
((64, 32), (64, 1)),
((64, 1), (1, 64)),
((1, 16, 59, 59), ()),
),
)
def test_minimum(device, input_shapes):
Expand Down
4 changes: 2 additions & 2 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ def rewrite_node(node):
kwargs = node.kwargs

# workaround for issue #64
if node.target == torch.ops.aten.maximum.default:
if node.target in [torch.ops.aten.maximum.default, torch.ops.aten.minimum.default]:
self_tensor = args[0]
if len(args) > 1:
other_tensor = args[1]
Expand All @@ -1179,7 +1179,7 @@ def rewrite_node(node):
if get_shape(self_tensor) is None or get_shape(other_tensor) is None:
return None
broadcasted_shape, broadcasted_tensors = broadcast_tensors(g, [self_tensor, other_tensor])
return g.call_function(torch.ops.aten.maximum.default, tuple(broadcasted_tensors))
return g.call_function(node.target, tuple(broadcasted_tensors))

with g.inserting_before(node):
new_node = rewrite_node(node)
Expand Down

0 comments on commit a6822d3

Please sign in to comment.