Skip to content

Commit

Permalink
Fix maxpool2d signature + add emitc conversion&test (#2385)
Browse files Browse the repository at this point in the history
### Ticket
#2328 

### Problem description
Maxpool2d is not supported thru emitc.

### What's changed
- Changed MaxPool2d signature in TTNN dialect to match signature of lib.
- Added conversion TTNN->EmitC
- Added test

### Checklist
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
svuckovicTT authored and odjuricicTT committed Mar 8, 2025
1 parent 995aad9 commit e268581
Show file tree
Hide file tree
Showing 12 changed files with 122 additions and 55 deletions.
14 changes: 5 additions & 9 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1278,15 +1278,11 @@ def TTNN_MaxPool2dOp : TTNN_Op<"max_pool2d"> {
SI32Attr:$input_height,
SI32Attr:$input_width,
SI32Attr:$channels,
SI32Attr:$kernel_height,
SI32Attr:$kernel_width,
SI32Attr:$stride_height,
SI32Attr:$stride_width,
SI32Attr:$dilation_height,
SI32Attr:$dilation_width,
BoolAttr:$ceil_mode,
SI32Attr:$padding_height,
SI32Attr:$padding_width);
DenseI32ArrayAttr:$kernel_size,
DenseI32ArrayAttr:$stride,
DenseI32ArrayAttr:$padding,
DenseI32ArrayAttr:$dilation,
BoolAttr:$ceil_mode);

let results = (outs AnyRankedTensor:$result);

Expand Down
12 changes: 4 additions & 8 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -359,15 +359,11 @@ table MaxPool2dOp {
input_height: uint32;
input_width: uint32;
channels: uint32;
kernel_height: uint32;
kernel_width: uint32;
stride_height: uint32;
stride_width: uint32;
dilation_height: uint32;
dilation_width: uint32;
kernel_size: [int32];
stride: [int32];
padding: [int32];
dilation: [int32];
ceil_mode: bool;
padding_height: uint32;
padding_width: uint32;
}

table DeallocateOp {
Expand Down
21 changes: 16 additions & 5 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1246,16 +1246,27 @@ class MaxPool2dOpConversionPattern
outputType.getElementType(),
outputType.getEncoding());

DenseI32ArrayAttr kernelSizeAttr = rewriter.getDenseI32ArrayAttr(
{adaptor.getKernelHeight(), adaptor.getKernelWidth()});

DenseI32ArrayAttr strideAttr = rewriter.getDenseI32ArrayAttr(
{adaptor.getStrideHeight(), adaptor.getStrideWidth()});

assert(adaptor.getPaddingTop() == adaptor.getPaddingBottom());
assert(adaptor.getPaddingLeft() == adaptor.getPaddingRight());
DenseI32ArrayAttr paddingAttr = rewriter.getDenseI32ArrayAttr(
{adaptor.getPaddingTop(), adaptor.getPaddingLeft()});

DenseI32ArrayAttr dilationAttr = rewriter.getDenseI32ArrayAttr(
{adaptor.getDilationHeight(), adaptor.getDilationWidth()});

auto newPool = rewriter.create<ttnn::MaxPool2dOp>(
op.getLoc(), this->getTypeConverter()->convertType(outputType),
flattenedInput, device, batchSize,
static_cast<int32_t>(inputShape[inputShape.size() - 3]),
static_cast<int32_t>(inputShape[inputShape.size() - 2]), channels,
adaptor.getKernelHeight(), adaptor.getKernelWidth(),
adaptor.getStrideHeight(), adaptor.getStrideWidth(),
adaptor.getDilationHeight(), adaptor.getDilationWidth(),
adaptor.getCeilMode(), adaptor.getPaddingTop(),
adaptor.getPaddingRight());
kernelSizeAttr, strideAttr, paddingAttr, dilationAttr,
adaptor.getCeilMode());

Value output =
ttir_to_ttnn::utils::generateReshape(newPool, outputShape, rewriter);
Expand Down
43 changes: 40 additions & 3 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,44 @@ class MatmulOpConversionPattern
} // namespace
// ANCHOR_END: adding_an_op_matmul_op_rewriter_emitc

// MaxPool2d op conversion pattern
//
class MaxPool2dOpConversionPattern
: public TTNNToEmitCBaseOpConversionPattern<tt::ttnn::MaxPool2dOp> {

public:
using TTNNToEmitCBaseOpConversionPattern<
tt::ttnn::MaxPool2dOp>::TTNNToEmitCBaseOpConversionPattern;

LogicalResult
matchAndRewrite(tt::ttnn::MaxPool2dOp srcOp,
tt::ttnn::MaxPool2dOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

ttnn_to_emitc::EmitCTTNNEmitter<tt::ttnn::MaxPool2dOp> emitter(
srcOp, adaptor, rewriter);

llvm::SmallVector<mlir::Attribute> args{
emitter.emit(srcOp.getInput()),
emitter.emit(srcOp.getBatchSize()),
emitter.emit(srcOp.getInputHeight()),
emitter.emit(srcOp.getInputWidth()),
emitter.emit(srcOp.getChannels()),
emitter.emit<std::array<uint32_t, 2>>(srcOp.getKernelSizeAttr()),
emitter.emit<std::array<uint32_t, 2>>(srcOp.getStrideAttr()),
emitter.emit<std::array<uint32_t, 2>>(srcOp.getPaddingAttr()),
emitter.emit<std::array<uint32_t, 2>>(srcOp.getDilationAttr()),
/*memory_config=*/emitter.emit(std::nullopt),
/*applied_shard_scheme=*/emitter.emit(std::nullopt),
emitter.emit(srcOp.getCeilMode()),
};

emitter.replaceOp(*this, args);

return success();
}
};

// Softmax op conversion pattern
//
class SoftmaxOpConversionPattern
Expand Down Expand Up @@ -1156,14 +1194,13 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
DefaultOpConversionPattern<tt::ttnn::ProdOp>,
ArgMaxOpConversionPattern>(typeConverter, ctx);

// Conv ops
// Pooling ops
//
patterns.add<DefaultOpConversionPattern<tt::ttnn::Conv2dOp>>(typeConverter,
ctx);
patterns.add<DefaultOpConversionPattern<tt::ttnn::ConvTranspose2dOp>>(
typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<tt::ttnn::MaxPool2dOp>>(typeConverter,
ctx);
patterns.add<MaxPool2dOpConversionPattern>(typeConverter, ctx);

// Other ops
//
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,14 +287,14 @@ ::mlir::LogicalResult mlir::tt::ttnn::MaxPool2dOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::llvm::ArrayRef<int64_t> inputShape = getInput().getType().getShape();

if (getKernelHeight() > getInputHeight()) {
return emitOpError() << "Kernel height " << getKernelHeight()
if (getKernelSize()[0] > getInputHeight()) {
return emitOpError() << "Kernel height " << getKernelSize()[0]
<< " is greater than input height " << getInputHeight()
<< ". This MaxPool2d configuration is invalid.";
}

if (getKernelWidth() > getInputWidth()) {
return emitOpError() << "Kernel width " << getKernelWidth()
if (getKernelSize()[1] > getInputWidth()) {
return emitOpError() << "Kernel width " << getKernelSize()[1]
<< " is greater than input width " << getInputWidth()
<< ". This MaxPool2d configuration is invalid.";
}
Expand Down
15 changes: 11 additions & 4 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1344,14 +1344,21 @@ createMaxPool2dOp(FlatbufferObjectCache &cache, MaxPool2dOp op) {
auto out = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedSize);

::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> kernelSize =
toFlatbuffer(cache, op.getKernelSize());
::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> stride =
toFlatbuffer(cache, op.getStride());
::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> padding =
toFlatbuffer(cache, op.getPadding());
::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> dilation =
toFlatbuffer(cache, op.getDilation());

auto device = getOperandThroughDPSOps(op.getDevice());
return ::tt::target::ttnn::CreateMaxPool2dOp(
*cache.fbb, in, out, cache.at<::tt::target::DeviceRef>(device),
op.getBatchSize(), op.getInputHeight(), op.getInputWidth(),
op.getChannels(), op.getKernelHeight(), op.getKernelWidth(),
op.getStrideHeight(), op.getStrideWidth(), op.getDilationHeight(),
op.getDilationWidth(), op.getCeilMode(), op.getPaddingHeight(),
op.getPaddingWidth());
op.getChannels(), kernelSize, stride, padding, dilation,
op.getCeilMode());
}

::flatbuffers::Offset<::tt::target::ttnn::RepeatInterleaveOp>
Expand Down
13 changes: 8 additions & 5 deletions runtime/lib/ttnn/operations/pool/maxpool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ void run(const ::tt::target::ttnn::MaxPool2dOp *op, ProgramContext &context) {
outputMemoryConfig.has_value(),
"Memory config must exist for device tensors");

std::array<uint32_t, 2> kernelSize, stride, padding, dilation;
std::copy_n(op->kernel_size()->begin(), 2, kernelSize.begin());
std::copy_n(op->stride()->begin(), 2, stride.begin());
std::copy_n(op->padding()->begin(), 2, padding.begin());
std::copy_n(op->dilation()->begin(), 2, dilation.begin());

::ttnn::Tensor out = ::ttnn::max_pool2d(
input, op->batch_size(), op->input_height(), op->input_width(),
op->channels(), std::array{op->kernel_height(), op->kernel_width()},
std::array{op->stride_height(), op->stride_width()},
std::array{op->padding_height(), op->padding_width()},
std::array{op->dilation_height(), op->dilation_width()},
outputMemoryConfig, std::nullopt);
op->channels(), kernelSize, stride, padding, dilation, outputMemoryConfig,
/*applied_shard_scheme=*/std::nullopt, op->ceil_mode());

tensorPool.insertAndValidate(op->out(), out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} {
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<16384x32>>, <interleaved>>
// CHECK-SAME: -> tensor<1x1x16384x32xbf16,
%3 = "ttnn.max_pool2d"(%2, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, dilation_height = 1 : si32, dilation_width = 1 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_height = 2 : si32, kernel_width = 2 : si32, padding_height = 0 : si32, padding_width = 0 : si32, stride_height = 2 : si32, stride_width = 2 : si32}> : (tensor<1x1x16384x32xf32, #ttnn_layout2>, !tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3>
// %3 = "ttnn.max_pool2d"(%2, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, dilation_height = 1 : si32, dilation_width = 1 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_height = 2 : si32, kernel_width = 2 : si32, padding_height = 0 : si32, padding_width = 0 : si32, stride_height = 2 : si32, stride_width = 2 : si32}> : (tensor<1x1x16384x32xf32, #ttnn_layout2>, !tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3>
%3 = "ttnn.max_pool2d"(%2, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_size = array<i32: 2, 2>, stride = array<i32: 2, 2>, dilation = array<i32: 1, 1>, padding = array<i32: 0, 0>}> : (tensor<1x1x16384x32xf32, #ttnn_layout2>, !tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3>
// CHECK-NEXT: %[[MAX_POOL_2D_OP:.*]] = "ttnn.max_pool2d"(%[[TO_LAYOUT_INPUT]], %[[DEVICE_OP]])
// Check that the output operand is transformed back into the tile and f32 data type.
// CHECK-NEXT: %[[TO_LAYOUT_OUTPUT:.*]] = "ttnn.to_layout"(%[[MAX_POOL_2D_OP]], %[[DEVICE_OP]])
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/EmitC/TTNN/other/pad.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
// UNSUPPORTED: true
// Outstanding bug: https://github.com/tenstorrent/tt-mlir/issues/2072
module {
func.func @main(%arg0: tensor<1x1x5x5xbf16>) -> tensor<1x1x7x7xbf16> {
func.func @pad(%arg0: tensor<1x1x5x5xbf16>) -> tensor<1x1x7x7xbf16> {
// CHECK: ttnn.pad
%1 = "ttir.pad"(%arg0) <{padding = array<i32: 0, 0, 0, 0, 1, 1, 1, 1>, value = 0.000000e+00 : f32}> : (tensor<1x1x5x5xbf16>) -> tensor<1x1x7x7xbf16>
return %1 : tensor<1x1x7x7xbf16>
Expand Down
12 changes: 12 additions & 0 deletions test/ttmlir/EmitC/TTNN/pooling/max_pool2d.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %basename_t.ttnn
// RUN: ttmlir-opt --ttnn-modify-signatures-for-dylib --convert-ttnn-to-emitc %t.mlir > %t2.mlir
// RUN: ttmlir-translate --mlir-to-cpp %t2.mlir > %basename_t.cpp

module attributes {} {
func.func @max_pool2d(%arg0: tensor<1x128x128x32xbf16>) -> tensor<1x64x64x32xbf16> {
%0 = tensor.empty() : tensor<1x64x64x32xbf16>
%1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32}> : (tensor<1x128x128x32xbf16>, tensor<1x64x64x32xbf16>) -> tensor<1x64x64x32xbf16>
return %1 : tensor<1x64x64x32xbf16>
}
}
33 changes: 18 additions & 15 deletions test/ttmlir/Silicon/StableHLO/n150/maxpool2d_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ module @max_pool2d attributes {} {
// CHECK-SAME: batch_size = 1 : si32,
// CHECK-SAME: ceil_mode = false,
// CHECK-SAME: channels = 128 : si32,
// CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32,
// CHECK-SAME: input_height = 32 : si32, input_width = 32 : si32,
// CHECK-SAME: kernel_height = 3 : si32, kernel_width = 3 : si32,
// CHECK-SAME: padding_height = 1 : si32, padding_width = 1 : si32,
// CHECK-SAME: stride_height = 3 : si32, stride_width = 3 : si32}
// CHECK-SAME: dilation = array<i32: 1, 1>,
// CHECK-SAME: input_height = 32 : si32,
// CHECK-SAME: input_width = 32 : si32,
// CHECK-SAME: kernel_size = array<i32: 3, 3>,
// CHECK-SAME: padding = array<i32: 1, 1>,
// CHECK-SAME: stride = array<i32: 3, 3>
// CHECK-SAME: tensor<1x1x1024x128xbf16
// CHECK-SAME: -> tensor<1x1x121x128xbf16
%0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 3, 3>}> ({
Expand Down Expand Up @@ -63,11 +64,12 @@ module @max_pool2d attributes {} {
// CHECK-SAME: batch_size = 1 : si32,
// CHECK-SAME: ceil_mode = false,
// CHECK-SAME: channels = 192 : si32,
// CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32,
// CHECK-SAME: input_height = 28 : si32, input_width = 28 : si32,
// CHECK-SAME: kernel_height = 1 : si32, kernel_width = 1 : si32,
// CHECK-SAME: padding_height = 0 : si32, padding_width = 0 : si32,
// CHECK-SAME: stride_height = 1 : si32, stride_width = 1 : si32}
// CHECK-SAME: dilation = array<i32: 1, 1>,
// CHECK-SAME: input_height = 28 : si32,
// CHECK-SAME: input_width = 28 : si32,
// CHECK-SAME: kernel_size = array<i32: 1, 1>,
// CHECK-SAME: padding = array<i32: 0, 0>,
// CHECK-SAME: stride = array<i32: 1, 1>
// CHECK-SAME: tensor<1x1x784x192xbf16
// CHECK-SAME: -> tensor<1x1x784x192xbf16
%0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 1, 1>, window_strides = array<i64: 1, 1, 1, 1>}> ({
Expand Down Expand Up @@ -103,11 +105,12 @@ module @max_pool2d attributes {} {
// CHECK-SAME: batch_size = 1 : si32,
// CHECK-SAME: ceil_mode = false,
// CHECK-SAME: channels = 192 : si32,
// CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32,
// CHECK-SAME: input_height = 28 : si32, input_width = 28 : si32,
// CHECK-SAME: kernel_height = 1 : si32, kernel_width = 2 : si32,
// CHECK-SAME: padding_height = 0 : si32, padding_width = 0 : si32,
// CHECK-SAME: stride_height = 3 : si32, stride_width = 1 : si32}
// CHECK-SAME: dilation = array<i32: 1, 1>,
// CHECK-SAME: input_height = 28 : si32,
// CHECK-SAME: input_width = 28 : si32,
// CHECK-SAME: kernel_size = array<i32: 1, 2>,
// CHECK-SAME: padding = array<i32: 0, 0>,
// CHECK-SAME: stride = array<i32: 3, 1>
// CHECK-SAME: tensor<1x1x784x192xbf16
// CHECK-SAME: -> tensor<1x1x270x192xbf16
%0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 1, 2>, window_strides = array<i64: 1, 1, 3, 1>}> ({
Expand Down
1 change: 1 addition & 0 deletions tools/ttnn-standalone/ttnn-precompiled.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "operations/matmul/matmul.hpp"
#include "operations/moreh/moreh_cumsum/moreh_cumsum.hpp"
#include "operations/normalization/softmax/softmax.hpp"
#include "operations/pool/generic/generic_pools.hpp"
#include "operations/reduction/generic/generic_reductions.hpp"
#include "tensor/tensor.hpp"
#include "tensor/types.hpp"
Expand Down

0 comments on commit e268581

Please sign in to comment.