Skip to content

Commit

Permalink
Broadcast maximum input, remove aten_maximum_default_blocklist
Browse files Browse the repository at this point in the history
  • Loading branch information
swimdi committed Dec 11, 2024
1 parent 95332f8 commit cce8552
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 21 deletions.
16 changes: 4 additions & 12 deletions tests/lowering/eltwise/binary/test_maximum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,10 @@ def forward(self, x, y):
"input_shapes",
(
((32, 32), (32, 32)),
pytest.param(
((64,), (32, 64)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
pytest.param(
((64, 32), (64, 1)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
pytest.param(
((64, 1), (1, 64)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
((64,), (32, 64)),
((64, 32), (64, 1)),
((64, 1), (1, 64)),
((1, 16, 59, 59), ()),
),
)
def test_maximum(device, input_shapes):
Expand Down
3 changes: 2 additions & 1 deletion torch_ttnn/passes/lowering/to_tt_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@
# ttnn.from_torch not support scalar
# RuntimeError: TT_FATAL @ tensor/types.cpp:209: normalized_index >= 0 and normalized_index < rank
# not lowering ttnn.maximum to avoid ttnn.from_torch of scalar
aten_maximum_default_blocklist += [["Tensor<[1, 16, 59, 59]> self = ?", "Tensor other = ?"]]
aten_maximum_default_blocklist = [["Tensor<[1, 16, 59, 59]> self = ?", "Tensor other = ?"]]

# torch._dynamo.exc.BackendCompilerFailed: backend='ttnn_backend' raised:
# RuntimeError: aten::clone() Expected a value of type 'Tensor' for argument 'self' but instead found type 'SymInt'.
Expand Down Expand Up @@ -318,6 +318,7 @@
GUARD[torch.ops.aten.gt.Scalar] = partial(guard_aten, aten_gt_Scalar_blocklist)
GUARD[torch.ops.aten.unsqueeze.default] = partial(guard_aten, aten_unsqueeze_default_blocklist)
GUARD[torch.ops.aten.cumsum.default] = partial(guard_aten, aten_cumsum_default_blocklist)
GUARD[torch.ops.aten.maximum.default] = partial(guard_aten, aten_maximum_default_blocklist)


def can_lowering_to_ttnn(node):
Expand Down
6 changes: 0 additions & 6 deletions torch_ttnn/passes/lowering/to_tt_guard_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,6 @@
["Tensor<[16, 1, 1]> self = ?", "Optional[number] min = ?", "Optional[number] max = 4.605170185988092"],
["Tensor<[32, 1, 1]> self = ?", "Optional[number] min = ?", "Optional[number] max = 4.605170185988092"],
]
aten_maximum_default_blocklist = [
["Tensor<[1, 16, 19, 19]> self = ?", "Tensor other = ?"],
["Tensor<[1, 16, 59, 59]> self = ?", "Tensor<[]> other = ?"],
["Tensor<[1, 16, 1, 60]> self = ?", "Tensor<[]> other = ?"],
]
aten__log_softmax_default_blocklist = [["Tensor<[19, 256008]> self = ?", "int dim = 1", "bool half_to_float = False"]]
aten_full_default_blocklist = [
[
Expand Down Expand Up @@ -1498,7 +1493,6 @@ def guard_aten(blocklist, node):

GUARD = {
torch.ops.aten.clamp.default: partial(guard_aten, aten_clamp_default_blocklist),
torch.ops.aten.maximum.default: partial(guard_aten, aten_maximum_default_blocklist),
torch.ops.aten._log_softmax.default: partial(guard_aten, aten__log_softmax_default_blocklist),
torch.ops.aten.full.default: partial(guard_aten, aten_full_default_blocklist),
torch.ops.aten._scaled_dot_product_flash_attention.default: partial(
Expand Down
67 changes: 65 additions & 2 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import ttnn
import math
import numpy as np
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch_ttnn.utils import (
GraphCleanup,
Expand All @@ -17,6 +18,7 @@

from torch.fx.passes.infra.pass_base import PassBase, PassResult
import torch.fx.traceback as fx_traceback
from torch._subclasses.fake_tensor import FakeTensorMode
from . import target_wrappers
from .to_tt_guard import can_lowering_to_ttnn

Expand Down Expand Up @@ -444,13 +446,19 @@ def __init__(self, node):
self.g = node.graph
self.node = node

def call_function(self, target, args=(), kwargs={}):
def call_function(self, target, args=(), kwargs={}, new_shape=None, new_dtype=None):
new_node = self.g.call_function(target, args, kwargs)
new_node.meta = self.node.meta
new_node.meta = self.node.meta.copy()
if hasattr(self.node.target, "_schema"):
new_node.meta["original_input_variations"] = metrics.collect_input_variation_from_node(self.node)
if target == ttnn.layer_norm:
new_node.meta["val"] = new_node.meta["val"][0]
if new_shape is not None or new_dtype is not None:
shape = new_shape if new_shape is not None else new_node.meta["val"].size()
dtype = new_dtype if new_dtype is not None else new_node.meta["val"].dtype
fake_mode = FakeTensorMode()
fake_tensor = fake_mode.from_tensor(torch.zeros(shape, dtype=dtype))
new_node.meta["val"] = fake_tensor
return new_node

def inserting_before(self, node):
Expand Down Expand Up @@ -612,6 +620,8 @@ def batch_norm_inference(input, weight, bias, mean, var, momentum, eps):
if not (hasattr(node, "meta") and "val" in node.meta and hasattr(node.meta["val"], "size")):
return None
input_tensor_shape = args[0].meta["val"].size()
if input_tensor_shape == torch.Size([]):
input_tensor_shape = torch.Size([1])
output_shape = node.meta["val"].size()
if input_tensor_shape.numel() == output_shape.numel():
if input_tensor_shape != output_shape:
Expand Down Expand Up @@ -1131,12 +1141,65 @@ def batch_norm_inference(input, weight, bias, mean, var, momentum, eps):
return gm


def broadcast_tensors(g, tensors):
tensors_shapes = [get_shape(tensors[i]) for i in range(len(tensors))]
broadcasted_shape = torch.Size(np.broadcast_shapes(*tensors_shapes))
broadcasted_tensors = []
for i in range(len(tensors)):
if tensors_shapes[i] == broadcasted_shape:
broadcasted_tensors.append(tensors[i])
else:
broadcasted_tensors.append(
g.call_function(
torch.ops.aten.expand.default,
(tensors[i], broadcasted_shape),
new_shape=broadcasted_shape,
new_dtype=tensors[i].meta["val"].dtype,
)
)
return broadcasted_shape, broadcasted_tensors


def DigestAtenOps(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
nodes = list(gm.graph.nodes)
for node in nodes:
g = GraphWrapper(node)

def rewrite_node(node):
args = node.args
kwargs = node.kwargs

# workaround for issue #64
if node.target == torch.ops.aten.maximum.default:
self_tensor = args[0]
if len(args) > 1:
other_tensor = args[1]
else:
other_tensor = kwargs["other"]
if get_shape(self_tensor) is None or get_shape(other_tensor) is None:
return None
broadcasted_shape, broadcasted_tensors = broadcast_tensors(g, [self_tensor, other_tensor])
return g.call_function(torch.ops.aten.maximum.default, tuple(broadcasted_tensors))

with g.inserting_before(node):
new_node = rewrite_node(node)
if new_node is not None:
node.replace_all_uses_with(
new_node,
delete_user_cb=lambda node: node != new_node,
)

gm = GraphCleanup(gm)
return gm


class ToTtPass(PassBase):
def __init__(self, device, use_less_ttnn_op_types):
self.device = device
self.use_less_ttnn_op_types = use_less_ttnn_op_types

def call(self, gm: torch.fx.GraphModule):
gm = DigestAtenOps(gm)
# Replace more patterns with torch.fx.Transformer
gm = ReplaceMoreTt(gm, self.device, self.use_less_ttnn_op_types).transform()

Expand Down

0 comments on commit cce8552

Please sign in to comment.