Skip to content

Commit

Permalink
Add repeat op support from Forge to TTIR Lowering (#1214)
Browse files Browse the repository at this point in the history
### Ticket
#1215

### Problem description
Add repeat op support from Forge to TTIR Lowering

### What's changed
- Added repeat op support from Forge to TTIR Lowering
- Added support for `DenseI64ArrayAttr` TargetType mapping in Repeat op
  • Loading branch information
ashokkumarkannan1 authored Feb 26, 2025
1 parent 487d25d commit 33258a7
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 8 deletions.
7 changes: 7 additions & 0 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ enum class TargetType
SourceType,
UInt32,
Int64,
DenseI64ArrayAttr,
};

struct AttributeRemap
Expand Down Expand Up @@ -106,6 +107,7 @@ class AttributeMapper
add_op_mapping("repeat_interleave", "repeats", AttributeRemap(std::nullopt, TargetType::UInt32));
add_op_mapping("reduce_avg", "dim", AttributeRemap("dim_arg"));
add_op_mapping("cumsum", "dim", AttributeRemap(std::nullopt, TargetType::Int64));
add_op_mapping("repeat", "repeats", AttributeRemap("repeat_dimensions", TargetType::DenseI64ArrayAttr));

// Add more default mappings here
}
Expand Down Expand Up @@ -237,6 +239,10 @@ class MLIRGenerator
TT_ASSERT(std::get<int>(value) >= 0, "Value must be an >= 0 for conversion to uint32");
return builder_.getUI32IntegerAttr(static_cast<uint32_t>(std::get<int>(value)));
case TargetType::Int64: return builder_.getI64IntegerAttr(static_cast<int64_t>(std::get<int>(value)));

case TargetType::DenseI64ArrayAttr:
return builder_.getDenseI64ArrayAttr(std::vector<int64_t>(
std::get<std::vector<int>>(value).begin(), std::get<std::vector<int>>(value).end()));
default:
// If type not handled, throw an exception
throw std::runtime_error("Unhandled target type conversion");
Expand Down Expand Up @@ -636,6 +642,7 @@ class MLIRGenerator
lowering_handler_map["remainder"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::RemainderOp>;
lowering_handler_map["repeat_interleave"] =
&MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::RepeatInterleaveOp>;
lowering_handler_map["repeat"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::RepeatOp>;
lowering_handler_map["reshape"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReshapeOp>;
lowering_handler_map["select"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SelectOp>;
lowering_handler_map["sigmoid"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SigmoidOp>;
Expand Down
24 changes: 22 additions & 2 deletions forge/forge/op/eval/forge/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def eval(type, attr, ops):

if type == "repeat":
sizes = attr
assert len(t_ops[0].shape) == len(sizes)
return t_ops[0].repeat(*sizes)

if type == "repeat_interleave":
Expand Down Expand Up @@ -486,7 +485,15 @@ def shape(type, attr, ops):

if type == "repeat":
sizes = attr
return tuple(dim * size for dim, size in zip(list(ops[0]), sizes)), []
if len(ops[0]) < len(sizes):
# Scenario: When the input is a 1D tensor and needs to be repeated in 2D,
# `ttir.repeat` does not currently support this directly,
# so we are calculating the new shape by expanding the dimensions
# to match repeat attr dimensions and calculate the output shape
shape = (1,) * (len(sizes) - len(ops[0])) + tuple(ops[0])
else:
shape = ops[0]
return tuple(dim * size for dim, size in zip(list(shape), sizes)), []

if type == "repeat_interleave":
assert len(attr) <= 3, "repeat_interleave should have two attributes - repeats and dim"
Expand Down Expand Up @@ -1379,6 +1386,19 @@ def decompose(type, attr, dc, inputs):
rank -= 1
dc.fuse(result)
return
if type == "repeat":
input_shape = inputs[0].shape.as_list()
target_shape = attr
result = inputs[0]

if len(input_shape) < len(target_shape):
# Scenario: When the input is a 1D tensor and needs to be repeated in 2D,
# `ttir.repeat` does not currently support this directly.
# To handle this, we first reshape the input to ensure both the input and the repeats have the same dimensions
new_shape = (1,) * (len(target_shape) - len(input_shape)) + tuple(input_shape)
result = dc.op("reshape", [result], new_shape)
result = dc.op_with_named_attrs("repeat", [result], {"repeats": target_shape}, target_shape)
dc.fuse(result)


def create_row_picker_matrix(col_indices, lhs_num_cols, lhs_num_channels=None, lhs_batch_size=None):
Expand Down
1 change: 0 additions & 1 deletion forge/forge/op/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,6 @@ def Repeat(name: str, operandA: Tensor, repeats: List[int]) -> Tensor:
Tensor
Forge tensor
"""
assert len(operandA.shape) == len(repeats)
return op("repeat", name, operandA, attrs=repeats, repeats=repeats).get_tensor()


Expand Down
27 changes: 22 additions & 5 deletions forge/test/mlir/operators/tm/test_tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,11 +562,28 @@ def forward(self, *tensors):
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out


@pytest.mark.xfail(
reason="RuntimeError: Found Unsupported operations while lowering from TTForge to TTIR in forward graph - repeat"
@pytest.mark.parametrize(
["input_shape", "repeats"],
[
pytest.param((1, 2), (10, 1)),
pytest.param((1, 99), (100, 1)),
pytest.param(
(1, 100),
(50, 2),
marks=pytest.mark.xfail(reason="info:Incompatible dimensions 200 and 100"),
),
pytest.param(
(3,),
(4, 2),
marks=pytest.mark.xfail(reason="info:Incompatible dimensions 6 and 3"),
),
pytest.param((4, 1, 4), (1, 10, 1)),
pytest.param((2, 2, 1, 2), (1, 1, 4, 1)),
pytest.param((1, 4, 1, 4, 4), (1, 1, 3, 1, 1)),
],
)
@pytest.mark.push
def test_repeat():
def test_repeat(input_shape, repeats):
class Repeat(nn.Module):
def __init__(self, repeats):
super().__init__()
Expand All @@ -575,9 +592,9 @@ def __init__(self, repeats):
def forward(self, x):
return x.repeat(*self.repeats)

inputs = [torch.rand(1, 2, 1, 4, 4)]
inputs = [torch.rand(input_shape)]

framework_model = Repeat(repeats=(1, 1, 4, 1, 1))
framework_model = Repeat(repeats=repeats)
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model)
Expand Down

0 comments on commit 33258a7

Please sign in to comment.