diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 381a3f2811..6e4e0c576f 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -1258,15 +1258,11 @@ def TTNN_MaxPool2dOp : TTNN_Op<"max_pool2d"> { SI32Attr:$input_height, SI32Attr:$input_width, SI32Attr:$channels, - SI32Attr:$kernel_height, - SI32Attr:$kernel_width, - SI32Attr:$stride_height, - SI32Attr:$stride_width, - SI32Attr:$dilation_height, - SI32Attr:$dilation_width, - BoolAttr:$ceil_mode, - SI32Attr:$padding_height, - SI32Attr:$padding_width); + DenseI32ArrayAttr:$kernel_size, + DenseI32ArrayAttr:$stride, + DenseI32ArrayAttr:$padding, + DenseI32ArrayAttr:$dilation, + BoolAttr:$ceil_mode); let results = (outs AnyRankedTensor:$result); diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index a074610c22..64d5104120 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -358,15 +358,11 @@ table MaxPool2dOp { input_height: uint32; input_width: uint32; channels: uint32; - kernel_height: uint32; - kernel_width: uint32; - stride_height: uint32; - stride_width: uint32; - dilation_height: uint32; - dilation_width: uint32; + kernel_size: [int32]; + stride: [int32]; + padding: [int32]; + dilation: [int32]; ceil_mode: bool; - padding_height: uint32; - padding_width: uint32; } table DeallocateOp { diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 8cab44949f..888569f2dc 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -1246,16 +1246,27 @@ class MaxPool2dOpConversionPattern outputType.getElementType(), outputType.getEncoding()); + DenseI32ArrayAttr kernelSizeAttr = rewriter.getDenseI32ArrayAttr( + {adaptor.getKernelHeight(), adaptor.getKernelWidth()}); + + DenseI32ArrayAttr strideAttr = rewriter.getDenseI32ArrayAttr( + {adaptor.getStrideHeight(), adaptor.getStrideWidth()}); + + assert(adaptor.getPaddingTop() == adaptor.getPaddingBottom()); + assert(adaptor.getPaddingLeft() == adaptor.getPaddingRight()); + DenseI32ArrayAttr paddingAttr = rewriter.getDenseI32ArrayAttr( + {adaptor.getPaddingTop(), adaptor.getPaddingLeft()}); + + DenseI32ArrayAttr dilationAttr = rewriter.getDenseI32ArrayAttr( + {adaptor.getDilationHeight(), adaptor.getDilationWidth()}); + auto newPool = rewriter.create( op.getLoc(), this->getTypeConverter()->convertType(outputType), flattenedInput, device, batchSize, static_cast(inputShape[inputShape.size() - 3]), static_cast(inputShape[inputShape.size() - 2]), channels, - adaptor.getKernelHeight(), adaptor.getKernelWidth(), - adaptor.getStrideHeight(), adaptor.getStrideWidth(), - adaptor.getDilationHeight(), adaptor.getDilationWidth(), - adaptor.getCeilMode(), adaptor.getPaddingTop(), - adaptor.getPaddingRight()); + kernelSizeAttr, strideAttr, paddingAttr, dilationAttr, + adaptor.getCeilMode()); Value output = ttir_to_ttnn::utils::generateReshape(newPool, outputShape, rewriter); diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index af23a6a7b2..f43621b227 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -313,6 +313,44 @@ class MatmulOpConversionPattern } // namespace // ANCHOR_END: adding_an_op_matmul_op_rewriter_emitc +// MaxPool2d op conversion pattern +// +class MaxPool2dOpConversionPattern + : public TTNNToEmitCBaseOpConversionPattern { + +public: + using TTNNToEmitCBaseOpConversionPattern< + tt::ttnn::MaxPool2dOp>::TTNNToEmitCBaseOpConversionPattern; + + LogicalResult + matchAndRewrite(tt::ttnn::MaxPool2dOp srcOp, + tt::ttnn::MaxPool2dOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + ttnn_to_emitc::EmitCTTNNEmitter emitter( + srcOp, adaptor, rewriter); + + llvm::SmallVector args{ + emitter.emit(srcOp.getInput()), + emitter.emit(srcOp.getBatchSize()), + emitter.emit(srcOp.getInputHeight()), + emitter.emit(srcOp.getInputWidth()), + emitter.emit(srcOp.getChannels()), + emitter.emit>(srcOp.getKernelSizeAttr()), + emitter.emit>(srcOp.getStrideAttr()), + emitter.emit>(srcOp.getPaddingAttr()), + emitter.emit>(srcOp.getDilationAttr()), + /*memory_config=*/emitter.emit(std::nullopt), + /*applied_shard_scheme=*/emitter.emit(std::nullopt), + emitter.emit(srcOp.getCeilMode()), + }; + + emitter.replaceOp(*this, args); + + return success(); + } +}; + // Softmax op conversion pattern // class SoftmaxOpConversionPattern @@ -1155,14 +1193,13 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, DefaultOpConversionPattern, ArgMaxOpConversionPattern>(typeConverter, ctx); - // Conv ops + // Pooling ops // patterns.add>(typeConverter, ctx); patterns.add>( typeConverter, ctx); - patterns.add>(typeConverter, - ctx); + patterns.add(typeConverter, ctx); // Other ops // diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 07d9ac3825..7daf0dcb07 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -273,14 +273,14 @@ ::mlir::LogicalResult mlir::tt::ttnn::MaxPool2dOp::verify() { ::mlir::RankedTensorType inputType = getInput().getType(); ::llvm::ArrayRef inputShape = getInput().getType().getShape(); - if (getKernelHeight() > getInputHeight()) { - return emitOpError() << "Kernel height " << getKernelHeight() + if (getKernelSize()[0] > getInputHeight()) { + return emitOpError() << "Kernel height " << getKernelSize()[0] << " is greater than input height " << getInputHeight() << ". This MaxPool2d configuration is invalid."; } - if (getKernelWidth() > getInputWidth()) { - return emitOpError() << "Kernel width " << getKernelWidth() + if (getKernelSize()[1] > getInputWidth()) { + return emitOpError() << "Kernel width " << getKernelSize()[1] << " is greater than input width " << getInputWidth() << ". This MaxPool2d configuration is invalid."; } diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index b4e95518ba..32c2ee15c4 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -1333,14 +1333,21 @@ createMaxPool2dOp(FlatbufferObjectCache &cache, MaxPool2dOp op) { auto out = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedSize); + ::flatbuffers::Offset<::flatbuffers::Vector> kernelSize = + toFlatbuffer(cache, op.getKernelSize()); + ::flatbuffers::Offset<::flatbuffers::Vector> stride = + toFlatbuffer(cache, op.getStride()); + ::flatbuffers::Offset<::flatbuffers::Vector> padding = + toFlatbuffer(cache, op.getPadding()); + ::flatbuffers::Offset<::flatbuffers::Vector> dilation = + toFlatbuffer(cache, op.getDilation()); + auto device = getOperandThroughDPSOps(op.getDevice()); return ::tt::target::ttnn::CreateMaxPool2dOp( *cache.fbb, in, out, cache.at<::tt::target::DeviceRef>(device), op.getBatchSize(), op.getInputHeight(), op.getInputWidth(), - op.getChannels(), op.getKernelHeight(), op.getKernelWidth(), - op.getStrideHeight(), op.getStrideWidth(), op.getDilationHeight(), - op.getDilationWidth(), op.getCeilMode(), op.getPaddingHeight(), - op.getPaddingWidth()); + op.getChannels(), kernelSize, stride, padding, dilation, + op.getCeilMode()); } ::flatbuffers::Offset<::tt::target::ttnn::RepeatInterleaveOp> diff --git a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp index 9c3cd58148..7ea8de997c 100644 --- a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp +++ b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp @@ -25,13 +25,16 @@ void run(const ::tt::target::ttnn::MaxPool2dOp *op, ProgramContext &context) { outputMemoryConfig.has_value(), "Memory config must exist for device tensors"); + std::array kernelSize, stride, padding, dilation; + std::copy_n(op->kernel_size()->begin(), 2, kernelSize.begin()); + std::copy_n(op->stride()->begin(), 2, stride.begin()); + std::copy_n(op->padding()->begin(), 2, padding.begin()); + std::copy_n(op->dilation()->begin(), 2, dilation.begin()); + ::ttnn::Tensor out = ::ttnn::max_pool2d( input, op->batch_size(), op->input_height(), op->input_width(), - op->channels(), std::array{op->kernel_height(), op->kernel_width()}, - std::array{op->stride_height(), op->stride_width()}, - std::array{op->padding_height(), op->padding_width()}, - std::array{op->dilation_height(), op->dilation_width()}, - outputMemoryConfig, std::nullopt); + op->channels(), kernelSize, stride, padding, dilation, outputMemoryConfig, + /*applied_shard_scheme=*/std::nullopt, op->ceil_mode()); tensorPool.insertAndValidate(op->out(), out); } diff --git a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir index 259e912922..c44d71dd5f 100644 --- a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir +++ b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir @@ -20,7 +20,8 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} { // CHECK-SAME: layout = #ttnn.layout // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<16384x32>>, > // CHECK-SAME: -> tensor<1x1x16384x32xbf16, - %3 = "ttnn.max_pool2d"(%2, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, dilation_height = 1 : si32, dilation_width = 1 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_height = 2 : si32, kernel_width = 2 : si32, padding_height = 0 : si32, padding_width = 0 : si32, stride_height = 2 : si32, stride_width = 2 : si32}> : (tensor<1x1x16384x32xf32, #ttnn_layout2>, !tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3> + // %3 = "ttnn.max_pool2d"(%2, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, dilation_height = 1 : si32, dilation_width = 1 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_height = 2 : si32, kernel_width = 2 : si32, padding_height = 0 : si32, padding_width = 0 : si32, stride_height = 2 : si32, stride_width = 2 : si32}> : (tensor<1x1x16384x32xf32, #ttnn_layout2>, !tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3> + %3 = "ttnn.max_pool2d"(%2, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_size = array, stride = array, dilation = array, padding = array}> : (tensor<1x1x16384x32xf32, #ttnn_layout2>, !tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3> // CHECK-NEXT: %[[MAX_POOL_2D_OP:.*]] = "ttnn.max_pool2d"(%[[TO_LAYOUT_INPUT]], %[[DEVICE_OP]]) // Check that the output operand is transformed back into the tile and f32 data type. // CHECK-NEXT: %[[TO_LAYOUT_OUTPUT:.*]] = "ttnn.to_layout"(%[[MAX_POOL_2D_OP]], %[[DEVICE_OP]]) diff --git a/test/ttmlir/EmitC/TTNN/other/pad.mlir b/test/ttmlir/EmitC/TTNN/other/pad.mlir index 18275a0d48..d6cded0fab 100644 --- a/test/ttmlir/EmitC/TTNN/other/pad.mlir +++ b/test/ttmlir/EmitC/TTNN/other/pad.mlir @@ -6,7 +6,7 @@ // UNSUPPORTED: true // Outstanding bug: https://github.com/tenstorrent/tt-mlir/issues/2072 module { - func.func @main(%arg0: tensor<1x1x5x5xbf16>) -> tensor<1x1x7x7xbf16> { + func.func @pad(%arg0: tensor<1x1x5x5xbf16>) -> tensor<1x1x7x7xbf16> { // CHECK: ttnn.pad %1 = "ttir.pad"(%arg0) <{padding = array, value = 0.000000e+00 : f32}> : (tensor<1x1x5x5xbf16>) -> tensor<1x1x7x7xbf16> return %1 : tensor<1x1x7x7xbf16> diff --git a/test/ttmlir/EmitC/TTNN/pooling/max_pool2d.mlir b/test/ttmlir/EmitC/TTNN/pooling/max_pool2d.mlir new file mode 100644 index 0000000000..3708be8139 --- /dev/null +++ b/test/ttmlir/EmitC/TTNN/pooling/max_pool2d.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %basename_t.ttnn +// RUN: ttmlir-opt --ttnn-modify-signatures-for-dylib --convert-ttnn-to-emitc %t.mlir > %t2.mlir +// RUN: ttmlir-translate --mlir-to-cpp %t2.mlir > %basename_t.cpp + +module attributes {} { + func.func @max_pool2d(%arg0: tensor<1x128x128x32xbf16>) -> tensor<1x64x64x32xbf16> { + %0 = tensor.empty() : tensor<1x64x64x32xbf16> + %1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32}> : (tensor<1x128x128x32xbf16>, tensor<1x64x64x32xbf16>) -> tensor<1x64x64x32xbf16> + return %1 : tensor<1x64x64x32xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/n150/maxpool2d_op.mlir b/test/ttmlir/Silicon/StableHLO/n150/maxpool2d_op.mlir index 768d125922..bfd833ef46 100644 --- a/test/ttmlir/Silicon/StableHLO/n150/maxpool2d_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/n150/maxpool2d_op.mlir @@ -23,11 +23,12 @@ module @max_pool2d attributes {} { // CHECK-SAME: batch_size = 1 : si32, // CHECK-SAME: ceil_mode = false, // CHECK-SAME: channels = 128 : si32, - // CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32, - // CHECK-SAME: input_height = 32 : si32, input_width = 32 : si32, - // CHECK-SAME: kernel_height = 3 : si32, kernel_width = 3 : si32, - // CHECK-SAME: padding_height = 1 : si32, padding_width = 1 : si32, - // CHECK-SAME: stride_height = 3 : si32, stride_width = 3 : si32} + // CHECK-SAME: dilation = array, + // CHECK-SAME: input_height = 32 : si32, + // CHECK-SAME: input_width = 32 : si32, + // CHECK-SAME: kernel_size = array, + // CHECK-SAME: padding = array, + // CHECK-SAME: stride = array // CHECK-SAME: tensor<1x1x1024x128xbf16 // CHECK-SAME: -> tensor<1x1x121x128xbf16 %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ @@ -63,11 +64,12 @@ module @max_pool2d attributes {} { // CHECK-SAME: batch_size = 1 : si32, // CHECK-SAME: ceil_mode = false, // CHECK-SAME: channels = 192 : si32, - // CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32, - // CHECK-SAME: input_height = 28 : si32, input_width = 28 : si32, - // CHECK-SAME: kernel_height = 1 : si32, kernel_width = 1 : si32, - // CHECK-SAME: padding_height = 0 : si32, padding_width = 0 : si32, - // CHECK-SAME: stride_height = 1 : si32, stride_width = 1 : si32} + // CHECK-SAME: dilation = array, + // CHECK-SAME: input_height = 28 : si32, + // CHECK-SAME: input_width = 28 : si32, + // CHECK-SAME: kernel_size = array, + // CHECK-SAME: padding = array, + // CHECK-SAME: stride = array // CHECK-SAME: tensor<1x1x784x192xbf16 // CHECK-SAME: -> tensor<1x1x784x192xbf16 %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ @@ -103,11 +105,12 @@ module @max_pool2d attributes {} { // CHECK-SAME: batch_size = 1 : si32, // CHECK-SAME: ceil_mode = false, // CHECK-SAME: channels = 192 : si32, - // CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32, - // CHECK-SAME: input_height = 28 : si32, input_width = 28 : si32, - // CHECK-SAME: kernel_height = 1 : si32, kernel_width = 2 : si32, - // CHECK-SAME: padding_height = 0 : si32, padding_width = 0 : si32, - // CHECK-SAME: stride_height = 3 : si32, stride_width = 1 : si32} + // CHECK-SAME: dilation = array, + // CHECK-SAME: input_height = 28 : si32, + // CHECK-SAME: input_width = 28 : si32, + // CHECK-SAME: kernel_size = array, + // CHECK-SAME: padding = array, + // CHECK-SAME: stride = array // CHECK-SAME: tensor<1x1x784x192xbf16 // CHECK-SAME: -> tensor<1x1x270x192xbf16 %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ diff --git a/tools/ttnn-standalone/ttnn-precompiled.hpp b/tools/ttnn-standalone/ttnn-precompiled.hpp index 9dc5c08184..78f5561fdd 100644 --- a/tools/ttnn-standalone/ttnn-precompiled.hpp +++ b/tools/ttnn-standalone/ttnn-precompiled.hpp @@ -23,6 +23,7 @@ #include "operations/matmul/matmul.hpp" #include "operations/moreh/moreh_cumsum/moreh_cumsum.hpp" #include "operations/normalization/softmax/softmax.hpp" +#include "operations/pool/generic/generic_pools.hpp" #include "operations/reduction/generic/generic_reductions.hpp" #include "tensor/tensor.hpp" #include "tensor/types.hpp"