diff --git a/forge/csrc/passes/lower_to_mlir.cpp b/forge/csrc/passes/lower_to_mlir.cpp index 9f3304933..501e5f6a7 100644 --- a/forge/csrc/passes/lower_to_mlir.cpp +++ b/forge/csrc/passes/lower_to_mlir.cpp @@ -54,9 +54,11 @@ using namespace tt; enum class TargetType { SourceType, - UInt32, - Int64, + UI32Attr, + I64Attr, + I32Attr, DenseI64ArrayAttr, + DenseI32ArrayAttr, }; struct AttributeRemap @@ -104,9 +106,14 @@ class AttributeMapper void initialize_default_mappings() { - add_op_mapping("repeat_interleave", "repeats", AttributeRemap(std::nullopt, TargetType::UInt32)); + // Sort the mappings in lexicographical order + add_op_mapping("conv2d", "dilation", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr)); + add_op_mapping("conv2d", "groups", AttributeRemap(std::nullopt, TargetType::I32Attr)); + add_op_mapping("conv2d", "padding", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr)); + add_op_mapping("conv2d", "stride", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr)); + add_op_mapping("cumsum", "dim", AttributeRemap(std::nullopt, TargetType::I64Attr)); add_op_mapping("reduce_avg", "dim", AttributeRemap("dim_arg")); - add_op_mapping("cumsum", "dim", AttributeRemap(std::nullopt, TargetType::Int64)); + add_op_mapping("repeat_interleave", "repeats", AttributeRemap(std::nullopt, TargetType::UI32Attr)); add_op_mapping("repeat", "repeats", AttributeRemap("repeat_dimensions", TargetType::DenseI64ArrayAttr)); // Add more default mappings here @@ -235,14 +242,18 @@ class MLIRGenerator // Convert the attribute to the target type switch (target_type) { - case TargetType::UInt32: + case TargetType::UI32Attr: TT_ASSERT(std::get(value) >= 0, "Value must be an >= 0 for conversion to uint32"); return builder_.getUI32IntegerAttr(static_cast(std::get(value))); - case TargetType::Int64: return builder_.getI64IntegerAttr(static_cast(std::get(value))); + case TargetType::I32Attr: return builder_.getI32IntegerAttr(static_cast(std::get(value))); + case TargetType::I64Attr: return builder_.getI64IntegerAttr(static_cast(std::get(value))); case TargetType::DenseI64ArrayAttr: return builder_.getDenseI64ArrayAttr(std::vector( std::get>(value).begin(), std::get>(value).end())); + case TargetType::DenseI32ArrayAttr: + return builder_.getDenseI32ArrayAttr(std::vector( + std::get>(value).begin(), std::get>(value).end())); default: // If type not handled, throw an exception throw std::runtime_error("Unhandled target type conversion"); diff --git a/forge/csrc/runtime/runtime.cpp b/forge/csrc/runtime/runtime.cpp index ffaa6b84d..39c8b20bb 100644 --- a/forge/csrc/runtime/runtime.cpp +++ b/forge/csrc/runtime/runtime.cpp @@ -21,7 +21,7 @@ static target::DataType torch_scalar_type_to_dt(torch::ScalarType st) case torch::ScalarType::Byte: return target::DataType::UInt8; case torch::ScalarType::Char: return target::DataType::UInt8; case torch::ScalarType::Short: return target::DataType::UInt16; - case torch::ScalarType::Int: return target::DataType::UInt32; + case torch::ScalarType::Int: return target::DataType::Int32; case torch::ScalarType::Long: return target::DataType::UInt32; case torch::ScalarType::Half: return target::DataType::Float16; case torch::ScalarType::Float: return target::DataType::Float32; @@ -44,6 +44,7 @@ static torch::ScalarType dt_to_torch_scalar_type(target::DataType df) case target::DataType::UInt8: return torch::ScalarType::Byte; case target::DataType::UInt16: return torch::ScalarType::Short; case target::DataType::UInt32: return torch::ScalarType::Int; + case target::DataType::Int32: return torch::ScalarType::Int; case target::DataType::Float16: return torch::ScalarType::Half; case target::DataType::Float32: return torch::ScalarType::Float; case target::DataType::BFloat16: return torch::ScalarType::BFloat16; diff --git a/forge/csrc/test/passes/test_erase_inverse_ops.cpp b/forge/csrc/test/passes/test_erase_inverse_ops.cpp index b0fde2958..a2abb0def 100644 --- a/forge/csrc/test/passes/test_erase_inverse_ops.cpp +++ b/forge/csrc/test/passes/test_erase_inverse_ops.cpp @@ -503,14 +503,9 @@ struct UpdateConvAttrsTest : testing::Test *graph, "conv2d", "conv2d", {3, 3, 1, 1, 0, 0, 1, 1, 1}, {input_node_0, weight_node}); auto named_attrs = conv_node->named_attrs(); named_attrs["channel_last"] = false; - named_attrs["padding_top"] = 1; - named_attrs["padding_bottom"] = 1; - named_attrs["padding_left"] = 1; - named_attrs["padding_right"] = 1; - named_attrs["stride_height"] = 1; - named_attrs["stride_width"] = 1; - named_attrs["dilation_height"] = 1; - named_attrs["dilation_width"] = 1; + named_attrs["padding"] = std::vector{1, 1, 1, 1}; + named_attrs["stride"] = std::vector{1, 1}; + named_attrs["dilation"] = std::vector{1, 1}; conv_node->overwrite_named_attrs(named_attrs); create_output(*graph, "out", conv_node); diff --git a/forge/forge/op/convolution.py b/forge/forge/op/convolution.py index 901766ed2..5fd26b48e 100644 --- a/forge/forge/op/convolution.py +++ b/forge/forge/op/convolution.py @@ -58,15 +58,10 @@ def Conv2d( "conv2d", name, *inputs, - stride_height=stride[0], - stride_width=stride[1], - dilation_height=dilation[0], - dilation_width=dilation[1], + stride=stride, + dilation=dilation, groups=groups, - padding_left=padding[0], - padding_right=padding[1], - padding_top=padding[2], - padding_bottom=padding[3], + padding=padding, channel_last=channel_last, ).get_tensor() @@ -138,15 +133,10 @@ def Conv2dTranspose( name, *inputs, attrs=attrs, - stride_height=stride[0], - stride_width=stride[1], - dilation_height=dilation, - dilation_width=dilation, + stride=stride, # [sH, sW] + dilation=dilation, # [dH, dW] groups=groups, - padding_top=padding[0], - padding_left=padding[1], - padding_bottom=padding[2], - padding_right=padding[3], + padding=padding, # [pT, pL, pB, pR] channel_last=channel_last, ).get_tensor() diff --git a/forge/forge/op/eval/forge/convolution.py b/forge/forge/op/eval/forge/convolution.py index 8f417b52d..f7e04435b 100644 --- a/forge/forge/op/eval/forge/convolution.py +++ b/forge/forge/op/eval/forge/convolution.py @@ -22,27 +22,18 @@ class Conv2d(PyOp): @classmethod def create( cls, - stride_height, - stride_width, - dilation_height, - dilation_width, + stride, # Input format: [dH, dW] + dilation, # Input format: [dH, dW] groups, - padding_left, - padding_right, - padding_top, - padding_bottom, + padding, # Input format: [pL, pR, pT, pB] channel_last, ): self = cls("conv2d") - self.stride_height = stride_height - self.stride_width = stride_width - self.dilation_height = dilation_height - self.dilation_width = dilation_width + self.stride = stride self.groups = groups - self.padding_left = padding_left - self.padding_right = padding_right - self.padding_top = padding_top - self.padding_bottom = padding_bottom + # Transform padding from [pL, pR, pT, pB] to [pT, pL, pB, pR] + self.padding = [padding[2], padding[0], padding[3], padding[1]] + self.dilation = dilation self.channel_last = int(channel_last) return self @@ -55,15 +46,10 @@ def eval(self, tensors): weights = t_ops[1] bias = t_ops[2] if len(t_ops) == 3 else None - stride = [self.stride_height, self.stride_width] - dilation = [self.dilation_height, self.dilation_width] groups = self.groups - padding = [ - self.padding_left, - self.padding_right, - self.padding_top, - self.padding_bottom, - ] + padding = self.padding + stride = self.stride + dilation = self.dilation channel_last = self.channel_last if channel_last: @@ -105,11 +91,15 @@ def shape(self, tensor_shapes): h_in = act[-3] if self.channel_last else act[-2] w_in = act[-2] if self.channel_last else act[-1] - h_numerator = h_in + (self.padding_top + self.padding_bottom) - self.dilation_height * (weight[-2] - 1) - 1 - h_out = math.floor(1 + (h_numerator / self.stride_height)) + padding_top, padding_left, padding_bottom, padding_right = self.padding + dilation_height, dilation_width = self.dilation + stride_height, stride_width = self.stride - w_numerator = w_in + (self.padding_left + self.padding_right) - self.dilation_width * (weight[-1] - 1) - 1 - w_out = math.floor(1 + (w_numerator / self.stride_width)) + h_numerator = h_in + (padding_top + padding_bottom) - dilation_height * (weight[-2] - 1) - 1 + h_out = math.floor(1 + (h_numerator / stride_height)) + + w_numerator = w_in + (padding_left + padding_right) - dilation_width * (weight[-1] - 1) - 1 + w_out = math.floor(1 + (w_numerator / stride_width)) out_shape = [batch_size, h_out, w_out, cout] if self.channel_last else [batch_size, cout, h_out, w_out] @@ -149,15 +139,10 @@ def decompose(self, dc, inputs): new_inputs = [activations, weight] if bias is None else [activations, weight, bias] result = dc.op( Conv2d.create( - self.stride_height, - self.stride_width, - self.dilation_height, - self.dilation_width, + self.stride, + self.dilation, self.groups, - self.padding_left, - self.padding_right, - self.padding_top, - self.padding_bottom, + self.padding, True, # If the original Conv2d was channel-last, that will not change. # If it was channel-first, it the input will have been permuted by this point. # So, the Conv2d op being created here is certainly channel-last. @@ -197,27 +182,17 @@ class Conv2dTranspose(PyOp): @classmethod def create( cls, - stride_height, - stride_width, - dilation_height, - dilation_width, + stride, # Input format: [sH, sW] + dilation, # Input format: [dH, dW] groups, - padding_left, - padding_right, - padding_top, - padding_bottom, + padding, # Input format: [pL, pR, pT, pB] channel_last, ): self = cls("conv2d_transpose") - self.stride_height = stride_height - self.stride_width = stride_width - self.dilation_height = dilation_height - self.dilation_width = dilation_width + self.stride = stride + self.dilation = dilation self.groups = groups - self.padding_left = padding_left - self.padding_right = padding_right - self.padding_top = padding_top - self.padding_bottom = padding_bottom + self.padding = padding self.channel_last = int(channel_last) return self @@ -229,10 +204,13 @@ def eval(self, tensors): weights = t_ops[1] bias = t_ops[2] if len(t_ops) == 3 else None - stride = [self.stride_height, self.stride_width] - dilation = [self.dilation_height, self.dilation_width] + stride = self.stride + dilation = self.dilation groups = self.groups - padding = (self.padding_top, self.padding_left) + padding = ( + self.padding[2], + self.padding[0], + ) # [pT, pL] not sure why padding only has two elements (meenakshiramanathan1 PR #826) channel_last = self.channel_last if channel_last: @@ -262,6 +240,15 @@ def eval(self, tensors): return result def shape(self, tensor_shapes): + stride_height = self.stride[0] + stride_width = self.stride[1] + dilation_height = self.dilation[0] + dilation_width = self.dilation[1] + padding_left = self.padding[0] + padding_right = self.padding[1] + padding_top = self.padding[2] + padding_bottom = self.padding[3] + act, weight = tensor_shapes[:2] batch_size = act[0] cout = weight[1] * self.groups @@ -273,16 +260,16 @@ def shape(self, tensor_shapes): output_padding_width = 0 h_out = ( - (h_in - 1) * self.stride_height - - (self.padding_top + self.padding_bottom) - + self.dilation_height * (weight[-2] - 1) + (h_in - 1) * stride_height + - (padding_top + padding_bottom) + + dilation_height * (weight[-2] - 1) + output_padding_height + 1 ) w_out = ( - (w_in - 1) * self.stride_width - - (self.padding_left + self.padding_right) - + self.dilation_width * (weight[-1] - 1) + (w_in - 1) * stride_width + - (padding_left + padding_right) + + dilation_width * (weight[-1] - 1) + output_padding_width + 1 ) @@ -323,15 +310,10 @@ def decompose(self, dc, inputs): new_inputs = [activations, weight] if bias is None else [activations, weight, bias] result = dc.op( Conv2dTranspose.create( - self.stride_height, - self.stride_width, - self.dilation_height, - self.dilation_width, + self.stride, + self.dilation, self.groups, - self.padding_left, - self.padding_right, - self.padding_top, - self.padding_bottom, + self.padding, True, # If the original Conv2dTranspose was channel-last, that will not change. # If it was channel-first, it the input will have been permuted by this point. # So, the Conv2dTranspose op being created here is certainly channel-last. diff --git a/forge/forge/op/eval/forge/pooling.py b/forge/forge/op/eval/forge/pooling.py index fa50bee3d..6a471b37d 100644 --- a/forge/forge/op/eval/forge/pooling.py +++ b/forge/forge/op/eval/forge/pooling.py @@ -587,28 +587,18 @@ def decompose(type, attr, dc, inputs): weight = dc.tensor(weight_tensor) result = dc.op_with_named_attrs( Conv2d.create( - stride_height=stride[0], - stride_width=stride[1], - dilation_height=dilation, - dilation_width=dilation, + stride=stride, + dilation=[dilation, dilation], groups=cin, - padding_left=padding[0], - padding_right=padding[1], - padding_top=padding[2], - padding_bottom=padding[3], + padding=padding, channel_last=channel_last, ), [activations, weight], { - "stride_height": stride[0], - "stride_width": stride[1], - "dilation_height": dilation, - "dilation_width": dilation, + "stride": stride, + "dilation": [dilation, dilation], "groups": cin, - "padding_left": padding[0], - "padding_right": padding[1], - "padding_top": padding[2], - "padding_bottom": padding[3], + "padding": padding, "channel_last": channel_last, }, ) diff --git a/third_party/tt-mlir b/third_party/tt-mlir index 1d937da29..cc6f36d02 160000 --- a/third_party/tt-mlir +++ b/third_party/tt-mlir @@ -1 +1 @@ -Subproject commit 1d937da293bcd621cd34a204465f185c4ab4996d +Subproject commit cc6f36d02687d60a40b0e8712d81c4792becd1f0