Skip to content

Commit

Permalink
Uplift third_party/tt-mlir to origin/main 2025-03-02 (#1354)
Browse files Browse the repository at this point in the history
This PR uplifts the third_party/tt-mlir submodule to the origin/main and
fixes problems that occured during uplift.
  • Loading branch information
mstojkovicTT authored Mar 4, 2025
1 parent 9658e75 commit 42065c6
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 117 deletions.
23 changes: 17 additions & 6 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ using namespace tt;
enum class TargetType
{
SourceType,
UInt32,
Int64,
UI32Attr,
I64Attr,
I32Attr,
DenseI64ArrayAttr,
DenseI32ArrayAttr,
};

struct AttributeRemap
Expand Down Expand Up @@ -104,9 +106,14 @@ class AttributeMapper

void initialize_default_mappings()
{
add_op_mapping("repeat_interleave", "repeats", AttributeRemap(std::nullopt, TargetType::UInt32));
// Sort the mappings in lexicographical order
add_op_mapping("conv2d", "dilation", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr));
add_op_mapping("conv2d", "groups", AttributeRemap(std::nullopt, TargetType::I32Attr));
add_op_mapping("conv2d", "padding", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr));
add_op_mapping("conv2d", "stride", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr));
add_op_mapping("cumsum", "dim", AttributeRemap(std::nullopt, TargetType::I64Attr));
add_op_mapping("reduce_avg", "dim", AttributeRemap("dim_arg"));
add_op_mapping("cumsum", "dim", AttributeRemap(std::nullopt, TargetType::Int64));
add_op_mapping("repeat_interleave", "repeats", AttributeRemap(std::nullopt, TargetType::UI32Attr));
add_op_mapping("repeat", "repeats", AttributeRemap("repeat_dimensions", TargetType::DenseI64ArrayAttr));

// Add more default mappings here
Expand Down Expand Up @@ -235,14 +242,18 @@ class MLIRGenerator
// Convert the attribute to the target type
switch (target_type)
{
case TargetType::UInt32:
case TargetType::UI32Attr:
TT_ASSERT(std::get<int>(value) >= 0, "Value must be an >= 0 for conversion to uint32");
return builder_.getUI32IntegerAttr(static_cast<uint32_t>(std::get<int>(value)));
case TargetType::Int64: return builder_.getI64IntegerAttr(static_cast<int64_t>(std::get<int>(value)));
case TargetType::I32Attr: return builder_.getI32IntegerAttr(static_cast<int32_t>(std::get<int>(value)));
case TargetType::I64Attr: return builder_.getI64IntegerAttr(static_cast<int64_t>(std::get<int>(value)));

case TargetType::DenseI64ArrayAttr:
return builder_.getDenseI64ArrayAttr(std::vector<int64_t>(
std::get<std::vector<int>>(value).begin(), std::get<std::vector<int>>(value).end()));
case TargetType::DenseI32ArrayAttr:
return builder_.getDenseI32ArrayAttr(std::vector<int32_t>(
std::get<std::vector<int>>(value).begin(), std::get<std::vector<int>>(value).end()));
default:
// If type not handled, throw an exception
throw std::runtime_error("Unhandled target type conversion");
Expand Down
3 changes: 2 additions & 1 deletion forge/csrc/runtime/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ static target::DataType torch_scalar_type_to_dt(torch::ScalarType st)
case torch::ScalarType::Byte: return target::DataType::UInt8;
case torch::ScalarType::Char: return target::DataType::UInt8;
case torch::ScalarType::Short: return target::DataType::UInt16;
case torch::ScalarType::Int: return target::DataType::UInt32;
case torch::ScalarType::Int: return target::DataType::Int32;
case torch::ScalarType::Long: return target::DataType::UInt32;
case torch::ScalarType::Half: return target::DataType::Float16;
case torch::ScalarType::Float: return target::DataType::Float32;
Expand All @@ -44,6 +44,7 @@ static torch::ScalarType dt_to_torch_scalar_type(target::DataType df)
case target::DataType::UInt8: return torch::ScalarType::Byte;
case target::DataType::UInt16: return torch::ScalarType::Short;
case target::DataType::UInt32: return torch::ScalarType::Int;
case target::DataType::Int32: return torch::ScalarType::Int;
case target::DataType::Float16: return torch::ScalarType::Half;
case target::DataType::Float32: return torch::ScalarType::Float;
case target::DataType::BFloat16: return torch::ScalarType::BFloat16;
Expand Down
11 changes: 3 additions & 8 deletions forge/csrc/test/passes/test_erase_inverse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,14 +503,9 @@ struct UpdateConvAttrsTest : testing::Test
*graph, "conv2d", "conv2d", {3, 3, 1, 1, 0, 0, 1, 1, 1}, {input_node_0, weight_node});
auto named_attrs = conv_node->named_attrs();
named_attrs["channel_last"] = false;
named_attrs["padding_top"] = 1;
named_attrs["padding_bottom"] = 1;
named_attrs["padding_left"] = 1;
named_attrs["padding_right"] = 1;
named_attrs["stride_height"] = 1;
named_attrs["stride_width"] = 1;
named_attrs["dilation_height"] = 1;
named_attrs["dilation_width"] = 1;
named_attrs["padding"] = std::vector<int>{1, 1, 1, 1};
named_attrs["stride"] = std::vector<int>{1, 1};
named_attrs["dilation"] = std::vector<int>{1, 1};

conv_node->overwrite_named_attrs(named_attrs);
create_output(*graph, "out", conv_node);
Expand Down
22 changes: 6 additions & 16 deletions forge/forge/op/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,10 @@ def Conv2d(
"conv2d",
name,
*inputs,
stride_height=stride[0],
stride_width=stride[1],
dilation_height=dilation[0],
dilation_width=dilation[1],
stride=stride,
dilation=dilation,
groups=groups,
padding_left=padding[0],
padding_right=padding[1],
padding_top=padding[2],
padding_bottom=padding[3],
padding=padding,
channel_last=channel_last,
).get_tensor()

Expand Down Expand Up @@ -138,15 +133,10 @@ def Conv2dTranspose(
name,
*inputs,
attrs=attrs,
stride_height=stride[0],
stride_width=stride[1],
dilation_height=dilation,
dilation_width=dilation,
stride=stride, # [sH, sW]
dilation=dilation, # [dH, dW]
groups=groups,
padding_top=padding[0],
padding_left=padding[1],
padding_bottom=padding[2],
padding_right=padding[3],
padding=padding, # [pT, pL, pB, pR]
channel_last=channel_last,
).get_tensor()

Expand Down
120 changes: 51 additions & 69 deletions forge/forge/op/eval/forge/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,18 @@ class Conv2d(PyOp):
@classmethod
def create(
cls,
stride_height,
stride_width,
dilation_height,
dilation_width,
stride, # Input format: [dH, dW]
dilation, # Input format: [dH, dW]
groups,
padding_left,
padding_right,
padding_top,
padding_bottom,
padding, # Input format: [pL, pR, pT, pB]
channel_last,
):
self = cls("conv2d")
self.stride_height = stride_height
self.stride_width = stride_width
self.dilation_height = dilation_height
self.dilation_width = dilation_width
self.stride = stride
self.groups = groups
self.padding_left = padding_left
self.padding_right = padding_right
self.padding_top = padding_top
self.padding_bottom = padding_bottom
# Transform padding from [pL, pR, pT, pB] to [pT, pL, pB, pR]
self.padding = [padding[2], padding[0], padding[3], padding[1]]
self.dilation = dilation
self.channel_last = int(channel_last)
return self

Expand All @@ -55,15 +46,10 @@ def eval(self, tensors):
weights = t_ops[1]
bias = t_ops[2] if len(t_ops) == 3 else None

stride = [self.stride_height, self.stride_width]
dilation = [self.dilation_height, self.dilation_width]
groups = self.groups
padding = [
self.padding_left,
self.padding_right,
self.padding_top,
self.padding_bottom,
]
padding = self.padding
stride = self.stride
dilation = self.dilation

channel_last = self.channel_last
if channel_last:
Expand Down Expand Up @@ -105,11 +91,15 @@ def shape(self, tensor_shapes):
h_in = act[-3] if self.channel_last else act[-2]
w_in = act[-2] if self.channel_last else act[-1]

h_numerator = h_in + (self.padding_top + self.padding_bottom) - self.dilation_height * (weight[-2] - 1) - 1
h_out = math.floor(1 + (h_numerator / self.stride_height))
padding_top, padding_left, padding_bottom, padding_right = self.padding
dilation_height, dilation_width = self.dilation
stride_height, stride_width = self.stride

w_numerator = w_in + (self.padding_left + self.padding_right) - self.dilation_width * (weight[-1] - 1) - 1
w_out = math.floor(1 + (w_numerator / self.stride_width))
h_numerator = h_in + (padding_top + padding_bottom) - dilation_height * (weight[-2] - 1) - 1
h_out = math.floor(1 + (h_numerator / stride_height))

w_numerator = w_in + (padding_left + padding_right) - dilation_width * (weight[-1] - 1) - 1
w_out = math.floor(1 + (w_numerator / stride_width))

out_shape = [batch_size, h_out, w_out, cout] if self.channel_last else [batch_size, cout, h_out, w_out]

Expand Down Expand Up @@ -149,15 +139,10 @@ def decompose(self, dc, inputs):
new_inputs = [activations, weight] if bias is None else [activations, weight, bias]
result = dc.op(
Conv2d.create(
self.stride_height,
self.stride_width,
self.dilation_height,
self.dilation_width,
self.stride,
self.dilation,
self.groups,
self.padding_left,
self.padding_right,
self.padding_top,
self.padding_bottom,
self.padding,
True, # If the original Conv2d was channel-last, that will not change.
# If it was channel-first, it the input will have been permuted by this point.
# So, the Conv2d op being created here is certainly channel-last.
Expand Down Expand Up @@ -197,27 +182,17 @@ class Conv2dTranspose(PyOp):
@classmethod
def create(
cls,
stride_height,
stride_width,
dilation_height,
dilation_width,
stride, # Input format: [sH, sW]
dilation, # Input format: [dH, dW]
groups,
padding_left,
padding_right,
padding_top,
padding_bottom,
padding, # Input format: [pL, pR, pT, pB]
channel_last,
):
self = cls("conv2d_transpose")
self.stride_height = stride_height
self.stride_width = stride_width
self.dilation_height = dilation_height
self.dilation_width = dilation_width
self.stride = stride
self.dilation = dilation
self.groups = groups
self.padding_left = padding_left
self.padding_right = padding_right
self.padding_top = padding_top
self.padding_bottom = padding_bottom
self.padding = padding
self.channel_last = int(channel_last)
return self

Expand All @@ -229,10 +204,13 @@ def eval(self, tensors):
weights = t_ops[1]
bias = t_ops[2] if len(t_ops) == 3 else None

stride = [self.stride_height, self.stride_width]
dilation = [self.dilation_height, self.dilation_width]
stride = self.stride
dilation = self.dilation
groups = self.groups
padding = (self.padding_top, self.padding_left)
padding = (
self.padding[2],
self.padding[0],
) # [pT, pL] not sure why padding only has two elements (meenakshiramanathan1 PR #826)

channel_last = self.channel_last
if channel_last:
Expand Down Expand Up @@ -262,6 +240,15 @@ def eval(self, tensors):
return result

def shape(self, tensor_shapes):
stride_height = self.stride[0]
stride_width = self.stride[1]
dilation_height = self.dilation[0]
dilation_width = self.dilation[1]
padding_left = self.padding[0]
padding_right = self.padding[1]
padding_top = self.padding[2]
padding_bottom = self.padding[3]

act, weight = tensor_shapes[:2]
batch_size = act[0]
cout = weight[1] * self.groups
Expand All @@ -273,16 +260,16 @@ def shape(self, tensor_shapes):
output_padding_width = 0

h_out = (
(h_in - 1) * self.stride_height
- (self.padding_top + self.padding_bottom)
+ self.dilation_height * (weight[-2] - 1)
(h_in - 1) * stride_height
- (padding_top + padding_bottom)
+ dilation_height * (weight[-2] - 1)
+ output_padding_height
+ 1
)
w_out = (
(w_in - 1) * self.stride_width
- (self.padding_left + self.padding_right)
+ self.dilation_width * (weight[-1] - 1)
(w_in - 1) * stride_width
- (padding_left + padding_right)
+ dilation_width * (weight[-1] - 1)
+ output_padding_width
+ 1
)
Expand Down Expand Up @@ -323,15 +310,10 @@ def decompose(self, dc, inputs):
new_inputs = [activations, weight] if bias is None else [activations, weight, bias]
result = dc.op(
Conv2dTranspose.create(
self.stride_height,
self.stride_width,
self.dilation_height,
self.dilation_width,
self.stride,
self.dilation,
self.groups,
self.padding_left,
self.padding_right,
self.padding_top,
self.padding_bottom,
self.padding,
True, # If the original Conv2dTranspose was channel-last, that will not change.
# If it was channel-first, it the input will have been permuted by this point.
# So, the Conv2dTranspose op being created here is certainly channel-last.
Expand Down
22 changes: 6 additions & 16 deletions forge/forge/op/eval/forge/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,28 +587,18 @@ def decompose(type, attr, dc, inputs):
weight = dc.tensor(weight_tensor)
result = dc.op_with_named_attrs(
Conv2d.create(
stride_height=stride[0],
stride_width=stride[1],
dilation_height=dilation,
dilation_width=dilation,
stride=stride,
dilation=[dilation, dilation],
groups=cin,
padding_left=padding[0],
padding_right=padding[1],
padding_top=padding[2],
padding_bottom=padding[3],
padding=padding,
channel_last=channel_last,
),
[activations, weight],
{
"stride_height": stride[0],
"stride_width": stride[1],
"dilation_height": dilation,
"dilation_width": dilation,
"stride": stride,
"dilation": [dilation, dilation],
"groups": cin,
"padding_left": padding[0],
"padding_right": padding[1],
"padding_top": padding[2],
"padding_bottom": padding[3],
"padding": padding,
"channel_last": channel_last,
},
)
Expand Down
2 changes: 1 addition & 1 deletion third_party/tt-mlir
Submodule tt-mlir updated 450 files

0 comments on commit 42065c6

Please sign in to comment.