From aa8b980a956588e8e69d72d85c9405bc35d1362f Mon Sep 17 00:00:00 2001 From: Sanja Djukic Date: Tue, 4 Mar 2025 11:24:45 +0000 Subject: [PATCH] conv1d op a different case --- .../TTIRToTTIRDecomposition.cpp | 5 ----- .../TTNN/convolution/simple_conv1d.mlir | 22 ++++++++++++++++++- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index 9d54682dfd..4c5878c768 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -213,11 +213,6 @@ struct Legalize1DConvolutionPattern : public ConvolutionDecompositionPattern { return failure(); } - // Not currently supporting spatial dims other than 2 for the 1D case. - if (op.getConvolutionLayout().getInputSpatialDimensions()[0] != 2) { - return failure(); - } - // The shapes that the convolution currently operates with have are 3D, and // we need to add another dimension for it to match the conv2d signature, so // adding a dimension of size 1 to the end of input and output shapes. diff --git a/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir b/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir index 8767a0409a..71b4cb8808 100644 --- a/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir +++ b/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir @@ -1,6 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s module { - func.func @main(%arg0: tensor<1x256x512xf32>, %arg1: tensor<1024x256x1xf32>, %arg2: tensor<1024xf32>) -> tensor<1x1024x512xf32> { + func.func @conv1d_test1(%arg0: tensor<1x256x512xf32>, %arg1: tensor<1024x256x1xf32>, %arg2: tensor<1024xf32>) -> tensor<1x1024x512xf32> { %0 = tensor.empty() : tensor<1x1024x512xf32> // CHECK: "ttnn.reshape" // CHECK-SAME: shape = [1 : i32, 256 : i32, 512 : i32, 1 : i32] @@ -17,4 +17,24 @@ module { // CHECK: return %{{.*}} : tensor<1x1024x512xf32, #ttnn_layout3> return %1 : tensor<1x1024x512xf32> } + + func.func public @conv1d_test2(%arg0: tensor<1x7x768xbf16>, %arg1: tensor<1x192x768xbf16>) -> (tensor<1x7x768xbf16>) { + %0 = tensor.empty() : tensor<1x7x768xbf16> + // CHECK: "ttnn.reshape" + // CHECK-SAME: shape = [1 : i32, 7 : i32, 768 : i32, 1 : i32] + // CHECK: "ttnn.reshape" + // CHECK-SAME: shape = [1 : i32, 192 : i32, 768 : i32, 1 : i32] + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK: "ttnn.conv2d" + // CHECK: "ttnn.permute" + // CHECK-SAME: permutation = array + // CHECK: "ttnn.reshape" + // CHECK-SAME: shape = [1 : i32, 7 : i32, 768 : i32] + %1 = "ttir.convolution"(%arg0, %arg1, %0) <{batch_group_count = 1 : i64, convolution_layout = #ttir, feature_group_count = 4 : i64, input_dilation = array, padding = array, weight_dilation = array, window_reversal = array, window_strides = array}> : (tensor<1x7x768xbf16>, tensor<1x192x768xbf16>, tensor<1x7x768xbf16>) -> tensor<1x7x768xbf16> + // CHECK: return %{{.*}} : tensor<1x7x768xbf16, #ttnn_layout{{.*}}> + return %1 : tensor<1x7x768xbf16> + } }