From 7943e1e9d5b4d9f3798de45d189ad28f45ce418f Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Fri, 31 Jan 2025 19:00:58 -0800 Subject: [PATCH 01/15] code re-factoring --- lib/TPP/Transforms/BrgemmLinalgTiling.cpp | 188 ++++++++++------------ 1 file changed, 89 insertions(+), 99 deletions(-) diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index 861800697..fb9f1709e 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -43,108 +43,86 @@ using namespace mlir::tpp; namespace mlir { namespace tpp { -struct LinalgOpTiling : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + +template +struct LinalgOpTiling : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; LinalgOpTiling(MLIRContext *ctx, BrgemmLinalgTilingOptions tilingoptions) - : OpRewritePattern(ctx), options(tilingoptions) {} + : OpRewritePattern(ctx), options(tilingoptions) {} - LogicalResult matchAndRewrite(linalg::BatchReduceMatmulOp brgemmOp, + LogicalResult matchAndRewrite(BrgemmOp brgemmOp, PatternRewriter &rewriter) const override { - if (!brgemmOp.hasPureBufferSemantics()) return failure(); - // Get the register blocking tile shape from the user input - SmallVector tileShapeM(options.registerTileShape.begin(), - options.registerTileShape.end()); - if (tileShapeM.size() != 2) + // Check whether the tile sizes are valid + if (options.registerTileShape.size() != 3 && options.registerTileShape.size() != 2) return failure(); - SmallVector tileShapeN(2); - tileShapeN[0] = 1; - tileShapeN[1] = tileShapeM[1]; - tileShapeM[1] = 1; + // Check the whether the operation is brmatmul fp32 or bf16 type using reduction count + SmallVector brgemmIteratorTypes = + brgemmOp.getIteratorTypesArray(); + int reductionCount = + std::count(brgemmIteratorTypes.begin(), brgemmIteratorTypes.end(), + utils::IteratorType::reduction); + if (reductionCount != 2 && reductionCount != 3) + return failure(); - // Stores the M, N, and K Tile Sizes + // Get the register blocking tile shape from the user input SmallVector mxnxkTile(3); - // Stores the M, and N Tile Sizes - SmallVector mxnTile(2); + for (size_t i = 0; i < options.registerTileShape.size(); i++) { + mxnxkTile[i] = options.registerTileShape[i]; + } + + // Set the K tile to 1, if the user not provided (it is fp32 target) + if (options.registerTileShape.size() == 2) + mxnxkTile[2] = 1; + + // k-tile size adjusted based on the vnni layout for bf16 type + auto tensorShape = dyn_cast(brgemmOp.getOperand(0).getType()).getShape(); + if (tensorShape.size() == 4 && options.registerTileShape.size() == 3) { + mxnxkTile[2] = mxnxkTile[2] / tensorShape[3]; + } + - mxnxkTile[0] = tileShapeM[0]; - mxnxkTile[1] = tileShapeN[1]; - mxnxkTile[2] = tileShapeM[1]; - mxnTile[0] = tileShapeM[0]; - mxnTile[1] = tileShapeN[1]; - - // To assist in calculating the argument and step values for the tiled loop. - SmallVector boundariesOne{1, - static_cast(tileShapeM.size() - 1), - static_cast(mxnxkTile.size() - 1)}; - - SmallVector tileSizesIndex{static_cast(tileShapeM.size()), - static_cast(tileShapeN.size()), - static_cast(mxnTile.size())}; - SmallVector> tileshapes{tileShapeM, tileShapeN, mxnTile}; SmallVector swap_i = {0, 2, 1}; size_t i = 0; std::map> inductionVars; // For M, N, and K loops scf::ForOp innermostForLoop; - // For brgemm reduction loop - scf::ForOp reductionForLoop; // Creating the tiled loops - for (auto itrShapeM = mxnxkTile.begin(); itrShapeM != mxnxkTile.end(); - itrShapeM++, i++) { - int index = swap_i[i] / boundariesOne[swap_i[i]]; - int offset = swap_i[i] / (mxnxkTile.size() - 1); - - int operandSize = - dyn_cast(brgemmOp.getOperand(index).getType()) - .getShape() - .size(); - int effectiveOffset = operandSize - tileSizesIndex[index] + offset; + for (auto itrShapeMNK = mxnxkTile.begin(); itrShapeMNK != mxnxkTile.end(); + itrShapeMNK++, i++) { auto upperBound = - dyn_cast(brgemmOp.getOperand(index).getType()) - .getShape()[effectiveOffset]; + dyn_cast(brgemmOp.getOperand(swap_i[i]).getType()) + .getShape()[1]; + + //Tile size should not be greater than the upperBound + if ((*itrShapeMNK) > upperBound) + return failure(); + Location loc = brgemmOp.getLoc(); Value zeroCst = rewriter.create(loc, 0); Value ubCstTiledLoop = rewriter.create(loc, upperBound); - //Tile size should not be greater than the upperBound - if ((*itrShapeM) > upperBound) - return failure(); - Value stepCstTiledLoop = rewriter.create(loc, *itrShapeM); + + Value stepCstTiledLoop = rewriter.create(loc, *itrShapeMNK); // Creates M, N, and K tile loops scf::ForOp loopOp = rewriter.create(brgemmOp.getLoc(), zeroCst, ubCstTiledLoop, stepCstTiledLoop); rewriter.setInsertionPointToStart(loopOp.getBody()); - int indexTwo = offset; - int operandSizeTwo = - dyn_cast(brgemmOp.getOperand(indexTwo).getType()) - .getShape() - .size(); - int effectiveOffsetTwo = operandSizeTwo - tileSizesIndex[index] + index; - - inductionVars[index][effectiveOffset] = loopOp.getInductionVar(); - - inductionVars[indexTwo][effectiveOffsetTwo] = loopOp.getInductionVar(); - int indexThree = mxnTile.size(); - int effectiveOffsetThree = - index + - dyn_cast(brgemmOp.getOperand(indexThree).getType()) - .getShape() - .size() - - tileSizesIndex[indexThree]; - if (inductionVars[indexThree][effectiveOffsetThree] == NULL) { - inductionVars[indexThree][effectiveOffsetThree] = - loopOp.getInductionVar(); - } - innermostForLoop = loopOp; - if ((mxnxkTile.size() - 1) == (i + 1)) { - //Creates the brgemm reduction loop + + // Stores the induction variable with respect to the operands mapping it's subview. + if (i == 0) { + inductionVars[0][1] = loopOp.getInductionVar(); + inductionVars[2][0] = loopOp.getInductionVar(); + } else if(i == 1) { + inductionVars[1][2] = loopOp.getInductionVar(); + inductionVars[2][1] = loopOp.getInductionVar(); + //Creates reduction loop after the N loop Value ubCstReduction = rewriter.create( loc, dyn_cast(brgemmOp.getOperand(0).getType()) .getShape()[0]); @@ -152,46 +130,58 @@ struct LinalgOpTiling : OpRewritePattern { scf::ForOp redloopOp = rewriter.create( brgemmOp.getLoc(), zeroCst, ubCstReduction, stepCstReduction); rewriter.setInsertionPointToStart(redloopOp.getBody()); - reductionForLoop = redloopOp; + inductionVars[0][0] = redloopOp.getInductionVar(); + inductionVars[1][0] = redloopOp.getInductionVar(); + + } else if(i == 2) { + inductionVars[0][2] = loopOp.getInductionVar(); + inductionVars[1][1] = loopOp.getInductionVar(); } } + // DS to assist while creating new subviews with correct indices and shapes + SmallVector mxkTile(2); + SmallVector kxnTile(2); + SmallVector mxnTile(2); + mxkTile[0] = mxnxkTile[0]; + mxkTile[1] = mxnxkTile[2]; + kxnTile[0] = mxnxkTile[2]; + kxnTile[1] = mxnxkTile[1]; + mxnTile[0] = mxnxkTile[0]; + mxnTile[1] = mxnxkTile[1]; + + SmallVector> tileshapes{mxkTile, kxnTile, mxnTile}; // Creating subviews - SmallVector> tiles = {tileShapeM, tileShapeN}; for (size_t i = 0; i < brgemmOp.getNumOperands(); i++) { - SmallVector indices; - auto input = brgemmOp.getOperand(i); - auto operandType = input.getType(); SmallVector offsets; - size_t k = 0; - auto tileItr = tileshapes[i].begin(); - auto tensorShape = dyn_cast(operandType).getShape(); + SmallVector indices; SmallVector shape; SmallVector strides; + + auto input = brgemmOp.getOperand(i); + auto tensorShape = dyn_cast(input.getType()).getShape(); + auto tileItr = tileshapes[i].begin(); + + // Iterates over the shape of each tensor and update its offsets, indices, shapes, strides with respect to tile sizes for (size_t j = 0; j < tensorShape.size(); j++) { - if (j < tensorShape.size() - tileSizesIndex[i]) { - if (j == ((tensorShape.size() - tileSizesIndex[i]) - 1) && - i < (brgemmOp.getNumOperands() - 1)) { - offsets.push_back(reductionForLoop.getInductionVar()); - indices.push_back(tensorShape[j] / tensorShape[j]); - shape.push_back(rewriter.getIndexAttr(tensorShape[j] / tensorShape[j])); + if (j == 0 && (i < 2)) { // Updates the batch dimension + offsets.push_back(inductionVars[i][j]); + indices.push_back(1); + shape.push_back(rewriter.getIndexAttr(1)); strides.push_back(rewriter.getIndexAttr(1)); - - } else { + } else if (j < 3) { // Updates the M, N, and K dimensions + offsets.push_back(inductionVars[i][j]); + indices.push_back((*tileItr)); + shape.push_back(rewriter.getIndexAttr(*tileItr)); + strides.push_back(rewriter.getIndexAttr(1)); + tileItr++; + } else { // Just copies the vnni layout dimensions offsets.push_back(rewriter.getIndexAttr(0)); indices.push_back(tensorShape[j]); shape.push_back(rewriter.getIndexAttr(tensorShape[j])); strides.push_back(rewriter.getIndexAttr(1)); - } - } else { - shape.push_back(rewriter.getIndexAttr(*tileItr)); - indices.push_back((*tileItr)); - strides.push_back(rewriter.getIndexAttr(1)); - offsets.push_back( - inductionVars[i][tensorShape.size() - tileSizesIndex[i] + k]); - k++; - tileItr++; } + } auto subview = rewriter.create( @@ -215,7 +205,7 @@ struct LinalgOpTiling : OpRewritePattern { void populateBrgemmLinalgTilingPatterns(RewritePatternSet &patterns, BrgemmLinalgTilingOptions options) { - patterns.add(patterns.getContext(), options); + patterns.add, LinalgOpTiling>(patterns.getContext(), options); } struct BrgemmLinalgTiling : public tpp::impl::BrgemmLinalgTilingBase { From 310a21a48e6cf93fad2b04651782c855ff324854 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Sun, 2 Feb 2025 19:12:55 -0800 Subject: [PATCH 02/15] code re-factoring and test cases for bf16 --- lib/TPP/Transforms/BrgemmLinalgTiling.cpp | 89 ++--- .../pass-tile-brgemm-linalg-matmul.mlir | 315 ++++++++++++------ 2 files changed, 270 insertions(+), 134 deletions(-) diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index fb9f1709e..fcdc645f6 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -1,4 +1,5 @@ -//===- BrgemmLinalgTiling.cpp -----------------------------------------*- C++-*-===// +//===- BrgemmLinalgTiling.cpp -----------------------------------------*- +//C++-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -57,22 +58,24 @@ struct LinalgOpTiling : OpRewritePattern { return failure(); // Check whether the tile sizes are valid - if (options.registerTileShape.size() != 3 && options.registerTileShape.size() != 2) - return failure(); + if (options.registerTileShape.size() != 3 && + options.registerTileShape.size() != 2) + return failure(); - // Check the whether the operation is brmatmul fp32 or bf16 type using reduction count + // Check the whether the operation is brmatmul fp32 or bf16 type using + // reduction count SmallVector brgemmIteratorTypes = brgemmOp.getIteratorTypesArray(); int reductionCount = std::count(brgemmIteratorTypes.begin(), brgemmIteratorTypes.end(), utils::IteratorType::reduction); if (reductionCount != 2 && reductionCount != 3) - return failure(); + return failure(); // Get the register blocking tile shape from the user input SmallVector mxnxkTile(3); for (size_t i = 0; i < options.registerTileShape.size(); i++) { - mxnxkTile[i] = options.registerTileShape[i]; + mxnxkTile[i] = options.registerTileShape[i]; } // Set the K tile to 1, if the user not provided (it is fp32 target) @@ -80,12 +83,12 @@ struct LinalgOpTiling : OpRewritePattern { mxnxkTile[2] = 1; // k-tile size adjusted based on the vnni layout for bf16 type - auto tensorShape = dyn_cast(brgemmOp.getOperand(0).getType()).getShape(); + auto tensorShape = + dyn_cast(brgemmOp.getOperand(0).getType()).getShape(); if (tensorShape.size() == 4 && options.registerTileShape.size() == 3) { mxnxkTile[2] = mxnxkTile[2] / tensorShape[3]; } - SmallVector swap_i = {0, 2, 1}; size_t i = 0; std::map> inductionVars; @@ -100,42 +103,46 @@ struct LinalgOpTiling : OpRewritePattern { dyn_cast(brgemmOp.getOperand(swap_i[i]).getType()) .getShape()[1]; - //Tile size should not be greater than the upperBound + // Tile size should not be greater than the upperBound if ((*itrShapeMNK) > upperBound) - return failure(); + return failure(); Location loc = brgemmOp.getLoc(); Value zeroCst = rewriter.create(loc, 0); - Value ubCstTiledLoop = rewriter.create(loc, upperBound); + Value ubCstTiledLoop = + rewriter.create(loc, upperBound); - Value stepCstTiledLoop = rewriter.create(loc, *itrShapeMNK); + Value stepCstTiledLoop = + rewriter.create(loc, *itrShapeMNK); // Creates M, N, and K tile loops - scf::ForOp loopOp = rewriter.create(brgemmOp.getLoc(), - zeroCst, ubCstTiledLoop, stepCstTiledLoop); + scf::ForOp loopOp = rewriter.create( + brgemmOp.getLoc(), zeroCst, ubCstTiledLoop, stepCstTiledLoop); rewriter.setInsertionPointToStart(loopOp.getBody()); innermostForLoop = loopOp; - // Stores the induction variable with respect to the operands mapping it's subview. + // Stores the induction variable with respect to the operands mapping it's + // subview. if (i == 0) { - inductionVars[0][1] = loopOp.getInductionVar(); - inductionVars[2][0] = loopOp.getInductionVar(); - } else if(i == 1) { - inductionVars[1][2] = loopOp.getInductionVar(); - inductionVars[2][1] = loopOp.getInductionVar(); - //Creates reduction loop after the N loop + inductionVars[0][1] = loopOp.getInductionVar(); + inductionVars[2][0] = loopOp.getInductionVar(); + } else if (i == 1) { + inductionVars[1][2] = loopOp.getInductionVar(); + inductionVars[2][1] = loopOp.getInductionVar(); + // Creates reduction loop after the N loop Value ubCstReduction = rewriter.create( loc, dyn_cast(brgemmOp.getOperand(0).getType()) .getShape()[0]); - Value stepCstReduction = rewriter.create(loc, 1); + Value stepCstReduction = + rewriter.create(loc, 1); scf::ForOp redloopOp = rewriter.create( brgemmOp.getLoc(), zeroCst, ubCstReduction, stepCstReduction); rewriter.setInsertionPointToStart(redloopOp.getBody()); inductionVars[0][0] = redloopOp.getInductionVar(); inductionVars[1][0] = redloopOp.getInductionVar(); - } else if(i == 2) { - inductionVars[0][2] = loopOp.getInductionVar(); - inductionVars[1][1] = loopOp.getInductionVar(); + } else if (i == 2) { + inductionVars[0][2] = loopOp.getInductionVar(); + inductionVars[1][1] = loopOp.getInductionVar(); } } @@ -162,13 +169,14 @@ struct LinalgOpTiling : OpRewritePattern { auto tensorShape = dyn_cast(input.getType()).getShape(); auto tileItr = tileshapes[i].begin(); - // Iterates over the shape of each tensor and update its offsets, indices, shapes, strides with respect to tile sizes + // Iterates over the shape of each tensor and update its offsets, indices, + // shapes, strides with respect to tile sizes for (size_t j = 0; j < tensorShape.size(); j++) { if (j == 0 && (i < 2)) { // Updates the batch dimension - offsets.push_back(inductionVars[i][j]); - indices.push_back(1); - shape.push_back(rewriter.getIndexAttr(1)); - strides.push_back(rewriter.getIndexAttr(1)); + offsets.push_back(inductionVars[i][j]); + indices.push_back(1); + shape.push_back(rewriter.getIndexAttr(1)); + strides.push_back(rewriter.getIndexAttr(1)); } else if (j < 3) { // Updates the M, N, and K dimensions offsets.push_back(inductionVars[i][j]); indices.push_back((*tileItr)); @@ -176,17 +184,15 @@ struct LinalgOpTiling : OpRewritePattern { strides.push_back(rewriter.getIndexAttr(1)); tileItr++; } else { // Just copies the vnni layout dimensions - offsets.push_back(rewriter.getIndexAttr(0)); - indices.push_back(tensorShape[j]); - shape.push_back(rewriter.getIndexAttr(tensorShape[j])); - strides.push_back(rewriter.getIndexAttr(1)); + offsets.push_back(rewriter.getIndexAttr(0)); + indices.push_back(tensorShape[j]); + shape.push_back(rewriter.getIndexAttr(tensorShape[j])); + strides.push_back(rewriter.getIndexAttr(1)); } - } auto subview = rewriter.create( - brgemmOp.getLoc(), MemRefType(), - input, offsets, shape, strides); + brgemmOp.getLoc(), MemRefType(), input, offsets, shape, strides); brgemmOp.setOperand(i, subview); } @@ -204,11 +210,14 @@ struct LinalgOpTiling : OpRewritePattern { }; void populateBrgemmLinalgTilingPatterns(RewritePatternSet &patterns, - BrgemmLinalgTilingOptions options) { - patterns.add, LinalgOpTiling>(patterns.getContext(), options); + BrgemmLinalgTilingOptions options) { + patterns.add, + LinalgOpTiling>( + patterns.getContext(), options); } -struct BrgemmLinalgTiling : public tpp::impl::BrgemmLinalgTilingBase { +struct BrgemmLinalgTiling + : public tpp::impl::BrgemmLinalgTilingBase { using BrgemmLinalgTilingBase::BrgemmLinalgTilingBase; diff --git a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir index 2473bbf79..5d647a842 100644 --- a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir +++ b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir @@ -1,4 +1,5 @@ -// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=8,32" --split-input-file | FileCheck %s +// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=8,32" --split-input-file | FileCheck -check-prefix=CONF1 %s +// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" --canonicalize --split-input-file | FileCheck -check-prefix=CONF2 %s module { func.func @gemm_do_register_tiling(%arg0: memref<16x32x16x32xf32>, %arg1: memref<32x32x32x32xf32>, %arg2: memref<16x32x16x32xf32>) { @@ -12,34 +13,34 @@ module { } } -// CHECK-LABEL: func.func @gemm_do_register_tiling( -// CHECK-SAME: %[[VAL_0:.*]]: memref<16x32x16x32xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: memref<32x32x32x32xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: memref<16x32x16x32xf32>) { -// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_4:.*]] = arith.constant 32 : index -// CHECK: %[[VAL_5:.*]] = arith.constant 8 : index -// CHECK: %[[VAL_6:.*]] = arith.constant 16 : index -// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index -// CHECK: scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) in (16, 32) { -// CHECK: %[[VAL_10:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_8]], 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> -// CHECK: %[[VAL_11:.*]] = memref.subview %[[VAL_1]]{{\[}}%[[VAL_9]], 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<16x32xf32, strided<[32, 1], offset: ?>> -// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_5]] { -// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_4]] { -// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_3]] { -// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_3]] { -// CHECK: %[[VAL_17:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_15]], %[[VAL_13]], %[[VAL_16]]] [1, 8, 1] [1, 1, 1] : memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[512, 32, 1], offset: ?>> -// CHECK: %[[VAL_18:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_15]], %[[VAL_16]], %[[VAL_14]]] [1, 1, 32] [1, 1, 1] : memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[VAL_19:.*]] = memref.subview %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_14]]] [8, 32] [1, 1] : memref<16x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>> -// CHECK: linalg.batch_reduce_matmul ins(%[[VAL_17]], %[[VAL_18]] : memref<1x8x1xf32, strided<[512, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_19]] : memref<8x32xf32, strided<[32, 1], offset: ?>>) -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: return -// CHECK: } +// CONF1-LABEL: func.func @gemm_do_register_tiling( +// CONF1-SAME: %[[VAL_0:.*]]: memref<16x32x16x32xf32>, +// CONF1-SAME: %[[VAL_1:.*]]: memref<32x32x32x32xf32>, +// CONF1-SAME: %[[VAL_2:.*]]: memref<16x32x16x32xf32>) { +// CONF1: %[[VAL_3:.*]] = arith.constant 1 : index +// CONF1: %[[VAL_4:.*]] = arith.constant 32 : index +// CONF1: %[[VAL_5:.*]] = arith.constant 8 : index +// CONF1: %[[VAL_6:.*]] = arith.constant 16 : index +// CONF1: %[[VAL_7:.*]] = arith.constant 0 : index +// CONF1: scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) in (16, 32) { +// CONF1: %[[VAL_10:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_8]], 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> +// CONF1: %[[VAL_11:.*]] = memref.subview %[[VAL_1]]{{\[}}%[[VAL_9]], 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CONF1: %[[VAL_12:.*]] = memref.subview %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<16x32xf32, strided<[32, 1], offset: ?>> +// CONF1: scf.for %[[VAL_13:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_5]] { +// CONF1: scf.for %[[VAL_14:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_4]] { +// CONF1: scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_3]] { +// CONF1: scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_3]] { +// CONF1: %[[VAL_17:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_15]], %[[VAL_13]], %[[VAL_16]]] [1, 8, 1] [1, 1, 1] : memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[512, 32, 1], offset: ?>> +// CONF1: %[[VAL_18:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_15]], %[[VAL_16]], %[[VAL_14]]] [1, 1, 32] [1, 1, 1] : memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>> +// CONF1: %[[VAL_19:.*]] = memref.subview %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_14]]] [8, 32] [1, 1] : memref<16x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>> +// CONF1: linalg.batch_reduce_matmul ins(%[[VAL_17]], %[[VAL_18]] : memref<1x8x1xf32, strided<[512, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_19]] : memref<8x32xf32, strided<[32, 1], offset: ?>>) +// CONF1: } +// CONF1: } +// CONF1: } +// CONF1: } +// CONF1: } +// CONF1: return +// CONF1: } // ----- @@ -72,68 +73,194 @@ module { } } -// CHECK-LABEL: memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64} -// CHECK-LABEL: func.func @chainned_gemm_do_register_tiling( -// CHECK-SAME: %[[VAL_0:.*]]: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> { -// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_2:.*]] = arith.constant 48 : index -// CHECK: %[[VAL_3:.*]] = arith.constant 8 : index -// CHECK: %[[VAL_4:.*]] = arith.constant 32 : index -// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_7:.*]] = memref.get_global @__constant_48x32x32xf32 : memref<48x32x32xf32> -// CHECK: %[[VAL_8:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> -// CHECK: scf.forall (%[[VAL_9:.*]], %[[VAL_10:.*]]) in (8, 48) { -// CHECK: %[[VAL_11:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_9]], %[[VAL_10]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> -// CHECK: linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_11]] : memref<32x32xf32, strided<[32, 1], offset: ?>>) -// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_9]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] { -// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_4]] { -// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_1]] { -// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_1]] { -// CHECK: %[[VAL_17:.*]] = memref.subview %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_13]], %[[VAL_16]]] [1, 8, 1] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[VAL_18:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_15]], %[[VAL_16]], %[[VAL_14]]] [1, 1, 32] [1, 1, 1] : memref<48x32x32xf32> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[VAL_19:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_13]], %[[VAL_14]]] [8, 32] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>> -// CHECK: linalg.batch_reduce_matmul ins(%[[VAL_17]], %[[VAL_18]] : memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_19]] : memref<8x32xf32, strided<[32, 1], offset: ?>>) -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: %[[VAL_20:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> -// CHECK: scf.forall (%[[VAL_21:.*]], %[[VAL_22:.*]]) in (8, 48) { -// CHECK: %[[VAL_23:.*]] = memref.subview %[[VAL_20]]{{\[}}%[[VAL_21]], %[[VAL_22]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> -// CHECK: linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_23]] : memref<32x32xf32, strided<[32, 1], offset: ?>>) -// CHECK: %[[VAL_24:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_21]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] { -// CHECK: scf.for %[[VAL_26:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_4]] { -// CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_1]] { -// CHECK: scf.for %[[VAL_28:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_1]] { -// CHECK: %[[VAL_29:.*]] = memref.subview %[[VAL_24]]{{\[}}%[[VAL_27]], %[[VAL_25]], %[[VAL_28]]] [1, 8, 1] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[VAL_30:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_27]], %[[VAL_28]], %[[VAL_26]]] [1, 1, 32] [1, 1, 1] : memref<48x32x32xf32> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[VAL_31:.*]] = memref.subview %[[VAL_23]]{{\[}}%[[VAL_25]], %[[VAL_26]]] [8, 32] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>> -// CHECK: linalg.batch_reduce_matmul ins(%[[VAL_29]], %[[VAL_30]] : memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_31]] : memref<8x32xf32, strided<[32, 1], offset: ?>>) -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: scf.forall (%[[VAL_32:.*]], %[[VAL_33:.*]]) in (8, 48) { -// CHECK: %[[VAL_34:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_32]], %[[VAL_33]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> -// CHECK: linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_34]] : memref<32x32xf32, strided<[32, 1], offset: ?>>) -// CHECK: %[[VAL_35:.*]] = memref.subview %[[VAL_20]]{{\[}}%[[VAL_32]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: scf.for %[[VAL_36:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] { -// CHECK: scf.for %[[VAL_37:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_4]] { -// CHECK: scf.for %[[VAL_38:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_1]] { -// CHECK: scf.for %[[VAL_39:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_1]] { -// CHECK: %[[VAL_40:.*]] = memref.subview %[[VAL_35]]{{\[}}%[[VAL_38]], %[[VAL_36]], %[[VAL_39]]] [1, 8, 1] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[VAL_41:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_38]], %[[VAL_39]], %[[VAL_37]]] [1, 1, 32] [1, 1, 1] : memref<48x32x32xf32> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[VAL_42:.*]] = memref.subview %[[VAL_34]]{{\[}}%[[VAL_36]], %[[VAL_37]]] [8, 32] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>> -// CHECK: linalg.batch_reduce_matmul ins(%[[VAL_40]], %[[VAL_41]] : memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_42]] : memref<8x32xf32, strided<[32, 1], offset: ?>>) -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: return %[[VAL_8]] : memref<8x48x32x32xf32> -// CHECK: } +// CONF1-LABEL: memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64} +// CONF1-LABEL: func.func @chainned_gemm_do_register_tiling( +// CONF1-SAME: %[[VAL_0:.*]]: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> { +// CONF1: %[[VAL_1:.*]] = arith.constant 1 : index +// CONF1: %[[VAL_2:.*]] = arith.constant 48 : index +// CONF1: %[[VAL_3:.*]] = arith.constant 8 : index +// CONF1: %[[VAL_4:.*]] = arith.constant 32 : index +// CONF1: %[[VAL_5:.*]] = arith.constant 0 : index +// CONF1: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 +// CONF1: %[[VAL_7:.*]] = memref.get_global @__constant_48x32x32xf32 : memref<48x32x32xf32> +// CONF1: %[[VAL_8:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> +// CONF1: scf.forall (%[[VAL_9:.*]], %[[VAL_10:.*]]) in (8, 48) { +// CONF1: %[[VAL_11:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_9]], %[[VAL_10]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> +// CONF1: linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_11]] : memref<32x32xf32, strided<[32, 1], offset: ?>>) +// CONF1: %[[VAL_12:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_9]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CONF1: scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] { +// CONF1: scf.for %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_4]] { +// CONF1: scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_1]] { +// CONF1: scf.for %[[VAL_16:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_1]] { +// CONF1: %[[VAL_17:.*]] = memref.subview %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_13]], %[[VAL_16]]] [1, 8, 1] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>> +// CONF1: %[[VAL_18:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_15]], %[[VAL_16]], %[[VAL_14]]] [1, 1, 32] [1, 1, 1] : memref<48x32x32xf32> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>> +// CONF1: %[[VAL_19:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_13]], %[[VAL_14]]] [8, 32] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>> +// CONF1: linalg.batch_reduce_matmul ins(%[[VAL_17]], %[[VAL_18]] : memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_19]] : memref<8x32xf32, strided<[32, 1], offset: ?>>) +// CONF1: } +// CONF1: } +// CONF1: } +// CONF1: } +// CONF1: } +// CONF1: %[[VAL_20:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> +// CONF1: scf.forall (%[[VAL_21:.*]], %[[VAL_22:.*]]) in (8, 48) { +// CONF1: %[[VAL_23:.*]] = memref.subview %[[VAL_20]]{{\[}}%[[VAL_21]], %[[VAL_22]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> +// CONF1: linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_23]] : memref<32x32xf32, strided<[32, 1], offset: ?>>) +// CONF1: %[[VAL_24:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_21]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CONF1: scf.for %[[VAL_25:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] { +// CONF1: scf.for %[[VAL_26:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_4]] { +// CONF1: scf.for %[[VAL_27:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_1]] { +// CONF1: scf.for %[[VAL_28:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_1]] { +// CONF1: %[[VAL_29:.*]] = memref.subview %[[VAL_24]]{{\[}}%[[VAL_27]], %[[VAL_25]], %[[VAL_28]]] [1, 8, 1] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>> +// CONF1: %[[VAL_30:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_27]], %[[VAL_28]], %[[VAL_26]]] [1, 1, 32] [1, 1, 1] : memref<48x32x32xf32> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>> +// CONF1: %[[VAL_31:.*]] = memref.subview %[[VAL_23]]{{\[}}%[[VAL_25]], %[[VAL_26]]] [8, 32] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>> +// CONF1: linalg.batch_reduce_matmul ins(%[[VAL_29]], %[[VAL_30]] : memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_31]] : memref<8x32xf32, strided<[32, 1], offset: ?>>) +// CONF1: } +// CONF1: } +// CONF1: } +// CONF1: } +// CONF1: } +// CONF1: scf.forall (%[[VAL_32:.*]], %[[VAL_33:.*]]) in (8, 48) { +// CONF1: %[[VAL_34:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_32]], %[[VAL_33]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> +// CONF1: linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_34]] : memref<32x32xf32, strided<[32, 1], offset: ?>>) +// CONF1: %[[VAL_35:.*]] = memref.subview %[[VAL_20]]{{\[}}%[[VAL_32]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CONF1: scf.for %[[VAL_36:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] { +// CONF1: scf.for %[[VAL_37:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_4]] { +// CONF1: scf.for %[[VAL_38:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_1]] { +// CONF1: scf.for %[[VAL_39:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_1]] { +// CONF1: %[[VAL_40:.*]] = memref.subview %[[VAL_35]]{{\[}}%[[VAL_38]], %[[VAL_36]], %[[VAL_39]]] [1, 8, 1] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>> +// CONF1: %[[VAL_41:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_38]], %[[VAL_39]], %[[VAL_37]]] [1, 1, 32] [1, 1, 1] : memref<48x32x32xf32> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>> +// CONF1: %[[VAL_42:.*]] = memref.subview %[[VAL_34]]{{\[}}%[[VAL_36]], %[[VAL_37]]] [8, 32] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>> +// CONF1: linalg.batch_reduce_matmul ins(%[[VAL_40]], %[[VAL_41]] : memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_42]] : memref<8x32xf32, strided<[32, 1], offset: ?>>) +// CONF1: } +// CONF1: } +// CONF1: } +// CONF1: } +// CONF1: } +// CONF1: return %[[VAL_8]] : memref<8x48x32x32xf32> +// CONF1: } + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> +module { + memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @gemm_32tiles_do_tiling(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> + %expand_shape = memref.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2] : memref<8x32x32x32xbf16> into memref<8x32x32x16x2xbf16> + scf.forall (%arg1, %arg2) in (8, 32) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + linalg.fill ins(%cst : bf16) outs(%subview : memref<32x32xbf16, strided<[32, 1], offset: ?>>) + %subview_0 = memref.subview %expand_shape[%arg1, 0, 0, 0, 0] [1, 32, 32, 16, 2] [1, 1, 1, 1, 1] : memref<8x32x32x16x2xbf16> to memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%subview_0, %0 : memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>, memref<32x16x32x2xbf16>) outs(%subview : memref<32x32xbf16, strided<[32, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_1: bf16, %out: bf16): + %1 = arith.mulf %in, %in_1 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + } + return %alloc : memref<8x32x32x32xbf16> + } +} + +// CONF2: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> +// CONF2: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> +// CONF2: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> + +// CONF2-LABEL: memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} +// CONF2-LABEL: func.func @gemm_32tiles_do_tiling( +// CONF2-SAME: %[[VAL_0:.*]]: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { +// CONF2: %[[VAL_1:.*]] = arith.constant 1 : index +// CONF2: %[[VAL_2:.*]] = arith.constant 32 : index +// CONF2: %[[VAL_3:.*]] = arith.constant 0 : index +// CONF2: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : bf16 +// CONF2: %[[VAL_5:.*]] = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> +// CONF2: %[[VAL_6:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> +// CONF2: %[[VAL_7:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2] : memref<8x32x32x32xbf16> into memref<8x32x32x16x2xbf16> +// CONF2: scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) in (8, 32) { +// CONF2: %[[VAL_10:.*]] = memref.subview %[[VAL_6]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> +// CONF2: linalg.fill ins(%[[VAL_4]] : bf16) outs(%[[VAL_10]] : memref<32x32xbf16, strided<[32, 1], offset: ?>>) +// CONF2: %[[VAL_11:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_8]], 0, 0, 0, 0] [1, 32, 32, 16, 2] [1, 1, 1, 1, 1] : memref<8x32x32x16x2xbf16> to memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>> +// CONF2: scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_1]] { +// CONF2: %[[VAL_13:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>> to memref<1x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>> +// CONF2: %[[VAL_14:.*]] = memref.subview %[[VAL_5]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<32x16x32x2xbf16> to memref<1x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> +// CONF2: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%[[VAL_13]], %[[VAL_14]] : memref<1x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>, memref<1x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%[[VAL_10]] : memref<32x32xbf16, strided<[32, 1], offset: ?>>) { +// CONF2: ^bb0(%[[VAL_15:.*]]: bf16, %[[VAL_16:.*]]: bf16, %[[VAL_17:.*]]: bf16): +// CONF2: %[[VAL_18:.*]] = arith.mulf %[[VAL_15]], %[[VAL_16]] : bf16 +// CONF2: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : bf16 +// CONF2: linalg.yield %[[VAL_19]] : bf16 +// CONF2: } +// CONF2: } +// CONF2: } +// CONF2: return %[[VAL_6]] : memref<8x32x32x32xbf16> +// CONF2: } + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> +module { + memref.global "private" constant @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @gemm_64tiles_do_tiling(%arg0: memref<4x16x64x64xbf16>) -> memref<4x16x64x64xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = memref.get_global @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x64x64xbf16> + %expand_shape = memref.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [4, 16, 64, 32, 2] : memref<4x16x64x64xbf16> into memref<4x16x64x32x2xbf16> + scf.forall (%arg1, %arg2) in (4, 16) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<4x16x64x64xbf16> to memref<64x64xbf16, strided<[64, 1], offset: ?>> + linalg.fill ins(%cst : bf16) outs(%subview : memref<64x64xbf16, strided<[64, 1], offset: ?>>) + %subview_0 = memref.subview %expand_shape[%arg1, 0, 0, 0, 0] [1, 16, 64, 32, 2] [1, 1, 1, 1, 1] : memref<4x16x64x32x2xbf16> to memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%subview_0, %0 : memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>, memref<16x32x64x2xbf16>) outs(%subview : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_1: bf16, %out: bf16): + %1 = arith.mulf %in, %in_1 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + } + return %alloc : memref<4x16x64x64xbf16> + } +} + +// CONF2: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> +// CONF2: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> +// CONF2: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> +// CONF2-LABEL: memref.global "private" constant @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} +// CONF2-LABEL: func.func @gemm_64tiles_do_tiling( +// CONF2-SAME: %[[VAL_0:.*]]: memref<4x16x64x64xbf16>) -> memref<4x16x64x64xbf16> { +// CONF2: %[[VAL_1:.*]] = arith.constant 1 : index +// CONF2: %[[VAL_2:.*]] = arith.constant 16 : index +// CONF2: %[[VAL_3:.*]] = arith.constant 32 : index +// CONF2: %[[VAL_4:.*]] = arith.constant 64 : index +// CONF2: %[[VAL_5:.*]] = arith.constant 0 : index +// CONF2: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : bf16 +// CONF2: %[[VAL_7:.*]] = memref.get_global @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> +// CONF2: %[[VAL_8:.*]] = memref.alloc() {alignment = 64 : i64} : memref<4x16x64x64xbf16> +// CONF2: %[[VAL_9:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0], [1], [2], [3, 4]] output_shape [4, 16, 64, 32, 2] : memref<4x16x64x64xbf16> into memref<4x16x64x32x2xbf16> +// CONF2: scf.forall (%[[VAL_10:.*]], %[[VAL_11:.*]]) in (4, 16) { +// CONF2: %[[VAL_12:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_10]], %[[VAL_11]], 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<4x16x64x64xbf16> to memref<64x64xbf16, strided<[64, 1], offset: ?>> +// CONF2: linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_12]] : memref<64x64xbf16, strided<[64, 1], offset: ?>>) +// CONF2: %[[VAL_13:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_10]], 0, 0, 0, 0] [1, 16, 64, 32, 2] [1, 1, 1, 1, 1] : memref<4x16x64x32x2xbf16> to memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> +// CONF2: scf.for %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] { +// CONF2: scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] { +// CONF2: scf.for %[[VAL_16:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_1]] { +// CONF2: scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_2]] { +// CONF2: %[[VAL_18:.*]] = memref.subview %[[VAL_13]]{{\[}}%[[VAL_16]], %[[VAL_14]], %[[VAL_17]], 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> to memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> +// CONF2: %[[VAL_19:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_15]], 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x32x64x2xbf16> to memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>> +// CONF2: %[[VAL_20:.*]] = memref.subview %[[VAL_12]]{{\[}}%[[VAL_14]], %[[VAL_15]]] [32, 32] [1, 1] : memref<64x64xbf16, strided<[64, 1], offset: ?>> to memref<32x32xbf16, strided<[64, 1], offset: ?>> +// CONF2: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%[[VAL_18]], %[[VAL_19]] : memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>, memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>>) outs(%[[VAL_20]] : memref<32x32xbf16, strided<[64, 1], offset: ?>>) { +// CONF2: ^bb0(%[[VAL_21:.*]]: bf16, %[[VAL_22:.*]]: bf16, %[[VAL_23:.*]]: bf16): +// CONF2: %[[VAL_24:.*]] = arith.mulf %[[VAL_21]], %[[VAL_22]] : bf16 +// CONF2: %[[VAL_25:.*]] = arith.addf %[[VAL_23]], %[[VAL_24]] : bf16 +// CONF2: linalg.yield %[[VAL_25]] : bf16 +// CONF2: } +// CONF2: } +// CONF2: } +// CONF2: } +// CONF2: } +// CONF2: } +// CONF2: return %[[VAL_8]] : memref<4x16x64x64xbf16> +// CONF2: } From 76534344eee6fd3aba1656ba729df6c517a3cf7f Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Sun, 2 Feb 2025 19:33:51 -0800 Subject: [PATCH 03/15] code re-factoring - adding comments, space etc.,. --- lib/TPP/Transforms/BrgemmLinalgTiling.cpp | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index fcdc645f6..cf50dd4b6 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -89,20 +89,18 @@ struct LinalgOpTiling : OpRewritePattern { mxnxkTile[2] = mxnxkTile[2] / tensorShape[3]; } - SmallVector swap_i = {0, 2, 1}; size_t i = 0; + SmallVector swap_i = {0, 2, 1}; std::map> inductionVars; // For M, N, and K loops scf::ForOp innermostForLoop; - // Creating the tiled loops for (auto itrShapeMNK = mxnxkTile.begin(); itrShapeMNK != mxnxkTile.end(); itrShapeMNK++, i++) { auto upperBound = dyn_cast(brgemmOp.getOperand(swap_i[i]).getType()) .getShape()[1]; - // Tile size should not be greater than the upperBound if ((*itrShapeMNK) > upperBound) return failure(); @@ -111,7 +109,6 @@ struct LinalgOpTiling : OpRewritePattern { Value zeroCst = rewriter.create(loc, 0); Value ubCstTiledLoop = rewriter.create(loc, upperBound); - Value stepCstTiledLoop = rewriter.create(loc, *itrShapeMNK); // Creates M, N, and K tile loops @@ -122,10 +119,10 @@ struct LinalgOpTiling : OpRewritePattern { // Stores the induction variable with respect to the operands mapping it's // subview. - if (i == 0) { + if (i == 0) { // Stores iv for M loop inductionVars[0][1] = loopOp.getInductionVar(); inductionVars[2][0] = loopOp.getInductionVar(); - } else if (i == 1) { + } else if (i == 1) { //stores iv for N loop, creates batch loop, and maps iv of batch loop inductionVars[1][2] = loopOp.getInductionVar(); inductionVars[2][1] = loopOp.getInductionVar(); // Creates reduction loop after the N loop @@ -140,7 +137,7 @@ struct LinalgOpTiling : OpRewritePattern { inductionVars[0][0] = redloopOp.getInductionVar(); inductionVars[1][0] = redloopOp.getInductionVar(); - } else if (i == 2) { + } else if (i == 2) { // stores iv for k-loop inductionVars[0][2] = loopOp.getInductionVar(); inductionVars[1][1] = loopOp.getInductionVar(); } From 09d9ae7b3344ba744992d29e0a2128630932caa0 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Mon, 3 Feb 2025 06:53:54 -0800 Subject: [PATCH 04/15] Added a integration test-case for bf16 type --- .../tile-brgemm-linalg-matmul-bf16.mlir | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 test/Integration/tile-brgemm-linalg-matmul-bf16.mlir diff --git a/test/Integration/tile-brgemm-linalg-matmul-bf16.mlir b/test/Integration/tile-brgemm-linalg-matmul-bf16.mlir new file mode 100644 index 000000000..7221f7cb9 --- /dev/null +++ b/test/Integration/tile-brgemm-linalg-matmul-bf16.mlir @@ -0,0 +1,30 @@ +// RUN: tpp-run -e register_tile_bf16 --entry-point-result=void -print %s > %t.1 +// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" -convert-linalg-to-xsmm | tpp-run -e register_tile_bf16 --entry-point-result=void -print > %t.2 +// RUN: diff %t.1 %t.2 +// RUN: rm %t.1 %t.2 + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> +module { + memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @register_tile_bf16(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> + %expand_shape = memref.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2] : memref<8x32x32x32xbf16> into memref<8x32x32x16x2xbf16> + scf.forall (%arg1, %arg2) in (8, 32) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + linalg.fill ins(%cst : bf16) outs(%subview : memref<32x32xbf16, strided<[32, 1], offset: ?>>) + %subview_0 = memref.subview %expand_shape[%arg1, 0, 0, 0, 0] [1, 32, 32, 16, 2] [1, 1, 1, 1, 1] : memref<8x32x32x16x2xbf16> to memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%subview_0, %0 : memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>, memref<32x16x32x2xbf16>) outs(%subview : memref<32x32xbf16, strided<[32, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_1: bf16, %out: bf16): + %1 = arith.mulf %in, %in_1 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + } + return %alloc : memref<8x32x32x32xbf16> + } +} + From ef08d660ad7b6be43e94fd3a0feb33330a896099 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Wed, 12 Feb 2025 22:14:11 -0800 Subject: [PATCH 05/15] Added extra validation, made m,n,k tile size as mandatory --- benchmarks/config/base/base.json | 20 +++--- .../omp/mlir-fp32-vector-to-kernel.json | 64 +++++++++---------- .../omp/torch-dynamo-vector-to-kernel.json | 16 ++--- lib/TPP/Transforms/BrgemmLinalgTiling.cpp | 64 +++++++++++-------- .../tile-brgemm-linalg-matmul.mlir | 2 +- .../pass-tile-brgemm-linalg-matmul.mlir | 2 +- 6 files changed, 90 insertions(+), 78 deletions(-) diff --git a/benchmarks/config/base/base.json b/benchmarks/config/base/base.json index f8495339d..39cfa96b3 100644 --- a/benchmarks/config/base/base.json +++ b/benchmarks/config/base/base.json @@ -40,21 +40,21 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": {}, - "flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": ["avx512.*"] }, "gemm_fp32_mlir_vector_avx2": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": {}, - "flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,16 '" ], + "flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,16,1 '" ], "extensions": ["avx2"] }, "gemm_fp32_mlir_vector_sve": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": {}, - "flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,32 '" ], + "flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,32,1 '" ], "extensions": ["asimd"] }, "gemm_bf16_dp2_mlir": { @@ -82,21 +82,21 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": {}, - "flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": ["avx512.*"] }, "mlp_fp32_mlir_vector_avx2": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": {}, - "flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,16 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,16,1 '" ], "extensions": ["avx2" ] }, "mlp_fp32_mlir_vector_sve": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": {}, - "flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,32,1 '" ], "extensions": ["asimd"] }, "mlp_bf16_dp2_mlir": { @@ -127,7 +127,7 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024" ], "environment": {}, - "flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_args_mlir": { @@ -141,7 +141,7 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=args --float-type=f32 --batch=256 --layers=1024,1024,1024,1024" ], "environment": {}, - "flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ "avx512.*" ] }, "bf16_3x1024_const_mlir": { @@ -172,7 +172,7 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024" ], "environment": {}, - "flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_args_mlir": { @@ -186,7 +186,7 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=args --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024" ], "environment": {}, - "flags": [ "-n", "100", "-run-args=' --def-parallel --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args=' --def-parallel --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ "avx512.*" ] }, "bf16_3x1024_const_mlir": { diff --git a/benchmarks/config/omp/mlir-fp32-vector-to-kernel.json b/benchmarks/config/omp/mlir-fp32-vector-to-kernel.json index a5f47a4bd..bf6c3a9cf 100644 --- a/benchmarks/config/omp/mlir-fp32-vector-to-kernel.json +++ b/benchmarks/config/omp/mlir-fp32-vector-to-kernel.json @@ -5,28 +5,28 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_omp_4_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_omp_8_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_omp_16_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ "avx512.*" ] } }}, @@ -36,28 +36,28 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_omp_4_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_omp_8_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_omp_16_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ "avx512.*" ] } }}, @@ -67,28 +67,28 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=64,64,64" ], "environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,64 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,64,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_omp_4_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=64,64,64" ], "environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,4 --vector-to-kernels --registerBlocking=4,64 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,4 --vector-to-kernels --registerBlocking=4,64,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_omp_8_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=64,64,64" ], "environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,4 --vector-to-kernels --registerBlocking=4,64 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,4 --vector-to-kernels --registerBlocking=4,64,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_omp_16_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=64,64,64" ], "environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=1,4 --vector-to-kernels --registerBlocking=4,64 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=1,4 --vector-to-kernels --registerBlocking=4,64,1 '" ], "extensions": [ "avx512.*" ] } }}, @@ -98,28 +98,28 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=64,64,64" ], "environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,64 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,64,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_omp_4_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=64,64,64" ], "environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,4 --vector-to-kernels --registerBlocking=4,64 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,4 --vector-to-kernels --registerBlocking=4,64,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_omp_8_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=64,64,64" ], "environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,4 --vector-to-kernels --registerBlocking=4,64 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,4 --vector-to-kernels --registerBlocking=4,64,1 '" ], "extensions": [ "avx512.*" ] }, "fp32_3x1024_omp_16_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=64,64,64" ], "environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=1,4 --vector-to-kernels --registerBlocking=4,64 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=1,4 --vector-to-kernels --registerBlocking=4,64,1 '" ], "extensions": [ "avx512.*" ] } }}, @@ -129,28 +129,28 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=4,16 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=4,16,1 '" ], "extensions": [ "avx2" ] }, "fp32_3x1024_omp_4_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=4,16 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=4,16,1 '" ], "extensions": [ "avx2" ] }, "fp32_3x1024_omp_8_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,16 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,16,1 '" ], "extensions": [ "avx2" ] }, "fp32_3x1024_omp_16_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=4,16 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=4,16,1 '" ], "extensions": [ "avx2" ] } }}, @@ -160,28 +160,28 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=4,16 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=4,16,1 '" ], "extensions": [ "avx2" ] }, "fp32_3x1024_omp_4_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=4,16 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=4,16,1 '" ], "extensions": [ "avx2" ] }, "fp32_3x1024_omp_8_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,16 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,16,1 '" ], "extensions": [ "avx2" ] }, "fp32_3x1024_omp_16_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=4,16 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=4,16,1 '" ], "extensions": [ "avx2" ] } }}, @@ -191,28 +191,28 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=4,32 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], "extensions": [ "asimd" ] }, "fp32_3x1024_omp_4_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=4,32 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], "extensions": [ "asimd" ] }, "fp32_3x1024_omp_8_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,32 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], "extensions": [ "asimd" ] }, "fp32_3x1024_omp_16_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=4,32 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], "extensions": [ "asimd" ] } }}, @@ -222,28 +222,28 @@ "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=4,32 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], "extensions": [ "asimd" ] }, "fp32_3x1024_omp_4_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=4,32 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], "extensions": [ "asimd" ] }, "fp32_3x1024_omp_8_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,32 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], "extensions": [ "asimd" ] }, "fp32_3x1024_omp_16_mlir": { "type": "IR-GEN", "benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ], "environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=4,32 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ], "extensions": [ "asimd" ] } }} diff --git a/benchmarks/config/omp/torch-dynamo-vector-to-kernel.json b/benchmarks/config/omp/torch-dynamo-vector-to-kernel.json index f1d7b5ebf..da891b048 100644 --- a/benchmarks/config/omp/torch-dynamo-vector-to-kernel.json +++ b/benchmarks/config/omp/torch-dynamo-vector-to-kernel.json @@ -5,28 +5,28 @@ "type": "MLIR", "benchmark": "pytorch/torch-dynamo-gemm-fp32-3x1024.mlir", "environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ ] }, "fp32_3x1024_omp_4_mlir": { "type": "MLIR", "benchmark": "pytorch/torch-dynamo-gemm-fp32-3x1024.mlir", "environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ ] }, "fp32_3x1024_omp_8_mlir": { "type": "MLIR", "benchmark": "pytorch/torch-dynamo-gemm-fp32-3x1024.mlir", "environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ ] }, "fp32_3x1024_omp_16_mlir": { "type": "MLIR", "benchmark": "pytorch/torch-dynamo-gemm-fp32-3x1024.mlir", "environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ ] } }}, @@ -36,28 +36,28 @@ "type": "MLIR", "benchmark": "pytorch/torch-dynamo-mlp-fp32-3x1024.mlir", "environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ ] }, "fp32_3x1024_omp_4_mlir": { "type": "MLIR", "benchmark": "pytorch/torch-dynamo-mlp-fp32-3x1024.mlir", "environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ ] }, "fp32_3x1024_omp_8_mlir": { "type": "MLIR", "benchmark": "pytorch/torch-dynamo-mlp-fp32-3x1024.mlir", "environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ ] }, "fp32_3x1024_omp_16_mlir": { "type": "MLIR", "benchmark": "pytorch/torch-dynamo-mlp-fp32-3x1024.mlir", "environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" }, - "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=8,32 '" ], + "flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=8,32,1 '" ], "extensions": [ ] } }} diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index cf50dd4b6..ac8e3bf16 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -1,5 +1,4 @@ -//===- BrgemmLinalgTiling.cpp -----------------------------------------*- -//C++-*-===// +//===- BrgemmLinalgTiling.cpp -----------------------------------------*-C++-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -28,6 +27,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" +#include "TPP/Transforms/Utils/VNNIUtils.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "brgemm-linalg-tiling" @@ -58,9 +58,9 @@ struct LinalgOpTiling : OpRewritePattern { return failure(); // Check whether the tile sizes are valid - if (options.registerTileShape.size() != 3 && - options.registerTileShape.size() != 2) - return failure(); + if (options.registerTileShape.size() != 3) + return rewriter.notifyMatchFailure(brgemmOp, + "Invalid user input tile sizes. Should be "); // Check the whether the operation is brmatmul fp32 or bf16 type using // reduction count @@ -69,23 +69,41 @@ struct LinalgOpTiling : OpRewritePattern { int reductionCount = std::count(brgemmIteratorTypes.begin(), brgemmIteratorTypes.end(), utils::IteratorType::reduction); - if (reductionCount != 2 && reductionCount != 3) - return failure(); - // Get the register blocking tile shape from the user input - SmallVector mxnxkTile(3); - for (size_t i = 0; i < options.registerTileShape.size(); i++) { - mxnxkTile[i] = options.registerTileShape[i]; - } + if (reductionCount == 0) + return rewriter.notifyMatchFailure( + brgemmOp, "Batch matmul operation not supported yet"); + + if (reductionCount == 1) + return rewriter.notifyMatchFailure( + brgemmOp, "Matmul operation not supported yet"); + + if (reductionCount > 3) + return rewriter.notifyMatchFailure( + brgemmOp, "The operation is not a gemm"); - // Set the K tile to 1, if the user not provided (it is fp32 target) - if (options.registerTileShape.size() == 2) - mxnxkTile[2] = 1; + auto tensorShapeLhs = + dyn_cast(brgemmOp.getOperand(0).getType()).getShape(); + auto tensorShapeRhs = + dyn_cast(brgemmOp.getOperand(1).getType()).getShape(); + + if (reductionCount == 2 && + (tensorShapeLhs.size() != 3 || tensorShapeRhs.size() != 3)) + return rewriter.notifyMatchFailure( + brgemmOp, "Invalid rank for batch reduce operation"); + + if (reductionCount == 3 && !vnni::utils::isInVnniLayout(brgemmOp)) + return rewriter.notifyMatchFailure( + brgemmOp, "Failed matching for batch reduce operation with vnni layout"); + + // Get the register blocking tile shape from the user input + SmallVector mxnxkTile(options.registerTileShape.begin(), + options.registerTileShape.end()); // k-tile size adjusted based on the vnni layout for bf16 type - auto tensorShape = + if (vnni::utils::isInVnniLayout(brgemmOp)) { + auto tensorShape = dyn_cast(brgemmOp.getOperand(0).getType()).getShape(); - if (tensorShape.size() == 4 && options.registerTileShape.size() == 3) { mxnxkTile[2] = mxnxkTile[2] / tensorShape[3]; } @@ -144,15 +162,9 @@ struct LinalgOpTiling : OpRewritePattern { } // DS to assist while creating new subviews with correct indices and shapes - SmallVector mxkTile(2); - SmallVector kxnTile(2); - SmallVector mxnTile(2); - mxkTile[0] = mxnxkTile[0]; - mxkTile[1] = mxnxkTile[2]; - kxnTile[0] = mxnxkTile[2]; - kxnTile[1] = mxnxkTile[1]; - mxnTile[0] = mxnxkTile[0]; - mxnTile[1] = mxnxkTile[1]; + SmallVector mxkTile{mxnxkTile[0], mxnxkTile[2]}; + SmallVector kxnTile{mxnxkTile[2], mxnxkTile[1]}; + SmallVector mxnTile{mxnxkTile[0], mxnxkTile[1]}; SmallVector> tileshapes{mxkTile, kxnTile, mxnTile}; // Creating subviews diff --git a/test/Integration/tile-brgemm-linalg-matmul.mlir b/test/Integration/tile-brgemm-linalg-matmul.mlir index e859543d3..3e8373723 100644 --- a/test/Integration/tile-brgemm-linalg-matmul.mlir +++ b/test/Integration/tile-brgemm-linalg-matmul.mlir @@ -1,5 +1,5 @@ // RUN: tpp-run -e entry --entry-point-result=void -print %s > %t.1 -// RUN: tpp-run -e entry --entry-point-result=void --vector-to-kernels --registerBlocking=8,32 %s -print > %t.2 +// RUN: tpp-run -e entry --entry-point-result=void --vector-to-kernels --registerBlocking=8,32,1 %s -print > %t.2 // RUN: diff %t.1 %t.2 // RUN: rm %t.1 %t.2 diff --git a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir index 5d647a842..5f9b23eba 100644 --- a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir +++ b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir @@ -1,4 +1,4 @@ -// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=8,32" --split-input-file | FileCheck -check-prefix=CONF1 %s +// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=8,32,1" --split-input-file | FileCheck -check-prefix=CONF1 %s // RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" --canonicalize --split-input-file | FileCheck -check-prefix=CONF2 %s module { From 974c3ebbabf97fae3f761cf8427763a8e7429359 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Thu, 13 Feb 2025 00:58:50 -0800 Subject: [PATCH 06/15] Code re-factoring, changes to test-cases --- lib/TPP/Transforms/BrgemmLinalgTiling.cpp | 17 +-- .../tile-brgemm-linalg-matmul-bf16.mlir | 3 +- .../tile-brgemm-linalg-matmul.mlir | 2 +- .../pass-tile-brgemm-linalg-matmul.mlir | 140 +++++------------- 4 files changed, 49 insertions(+), 113 deletions(-) diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index ac8e3bf16..f77035146 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -72,11 +72,11 @@ struct LinalgOpTiling : OpRewritePattern { if (reductionCount == 0) return rewriter.notifyMatchFailure( - brgemmOp, "Batch matmul operation not supported yet"); + brgemmOp, "Matmul operation not supported yet"); if (reductionCount == 1) return rewriter.notifyMatchFailure( - brgemmOp, "Matmul operation not supported yet"); + brgemmOp, "Batch matmul operation not supported yet"); if (reductionCount > 3) return rewriter.notifyMatchFailure( @@ -107,28 +107,27 @@ struct LinalgOpTiling : OpRewritePattern { mxnxkTile[2] = mxnxkTile[2] / tensorShape[3]; } - size_t i = 0; SmallVector swap_i = {0, 2, 1}; std::map> inductionVars; // For M, N, and K loops scf::ForOp innermostForLoop; // Creating the tiled loops - for (auto itrShapeMNK = mxnxkTile.begin(); itrShapeMNK != mxnxkTile.end(); - itrShapeMNK++, i++) { + for (auto [i, itrShapeMNK] : llvm::enumerate(mxnxkTile)) { auto upperBound = dyn_cast(brgemmOp.getOperand(swap_i[i]).getType()) .getShape()[1]; // Tile size should not be greater than the upperBound - if ((*itrShapeMNK) > upperBound) - return failure(); + if ((itrShapeMNK) > upperBound) + return rewriter.notifyMatchFailure( + brgemmOp, "Tile size is greater than the dimension"); Location loc = brgemmOp.getLoc(); Value zeroCst = rewriter.create(loc, 0); Value ubCstTiledLoop = rewriter.create(loc, upperBound); Value stepCstTiledLoop = - rewriter.create(loc, *itrShapeMNK); + rewriter.create(loc, itrShapeMNK); // Creates M, N, and K tile loops scf::ForOp loopOp = rewriter.create( brgemmOp.getLoc(), zeroCst, ubCstTiledLoop, stepCstTiledLoop); @@ -201,7 +200,7 @@ struct LinalgOpTiling : OpRewritePattern { } auto subview = rewriter.create( - brgemmOp.getLoc(), MemRefType(), input, offsets, shape, strides); + brgemmOp.getLoc(), input, offsets, shape, strides); brgemmOp.setOperand(i, subview); } diff --git a/test/Integration/tile-brgemm-linalg-matmul-bf16.mlir b/test/Integration/tile-brgemm-linalg-matmul-bf16.mlir index 7221f7cb9..c8941fd88 100644 --- a/test/Integration/tile-brgemm-linalg-matmul-bf16.mlir +++ b/test/Integration/tile-brgemm-linalg-matmul-bf16.mlir @@ -1,7 +1,6 @@ // RUN: tpp-run -e register_tile_bf16 --entry-point-result=void -print %s > %t.1 // RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" -convert-linalg-to-xsmm | tpp-run -e register_tile_bf16 --entry-point-result=void -print > %t.2 -// RUN: diff %t.1 %t.2 -// RUN: rm %t.1 %t.2 +// RUN: fpcmp %t.1 %t.2 #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> diff --git a/test/Integration/tile-brgemm-linalg-matmul.mlir b/test/Integration/tile-brgemm-linalg-matmul.mlir index 3e8373723..c96fd5921 100644 --- a/test/Integration/tile-brgemm-linalg-matmul.mlir +++ b/test/Integration/tile-brgemm-linalg-matmul.mlir @@ -1,6 +1,6 @@ // RUN: tpp-run -e entry --entry-point-result=void -print %s > %t.1 // RUN: tpp-run -e entry --entry-point-result=void --vector-to-kernels --registerBlocking=8,32,1 %s -print > %t.2 -// RUN: diff %t.1 %t.2 +// RUN: fpcmp %t.1 %t.2 // RUN: rm %t.1 %t.2 module { diff --git a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir index 5f9b23eba..69e4da8f1 100644 --- a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir +++ b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir @@ -1,5 +1,5 @@ // RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=8,32,1" --split-input-file | FileCheck -check-prefix=CONF1 %s -// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" --canonicalize --split-input-file | FileCheck -check-prefix=CONF2 %s +// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" --split-input-file | FileCheck -check-prefix=CONF2 %s module { func.func @gemm_do_register_tiling(%arg0: memref<16x32x16x32xf32>, %arg1: memref<32x32x32x32xf32>, %arg2: memref<16x32x16x32xf32>) { @@ -13,34 +13,18 @@ module { } } -// CONF1-LABEL: func.func @gemm_do_register_tiling( -// CONF1-SAME: %[[VAL_0:.*]]: memref<16x32x16x32xf32>, -// CONF1-SAME: %[[VAL_1:.*]]: memref<32x32x32x32xf32>, -// CONF1-SAME: %[[VAL_2:.*]]: memref<16x32x16x32xf32>) { -// CONF1: %[[VAL_3:.*]] = arith.constant 1 : index -// CONF1: %[[VAL_4:.*]] = arith.constant 32 : index -// CONF1: %[[VAL_5:.*]] = arith.constant 8 : index -// CONF1: %[[VAL_6:.*]] = arith.constant 16 : index -// CONF1: %[[VAL_7:.*]] = arith.constant 0 : index -// CONF1: scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) in (16, 32) { -// CONF1: %[[VAL_10:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_8]], 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> -// CONF1: %[[VAL_11:.*]] = memref.subview %[[VAL_1]]{{\[}}%[[VAL_9]], 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CONF1: %[[VAL_12:.*]] = memref.subview %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<16x32xf32, strided<[32, 1], offset: ?>> -// CONF1: scf.for %[[VAL_13:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_5]] { -// CONF1: scf.for %[[VAL_14:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_4]] { -// CONF1: scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_3]] { -// CONF1: scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_3]] { -// CONF1: %[[VAL_17:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_15]], %[[VAL_13]], %[[VAL_16]]] [1, 8, 1] [1, 1, 1] : memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[512, 32, 1], offset: ?>> -// CONF1: %[[VAL_18:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_15]], %[[VAL_16]], %[[VAL_14]]] [1, 1, 32] [1, 1, 1] : memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>> -// CONF1: %[[VAL_19:.*]] = memref.subview %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_14]]] [8, 32] [1, 1] : memref<16x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>> -// CONF1: linalg.batch_reduce_matmul ins(%[[VAL_17]], %[[VAL_18]] : memref<1x8x1xf32, strided<[512, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_19]] : memref<8x32xf32, strided<[32, 1], offset: ?>>) -// CONF1: } -// CONF1: } -// CONF1: } -// CONF1: } -// CONF1: } -// CONF1: return -// CONF1: } +// CONF1-LABEL: func.func @gemm_do_register_tiling +// CONF1: memref.subview +// CONF1-NEXT: memref.subview +// CONF1-NEXT: memref.subview +// CONF1-NEXT: scf.for +// CONF1-NEXT: scf.for +// CONF1-NEXT: scf.for +// CONF1-NEXT: scf.for +// CONF1-NEXT: memref.subview +// CONF1-NEXT: memref.subview +// CONF1-NEXT: memref.subview +// CONF1-NEXT: linalg.batch_reduce_matmul // ----- @@ -146,7 +130,7 @@ module { #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> module { memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} - func.func @gemm_32tiles_do_tiling(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { + func.func @gemm_32tiles_do_tiling_bf16(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { %cst = arith.constant 0.000000e+00 : bf16 %0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> @@ -166,37 +150,18 @@ module { } } -// CONF2: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> -// CONF2: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> -// CONF2: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> - -// CONF2-LABEL: memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} -// CONF2-LABEL: func.func @gemm_32tiles_do_tiling( -// CONF2-SAME: %[[VAL_0:.*]]: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { -// CONF2: %[[VAL_1:.*]] = arith.constant 1 : index -// CONF2: %[[VAL_2:.*]] = arith.constant 32 : index -// CONF2: %[[VAL_3:.*]] = arith.constant 0 : index -// CONF2: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : bf16 -// CONF2: %[[VAL_5:.*]] = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> -// CONF2: %[[VAL_6:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> -// CONF2: %[[VAL_7:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2] : memref<8x32x32x32xbf16> into memref<8x32x32x16x2xbf16> -// CONF2: scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) in (8, 32) { -// CONF2: %[[VAL_10:.*]] = memref.subview %[[VAL_6]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> -// CONF2: linalg.fill ins(%[[VAL_4]] : bf16) outs(%[[VAL_10]] : memref<32x32xbf16, strided<[32, 1], offset: ?>>) -// CONF2: %[[VAL_11:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_8]], 0, 0, 0, 0] [1, 32, 32, 16, 2] [1, 1, 1, 1, 1] : memref<8x32x32x16x2xbf16> to memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>> -// CONF2: scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_1]] { -// CONF2: %[[VAL_13:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>> to memref<1x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>> -// CONF2: %[[VAL_14:.*]] = memref.subview %[[VAL_5]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<32x16x32x2xbf16> to memref<1x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -// CONF2: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%[[VAL_13]], %[[VAL_14]] : memref<1x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>, memref<1x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%[[VAL_10]] : memref<32x32xbf16, strided<[32, 1], offset: ?>>) { -// CONF2: ^bb0(%[[VAL_15:.*]]: bf16, %[[VAL_16:.*]]: bf16, %[[VAL_17:.*]]: bf16): -// CONF2: %[[VAL_18:.*]] = arith.mulf %[[VAL_15]], %[[VAL_16]] : bf16 -// CONF2: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : bf16 -// CONF2: linalg.yield %[[VAL_19]] : bf16 -// CONF2: } -// CONF2: } -// CONF2: } -// CONF2: return %[[VAL_6]] : memref<8x32x32x32xbf16> -// CONF2: } +// CONF2-LABEL: func.func @gemm_32tiles_do_tiling_bf16 +// CONF2: memref.subview +// CONF2-NEXT: linalg.fill +// CONF2-NEXT: memref.subview +// CONF2-NEXT: scf.for +// CONF2-NEXT: scf.for +// CONF2-NEXT: scf.for +// CONF2-NEXT: scf.for +// CONF2-NEXT: memref.subview +// CONF2-NEXT: memref.subview +// CONF2-NEXT: memref.subview +// CONF2-NEXT: linalg.generic // ----- @@ -205,7 +170,7 @@ module { #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> module { memref.global "private" constant @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} - func.func @gemm_64tiles_do_tiling(%arg0: memref<4x16x64x64xbf16>) -> memref<4x16x64x64xbf16> { + func.func @gemm_64tiles_do_tiling_bf16(%arg0: memref<4x16x64x64xbf16>) -> memref<4x16x64x64xbf16> { %cst = arith.constant 0.000000e+00 : bf16 %0 = memref.get_global @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x64x64xbf16> @@ -225,42 +190,15 @@ module { } } -// CONF2: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> -// CONF2: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> -// CONF2: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> -// CONF2-LABEL: memref.global "private" constant @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} -// CONF2-LABEL: func.func @gemm_64tiles_do_tiling( -// CONF2-SAME: %[[VAL_0:.*]]: memref<4x16x64x64xbf16>) -> memref<4x16x64x64xbf16> { -// CONF2: %[[VAL_1:.*]] = arith.constant 1 : index -// CONF2: %[[VAL_2:.*]] = arith.constant 16 : index -// CONF2: %[[VAL_3:.*]] = arith.constant 32 : index -// CONF2: %[[VAL_4:.*]] = arith.constant 64 : index -// CONF2: %[[VAL_5:.*]] = arith.constant 0 : index -// CONF2: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : bf16 -// CONF2: %[[VAL_7:.*]] = memref.get_global @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> -// CONF2: %[[VAL_8:.*]] = memref.alloc() {alignment = 64 : i64} : memref<4x16x64x64xbf16> -// CONF2: %[[VAL_9:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0], [1], [2], [3, 4]] output_shape [4, 16, 64, 32, 2] : memref<4x16x64x64xbf16> into memref<4x16x64x32x2xbf16> -// CONF2: scf.forall (%[[VAL_10:.*]], %[[VAL_11:.*]]) in (4, 16) { -// CONF2: %[[VAL_12:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_10]], %[[VAL_11]], 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<4x16x64x64xbf16> to memref<64x64xbf16, strided<[64, 1], offset: ?>> -// CONF2: linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_12]] : memref<64x64xbf16, strided<[64, 1], offset: ?>>) -// CONF2: %[[VAL_13:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_10]], 0, 0, 0, 0] [1, 16, 64, 32, 2] [1, 1, 1, 1, 1] : memref<4x16x64x32x2xbf16> to memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> -// CONF2: scf.for %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] { -// CONF2: scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] { -// CONF2: scf.for %[[VAL_16:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_1]] { -// CONF2: scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_2]] { -// CONF2: %[[VAL_18:.*]] = memref.subview %[[VAL_13]]{{\[}}%[[VAL_16]], %[[VAL_14]], %[[VAL_17]], 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> to memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> -// CONF2: %[[VAL_19:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_15]], 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x32x64x2xbf16> to memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>> -// CONF2: %[[VAL_20:.*]] = memref.subview %[[VAL_12]]{{\[}}%[[VAL_14]], %[[VAL_15]]] [32, 32] [1, 1] : memref<64x64xbf16, strided<[64, 1], offset: ?>> to memref<32x32xbf16, strided<[64, 1], offset: ?>> -// CONF2: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%[[VAL_18]], %[[VAL_19]] : memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>, memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>>) outs(%[[VAL_20]] : memref<32x32xbf16, strided<[64, 1], offset: ?>>) { -// CONF2: ^bb0(%[[VAL_21:.*]]: bf16, %[[VAL_22:.*]]: bf16, %[[VAL_23:.*]]: bf16): -// CONF2: %[[VAL_24:.*]] = arith.mulf %[[VAL_21]], %[[VAL_22]] : bf16 -// CONF2: %[[VAL_25:.*]] = arith.addf %[[VAL_23]], %[[VAL_24]] : bf16 -// CONF2: linalg.yield %[[VAL_25]] : bf16 -// CONF2: } -// CONF2: } -// CONF2: } -// CONF2: } -// CONF2: } -// CONF2: } -// CONF2: return %[[VAL_8]] : memref<4x16x64x64xbf16> -// CONF2: } +// CONF2-LABEL: func.func @gemm_64tiles_do_tiling_bf16 +// CONF2: memref.subview +// CONF2-NEXT: linalg.fill +// CONF2-NEXT: memref.subview +// CONF2-NEXT: scf.for +// CONF2-NEXT: scf.for +// CONF2-NEXT: scf.for +// CONF2-NEXT: scf.for +// CONF2-NEXT: memref.subview +// CONF2-NEXT: memref.subview +// CONF2-NEXT: memref.subview +// CONF2-NEXT: linalg.generic From 5d5908c9b2477fa82b331cf4355530ad82b42ad7 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Mon, 17 Feb 2025 23:53:20 -0800 Subject: [PATCH 07/15] Used upstream code for f32 type and updated unit tests --- lib/TPP/Transforms/BrgemmLinalgTiling.cpp | 261 ++++++++++-------- .../pass-tile-brgemm-linalg-matmul.mlir | 203 ++++++-------- 2 files changed, 231 insertions(+), 233 deletions(-) diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index f77035146..c7107bcab 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -1,4 +1,5 @@ -//===- BrgemmLinalgTiling.cpp -----------------------------------------*-C++-*-===// +//===- BrgemmLinalgTiling.cpp +//-----------------------------------------*-C++-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,14 +11,17 @@ // //===----------------------------------------------------------------------===// #include "TPP/Transforms/Transforms.h" +#include "TPP/Transforms/Utils/VNNIUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -27,7 +31,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" -#include "TPP/Transforms/Utils/VNNIUtils.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "brgemm-linalg-tiling" @@ -58,9 +61,9 @@ struct LinalgOpTiling : OpRewritePattern { return failure(); // Check whether the tile sizes are valid - if (options.registerTileShape.size() != 3) - return rewriter.notifyMatchFailure(brgemmOp, - "Invalid user input tile sizes. Should be "); + if (options.registerTileShape.size() != 3) + return rewriter.notifyMatchFailure( + brgemmOp, "Invalid user input tile sizes. Should be "); // Check the whether the operation is brmatmul fp32 or bf16 type using // reduction count @@ -71,16 +74,16 @@ struct LinalgOpTiling : OpRewritePattern { utils::IteratorType::reduction); if (reductionCount == 0) - return rewriter.notifyMatchFailure( - brgemmOp, "Matmul operation not supported yet"); + return rewriter.notifyMatchFailure(brgemmOp, + "Matmul operation not supported yet"); if (reductionCount == 1) return rewriter.notifyMatchFailure( brgemmOp, "Batch matmul operation not supported yet"); if (reductionCount > 3) - return rewriter.notifyMatchFailure( - brgemmOp, "The operation is not a gemm"); + return rewriter.notifyMatchFailure(brgemmOp, + "The operation is not a gemm"); auto tensorShapeLhs = dyn_cast(brgemmOp.getOperand(0).getType()).getShape(); @@ -92,124 +95,154 @@ struct LinalgOpTiling : OpRewritePattern { return rewriter.notifyMatchFailure( brgemmOp, "Invalid rank for batch reduce operation"); - if (reductionCount == 3 && !vnni::utils::isInVnniLayout(brgemmOp)) + auto vnniOpt = vnni::utils::isInVnniLayout(brgemmOp); + if (reductionCount == 3 && !vnniOpt) return rewriter.notifyMatchFailure( - brgemmOp, "Failed matching for batch reduce operation with vnni layout"); + brgemmOp, + "Failed matching for batch reduce operation with vnni layout"); // Get the register blocking tile shape from the user input SmallVector mxnxkTile(options.registerTileShape.begin(), - options.registerTileShape.end()); + options.registerTileShape.end()); - // k-tile size adjusted based on the vnni layout for bf16 type - if (vnni::utils::isInVnniLayout(brgemmOp)) { + // We do manual tiling for bf16type with vnni layout. It seems the + // upstream tiling interface is broken for vnni layouts. + if (vnniOpt) { + // k-tile size adjusted based on the vnni layout for bf16 type auto tensorShape = - dyn_cast(brgemmOp.getOperand(0).getType()).getShape(); - mxnxkTile[2] = mxnxkTile[2] / tensorShape[3]; - } + dyn_cast(brgemmOp.getOperand(0).getType()).getShape(); + auto kTileVnni = mxnxkTile[2] / tensorShape[3]; + + if (kTileVnni > 0) { + mxnxkTile[2] = kTileVnni; + } else { + return rewriter.notifyMatchFailure( + brgemmOp, "Failed matching K tile size for batch reduce operation " + "with vnni layout. K tile size should be >= vnni layout"); + } - SmallVector swap_i = {0, 2, 1}; - std::map> inductionVars; - - // For M, N, and K loops - scf::ForOp innermostForLoop; - // Creating the tiled loops - for (auto [i, itrShapeMNK] : llvm::enumerate(mxnxkTile)) { - auto upperBound = - dyn_cast(brgemmOp.getOperand(swap_i[i]).getType()) - .getShape()[1]; - // Tile size should not be greater than the upperBound - if ((itrShapeMNK) > upperBound) - return rewriter.notifyMatchFailure( - brgemmOp, "Tile size is greater than the dimension"); - - Location loc = brgemmOp.getLoc(); - Value zeroCst = rewriter.create(loc, 0); - Value ubCstTiledLoop = - rewriter.create(loc, upperBound); - Value stepCstTiledLoop = - rewriter.create(loc, itrShapeMNK); - // Creates M, N, and K tile loops - scf::ForOp loopOp = rewriter.create( - brgemmOp.getLoc(), zeroCst, ubCstTiledLoop, stepCstTiledLoop); - rewriter.setInsertionPointToStart(loopOp.getBody()); - innermostForLoop = loopOp; - - // Stores the induction variable with respect to the operands mapping it's - // subview. - if (i == 0) { // Stores iv for M loop - inductionVars[0][1] = loopOp.getInductionVar(); - inductionVars[2][0] = loopOp.getInductionVar(); - } else if (i == 1) { //stores iv for N loop, creates batch loop, and maps iv of batch loop - inductionVars[1][2] = loopOp.getInductionVar(); - inductionVars[2][1] = loopOp.getInductionVar(); - // Creates reduction loop after the N loop - Value ubCstReduction = rewriter.create( - loc, dyn_cast(brgemmOp.getOperand(0).getType()) - .getShape()[0]); - Value stepCstReduction = - rewriter.create(loc, 1); - scf::ForOp redloopOp = rewriter.create( - brgemmOp.getLoc(), zeroCst, ubCstReduction, stepCstReduction); - rewriter.setInsertionPointToStart(redloopOp.getBody()); - inductionVars[0][0] = redloopOp.getInductionVar(); - inductionVars[1][0] = redloopOp.getInductionVar(); - - } else if (i == 2) { // stores iv for k-loop - inductionVars[0][2] = loopOp.getInductionVar(); - inductionVars[1][1] = loopOp.getInductionVar(); + SmallVector swap_i = {0, 2, 1}; + std::map> inductionVars; + // For M, N, and K loops + scf::ForOp innermostForLoop; + // Creating the tiled loops + for (auto [i, itrShapeMNK] : llvm::enumerate(mxnxkTile)) { + auto upperBound = + dyn_cast(brgemmOp.getOperand(swap_i[i]).getType()) + .getShape()[1]; + // Tile size should not be greater than the upperBound + if ((itrShapeMNK) > upperBound) + return rewriter.notifyMatchFailure( + brgemmOp, "Tile size is greater than the dimension"); + + Location loc = brgemmOp.getLoc(); + Value zeroCst = rewriter.create(loc, 0); + Value ubCstTiledLoop = + rewriter.create(loc, upperBound); + Value stepCstTiledLoop = + rewriter.create(loc, itrShapeMNK); + // Creates M, N, and K tile loops + scf::ForOp loopOp = rewriter.create( + brgemmOp.getLoc(), zeroCst, ubCstTiledLoop, stepCstTiledLoop); + rewriter.setInsertionPointToStart(loopOp.getBody()); + innermostForLoop = loopOp; + + // Stores the induction variable with respect to the operands mapping + // it's subview. + if (i == 0) { // Stores iv for M loop + inductionVars[0][1] = loopOp.getInductionVar(); + inductionVars[2][0] = loopOp.getInductionVar(); + } else if (i == 1) { // stores iv for N loop, creates batch loop, and + // maps iv of batch loop + inductionVars[1][2] = loopOp.getInductionVar(); + inductionVars[2][1] = loopOp.getInductionVar(); + // Creates reduction loop after the N loop + Value ubCstReduction = rewriter.create( + loc, dyn_cast(brgemmOp.getOperand(0).getType()) + .getShape()[0]); + Value stepCstReduction = + rewriter.create(loc, 1); + scf::ForOp redloopOp = rewriter.create( + brgemmOp.getLoc(), zeroCst, ubCstReduction, stepCstReduction); + rewriter.setInsertionPointToStart(redloopOp.getBody()); + inductionVars[0][0] = redloopOp.getInductionVar(); + inductionVars[1][0] = redloopOp.getInductionVar(); + + } else if (i == 2) { // stores iv for k-loop + inductionVars[0][2] = loopOp.getInductionVar(); + inductionVars[1][1] = loopOp.getInductionVar(); + } } - } - // DS to assist while creating new subviews with correct indices and shapes - SmallVector mxkTile{mxnxkTile[0], mxnxkTile[2]}; - SmallVector kxnTile{mxnxkTile[2], mxnxkTile[1]}; - SmallVector mxnTile{mxnxkTile[0], mxnxkTile[1]}; - - SmallVector> tileshapes{mxkTile, kxnTile, mxnTile}; - // Creating subviews - for (size_t i = 0; i < brgemmOp.getNumOperands(); i++) { - SmallVector offsets; - SmallVector indices; - SmallVector shape; - SmallVector strides; - - auto input = brgemmOp.getOperand(i); - auto tensorShape = dyn_cast(input.getType()).getShape(); - auto tileItr = tileshapes[i].begin(); - - // Iterates over the shape of each tensor and update its offsets, indices, - // shapes, strides with respect to tile sizes - for (size_t j = 0; j < tensorShape.size(); j++) { - if (j == 0 && (i < 2)) { // Updates the batch dimension - offsets.push_back(inductionVars[i][j]); - indices.push_back(1); - shape.push_back(rewriter.getIndexAttr(1)); - strides.push_back(rewriter.getIndexAttr(1)); - } else if (j < 3) { // Updates the M, N, and K dimensions - offsets.push_back(inductionVars[i][j]); - indices.push_back((*tileItr)); - shape.push_back(rewriter.getIndexAttr(*tileItr)); - strides.push_back(rewriter.getIndexAttr(1)); - tileItr++; - } else { // Just copies the vnni layout dimensions - offsets.push_back(rewriter.getIndexAttr(0)); - indices.push_back(tensorShape[j]); - shape.push_back(rewriter.getIndexAttr(tensorShape[j])); - strides.push_back(rewriter.getIndexAttr(1)); + // DS to assist while creating new subviews with correct indices and + // shapes + SmallVector mxkTile{mxnxkTile[0], mxnxkTile[2]}; + SmallVector kxnTile{mxnxkTile[2], mxnxkTile[1]}; + SmallVector mxnTile{mxnxkTile[0], mxnxkTile[1]}; + + SmallVector> tileshapes{mxkTile, kxnTile, mxnTile}; + // Creating subviews + for (size_t i = 0; i < brgemmOp.getNumOperands(); i++) { + SmallVector offsets; + SmallVector indices; + SmallVector shape; + SmallVector strides; + + auto input = brgemmOp.getOperand(i); + auto tensorShape = dyn_cast(input.getType()).getShape(); + auto tileItr = tileshapes[i].begin(); + + // Iterates over the shape of each tensor and update its offsets, + // indices, shapes, strides with respect to tile sizes + for (size_t j = 0; j < tensorShape.size(); j++) { + if (j == 0 && (i < 2)) { // Updates the batch dimension + offsets.push_back(inductionVars[i][j]); + indices.push_back(1); + shape.push_back(rewriter.getIndexAttr(1)); + strides.push_back(rewriter.getIndexAttr(1)); + } else if (j < 3) { // Updates the M, N, and K dimensions + offsets.push_back(inductionVars[i][j]); + indices.push_back((*tileItr)); + shape.push_back(rewriter.getIndexAttr(*tileItr)); + strides.push_back(rewriter.getIndexAttr(1)); + tileItr++; + } else { // Just copies the vnni layout dimensions + offsets.push_back(rewriter.getIndexAttr(0)); + indices.push_back(tensorShape[j]); + shape.push_back(rewriter.getIndexAttr(tensorShape[j])); + strides.push_back(rewriter.getIndexAttr(1)); + } } + + auto subview = rewriter.create( + brgemmOp.getLoc(), input, offsets, shape, strides); + brgemmOp.setOperand(i, subview); } - auto subview = rewriter.create( - brgemmOp.getLoc(), input, offsets, shape, strides); - brgemmOp.setOperand(i, subview); - } + rewriter.setInsertionPoint( + innermostForLoop.getBody(), + std::prev(innermostForLoop.getBody()->end(), 1)); + auto clone = rewriter.clone(*brgemmOp); + brgemmOp.replaceAllUsesWith(clone); - rewriter.setInsertionPoint(innermostForLoop.getBody(), - std::prev(innermostForLoop.getBody()->end(), 1)); - auto clone = rewriter.clone(*brgemmOp); - brgemmOp.replaceAllUsesWith(clone); - if (brgemmOp->use_empty()) - rewriter.eraseOp(brgemmOp); + if (brgemmOp->use_empty()) + rewriter.eraseOp(brgemmOp); + + } else { + // f32 gets tiled with the help of upstream tiling interface + linalg::LinalgTilingOptions options; + options.setTileSizes({1, mxnxkTile[0], mxnxkTile[1], mxnxkTile[2]}); + options.setLoopType(linalg::LinalgTilingLoopType::Loops); + options.setInterchange({1, 2, 0, 3}); + + FailureOr tiledOp = + linalg::tileLinalgOp(rewriter, brgemmOp, options); + + if (failed(tiledOp)) { + return failure(); + } + rewriter.replaceOp(brgemmOp, tiledOp->op->getResults()); + } return success(); } diff --git a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir index 69e4da8f1..78829e0fd 100644 --- a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir +++ b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir @@ -13,115 +13,25 @@ module { } } -// CONF1-LABEL: func.func @gemm_do_register_tiling -// CONF1: memref.subview -// CONF1-NEXT: memref.subview -// CONF1-NEXT: memref.subview -// CONF1-NEXT: scf.for -// CONF1-NEXT: scf.for -// CONF1-NEXT: scf.for -// CONF1-NEXT: scf.for -// CONF1-NEXT: memref.subview -// CONF1-NEXT: memref.subview -// CONF1-NEXT: memref.subview -// CONF1-NEXT: linalg.batch_reduce_matmul - -// ----- -module { - memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64} - func.func @chainned_gemm_do_register_tiling(%arg0: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %0 = memref.get_global @__constant_48x32x32xf32 : memref<48x32x32xf32> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> - scf.forall (%arg1, %arg2) in (8, 48) { - %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) - %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> - linalg.batch_reduce_matmul ins(%subview_1, %0 : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<48x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) - } - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> - scf.forall (%arg1, %arg2) in (8, 48) { - %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) - %subview_1 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> - linalg.batch_reduce_matmul ins(%subview_1, %0 : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<48x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) - } - scf.forall (%arg1, %arg2) in (8, 48) { - %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) - %subview_1 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> - linalg.batch_reduce_matmul ins(%subview_1, %0 : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<48x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) - } - return %alloc : memref<8x48x32x32xf32> - } -} - -// CONF1-LABEL: memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64} -// CONF1-LABEL: func.func @chainned_gemm_do_register_tiling( -// CONF1-SAME: %[[VAL_0:.*]]: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> { -// CONF1: %[[VAL_1:.*]] = arith.constant 1 : index -// CONF1: %[[VAL_2:.*]] = arith.constant 48 : index -// CONF1: %[[VAL_3:.*]] = arith.constant 8 : index -// CONF1: %[[VAL_4:.*]] = arith.constant 32 : index -// CONF1: %[[VAL_5:.*]] = arith.constant 0 : index -// CONF1: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 -// CONF1: %[[VAL_7:.*]] = memref.get_global @__constant_48x32x32xf32 : memref<48x32x32xf32> -// CONF1: %[[VAL_8:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> -// CONF1: scf.forall (%[[VAL_9:.*]], %[[VAL_10:.*]]) in (8, 48) { -// CONF1: %[[VAL_11:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_9]], %[[VAL_10]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> -// CONF1: linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_11]] : memref<32x32xf32, strided<[32, 1], offset: ?>>) -// CONF1: %[[VAL_12:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_9]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CONF1: scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] { -// CONF1: scf.for %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_4]] { -// CONF1: scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_1]] { -// CONF1: scf.for %[[VAL_16:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_1]] { -// CONF1: %[[VAL_17:.*]] = memref.subview %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_13]], %[[VAL_16]]] [1, 8, 1] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>> -// CONF1: %[[VAL_18:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_15]], %[[VAL_16]], %[[VAL_14]]] [1, 1, 32] [1, 1, 1] : memref<48x32x32xf32> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>> -// CONF1: %[[VAL_19:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_13]], %[[VAL_14]]] [8, 32] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>> -// CONF1: linalg.batch_reduce_matmul ins(%[[VAL_17]], %[[VAL_18]] : memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_19]] : memref<8x32xf32, strided<[32, 1], offset: ?>>) -// CONF1: } -// CONF1: } -// CONF1: } -// CONF1: } -// CONF1: } -// CONF1: %[[VAL_20:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> -// CONF1: scf.forall (%[[VAL_21:.*]], %[[VAL_22:.*]]) in (8, 48) { -// CONF1: %[[VAL_23:.*]] = memref.subview %[[VAL_20]]{{\[}}%[[VAL_21]], %[[VAL_22]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> -// CONF1: linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_23]] : memref<32x32xf32, strided<[32, 1], offset: ?>>) -// CONF1: %[[VAL_24:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_21]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CONF1: scf.for %[[VAL_25:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] { -// CONF1: scf.for %[[VAL_26:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_4]] { -// CONF1: scf.for %[[VAL_27:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_1]] { -// CONF1: scf.for %[[VAL_28:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_1]] { -// CONF1: %[[VAL_29:.*]] = memref.subview %[[VAL_24]]{{\[}}%[[VAL_27]], %[[VAL_25]], %[[VAL_28]]] [1, 8, 1] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>> -// CONF1: %[[VAL_30:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_27]], %[[VAL_28]], %[[VAL_26]]] [1, 1, 32] [1, 1, 1] : memref<48x32x32xf32> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>> -// CONF1: %[[VAL_31:.*]] = memref.subview %[[VAL_23]]{{\[}}%[[VAL_25]], %[[VAL_26]]] [8, 32] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>> -// CONF1: linalg.batch_reduce_matmul ins(%[[VAL_29]], %[[VAL_30]] : memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_31]] : memref<8x32xf32, strided<[32, 1], offset: ?>>) -// CONF1: } -// CONF1: } -// CONF1: } -// CONF1: } -// CONF1: } -// CONF1: scf.forall (%[[VAL_32:.*]], %[[VAL_33:.*]]) in (8, 48) { -// CONF1: %[[VAL_34:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_32]], %[[VAL_33]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> -// CONF1: linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_34]] : memref<32x32xf32, strided<[32, 1], offset: ?>>) -// CONF1: %[[VAL_35:.*]] = memref.subview %[[VAL_20]]{{\[}}%[[VAL_32]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CONF1: scf.for %[[VAL_36:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] { -// CONF1: scf.for %[[VAL_37:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_4]] { -// CONF1: scf.for %[[VAL_38:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_1]] { -// CONF1: scf.for %[[VAL_39:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_1]] { -// CONF1: %[[VAL_40:.*]] = memref.subview %[[VAL_35]]{{\[}}%[[VAL_38]], %[[VAL_36]], %[[VAL_39]]] [1, 8, 1] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>> -// CONF1: %[[VAL_41:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_38]], %[[VAL_39]], %[[VAL_37]]] [1, 1, 32] [1, 1, 1] : memref<48x32x32xf32> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>> -// CONF1: %[[VAL_42:.*]] = memref.subview %[[VAL_34]]{{\[}}%[[VAL_36]], %[[VAL_37]]] [8, 32] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>> -// CONF1: linalg.batch_reduce_matmul ins(%[[VAL_40]], %[[VAL_41]] : memref<1x8x1xf32, strided<[1024, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_42]] : memref<8x32xf32, strided<[32, 1], offset: ?>>) -// CONF1: } -// CONF1: } -// CONF1: } -// CONF1: } -// CONF1: } -// CONF1: return %[[VAL_8]] : memref<8x48x32x32xf32> -// CONF1: } +// CONF1-LABEL: func.func @gemm_do_register_tiling +// CONF1-DAG: %[[C1:.+]] = arith.constant 1 : index +// CONF1-DAG: %[[C32:.+]] = arith.constant 32 : index +// CONF1-DAG: %[[C8:.+]] = arith.constant 8 : index +// CONF1-DAG: %[[C16:.+]] = arith.constant 16 : index +// CONF1-DAG: %[[C0:.+]] = arith.constant 0 : index +// CONF1: scf.forall (%arg3, %arg4) in (16, 32) { +// CONF1-NEXT: %subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> +// CONF1-NEXT: %subview_0 = memref.subview %arg1[%arg4, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CONF1-NEXT: %subview_1 = memref.subview %arg2[%arg3, %arg4, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<16x32xf32, strided<[32, 1], offset: ?>> +// CONF1-NEXT: scf.for %[[I:.+]] = %[[C0]] to %[[C16]] step %[[C8]] { +// CONF1-NEXT: scf.for %[[J:.+]] = %[[C0]] to %[[C32]] step %[[C32]] { +// CONF1-NEXT: scf.for %[[K:.+]] = %[[C0]] to %[[C32]] step %[[C1]] { +// CONF1-NEXT: scf.for %[[L:.+]] = %[[C0]] to %[[C32]] step %[[C1]] { +// CONF1-NEXT: %subview_2 = memref.subview %subview[%[[K]], %[[I]], %[[L]]] [1, 8, 1] [1, 1, 1] : memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[512, 32, 1], offset: ?>> +// CONF1-NEXT: %subview_3 = memref.subview %subview_0[%[[K]], %[[L]], %[[J]]] [1, 1, 32] [1, 1, 1] : memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>> +// CONF1-NEXT: %subview_4 = memref.subview %subview_1[%[[I]], %[[J]]] [8, 32] [1, 1] : memref<16x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>> +// CONF1-NEXT: linalg.batch_reduce_matmul ins(%subview_2, %subview_3 : memref<1x8x1xf32, strided<[512, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%subview_4 : memref<8x32xf32, strided<[32, 1], offset: ?>>) // ----- @@ -191,14 +101,69 @@ module { } // CONF2-LABEL: func.func @gemm_64tiles_do_tiling_bf16 -// CONF2: memref.subview -// CONF2-NEXT: linalg.fill -// CONF2-NEXT: memref.subview -// CONF2-NEXT: scf.for -// CONF2-NEXT: scf.for -// CONF2-NEXT: scf.for -// CONF2-NEXT: scf.for -// CONF2-NEXT: memref.subview -// CONF2-NEXT: memref.subview -// CONF2-NEXT: memref.subview -// CONF2-NEXT: linalg.generic +// CONF2-DAG: %[[C1:.+]] = arith.constant 1 : index +// CONF2-DAG: %[[C32:.+]] = arith.constant 32 : index +// CONF2-DAG: %[[C64:.+]] = arith.constant 64 : index +// CONF2-DAG: %[[C16:.+]] = arith.constant 16 : index +// CONF2-DAG: %[[C0:.+]] = arith.constant 0 : index +// CONF2: %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<4x16x64x64xbf16> to memref<64x64xbf16, strided<[64, 1], offset: ?>> +// CONF2-NEXT: linalg.fill ins(%cst : bf16) outs(%subview : memref<64x64xbf16, strided<[64, 1], offset: ?>>) +// CONF2-NEXT: %subview_0 = memref.subview %expand_shape[%arg1, 0, 0, 0, 0] [1, 16, 64, 32, 2] [1, 1, 1, 1, 1] : memref<4x16x64x32x2xbf16> to memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> +// CONF2-NEXT: scf.for %[[I:.+]] = %[[C0]] to %[[C64]] step %[[C32]] { +// CONF2-NEXT: scf.for %[[J:.+]] = %[[C0]] to %[[C64]] step %[[C32]] { +// CONF2-NEXT: scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[C1]] { +// CONF2-NEXT: scf.for %[[L:.+]] = %[[C0]] to %[[C32]] step %[[C16]] { +// CONF2-NEXT: %subview_1 = memref.subview %subview_0[%[[K]], %[[I]], %[[L]], 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> to memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> +// CONF2-NEXT: %subview_2 = memref.subview %0[%[[K]], %[[L]], %[[J]], 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x32x64x2xbf16> to memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>> +// CONF2-NEXT: %subview_3 = memref.subview %subview[%[[I]], %[[J]]] [32, 32] [1, 1] : memref<64x64xbf16, strided<[64, 1], offset: ?>> to memref<32x32xbf16, strided<[64, 1], offset: ?>> +// CONF2-NEXT: linalg.generic + +// ----- + +module { + func.func @brgemm_tensor_type_no_tiling(%arg0: tensor<128x256x512xf32>, %arg1: tensor<128x512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> { + %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<128x256x512xf32>, tensor<128x512x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> + return %0 : tensor<256x256xf32> + } +} + + +// CONF1-LABEL: func.func @brgemm_tensor_type_no_tiling +func.func @brgemm_tensor_type_no_tiling(%arg0: tensor<128x256x512xf32>, %arg1: tensor<128x512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> { +// CONF1-NOT: scf.for +// CONF1-NOT: scf.for +// CONF1-NOT: scf.for +// CONF1-NOT: scf.for +// CONF1-NOT: memref.subview +// CONF1-NOT: memref.subview +// CONF1-NOT: memref.subview + %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<128x256x512xf32>, tensor<128x512x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> + return %0 : tensor<256x256xf32> +} + +// ----- + +module { + func.func @matmul_no_tiling(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>, %arg2: memref<64x64xf32>) { + linalg.matmul ins(%arg0, %arg1 : memref<64x64xf32>, memref<64x64xf32>) + outs(%arg2 : memref<64x64xf32>) + return + } +} + + +// CONF1-LABEL: func.func @matmul_no_tiling +func.func @matmul_no_tiling(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>, %arg2: memref<64x64xf32>) { +// CONF1-NOT: scf.for +// CONF1-NOT: scf.for +// CONF1-NOT: scf.for +// CONF1-NOT: scf.for +// CONF1-NOT: memref.subview +// CONF1-NOT: memref.subview +// CONF1-NOT: memref.subview + linalg.matmul ins(%arg0, %arg1 : memref<64x64xf32>, memref<64x64xf32>) + outs(%arg2 : memref<64x64xf32>) + return +} + + From ddde637f8f1c71dff6589c673c18ad8a388ea4a8 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Tue, 18 Feb 2025 03:02:18 -0800 Subject: [PATCH 08/15] Few re-factoring --- lib/TPP/Transforms/BrgemmLinalgTiling.cpp | 18 ++++----- .../tile-brgemm-linalg-matmul-bf16.mlir | 2 +- .../pass-tile-brgemm-linalg-matmul.mlir | 38 ++++++++----------- 3 files changed, 25 insertions(+), 33 deletions(-) diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index c7107bcab..4a9275992 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -1,5 +1,4 @@ -//===- BrgemmLinalgTiling.cpp -//-----------------------------------------*-C++-*-===// +//===- BrgemmLinalgTiling.cpp--------------------------------------*-C++-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -113,13 +112,14 @@ struct LinalgOpTiling : OpRewritePattern { dyn_cast(brgemmOp.getOperand(0).getType()).getShape(); auto kTileVnni = mxnxkTile[2] / tensorShape[3]; - if (kTileVnni > 0) { - mxnxkTile[2] = kTileVnni; - } else { - return rewriter.notifyMatchFailure( + // Note: We make an assumption that the k tile size is divisible to + // the powers of 2. + if (kTileVnni < 1) + return rewriter.notifyMatchFailure( brgemmOp, "Failed matching K tile size for batch reduce operation " "with vnni layout. K tile size should be >= vnni layout"); - } + + mxnxkTile[2] = kTileVnni; SmallVector swap_i = {0, 2, 1}; std::map> inductionVars; @@ -182,7 +182,7 @@ struct LinalgOpTiling : OpRewritePattern { SmallVector> tileshapes{mxkTile, kxnTile, mxnTile}; // Creating subviews - for (size_t i = 0; i < brgemmOp.getNumOperands(); i++) { + for (size_t i = 0, opSize = brgemmOp.getNumOperands(); i < opSize; i++) { SmallVector offsets; SmallVector indices; SmallVector shape; @@ -194,7 +194,7 @@ struct LinalgOpTiling : OpRewritePattern { // Iterates over the shape of each tensor and update its offsets, // indices, shapes, strides with respect to tile sizes - for (size_t j = 0; j < tensorShape.size(); j++) { + for (size_t j = 0, tSize = tensorShape.size(); j < tSize; j++) { if (j == 0 && (i < 2)) { // Updates the batch dimension offsets.push_back(inductionVars[i][j]); indices.push_back(1); diff --git a/test/Integration/tile-brgemm-linalg-matmul-bf16.mlir b/test/Integration/tile-brgemm-linalg-matmul-bf16.mlir index c8941fd88..d14803ca7 100644 --- a/test/Integration/tile-brgemm-linalg-matmul-bf16.mlir +++ b/test/Integration/tile-brgemm-linalg-matmul-bf16.mlir @@ -1,6 +1,6 @@ // RUN: tpp-run -e register_tile_bf16 --entry-point-result=void -print %s > %t.1 // RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" -convert-linalg-to-xsmm | tpp-run -e register_tile_bf16 --entry-point-result=void -print > %t.2 -// RUN: fpcmp %t.1 %t.2 +// RUN: fpcmp -r 0.001 %t.1 %t.2 #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> diff --git a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir index 78829e0fd..9bc70439b 100644 --- a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir +++ b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir @@ -61,17 +61,21 @@ module { } // CONF2-LABEL: func.func @gemm_32tiles_do_tiling_bf16 -// CONF2: memref.subview -// CONF2-NEXT: linalg.fill -// CONF2-NEXT: memref.subview -// CONF2-NEXT: scf.for -// CONF2-NEXT: scf.for -// CONF2-NEXT: scf.for -// CONF2-NEXT: scf.for -// CONF2-NEXT: memref.subview -// CONF2-NEXT: memref.subview -// CONF2-NEXT: memref.subview -// CONF2-NEXT: linalg.generic +// CONF2-DAG: %[[C1:.+]] = arith.constant 1 : index +// CONF2-DAG: %[[C32:.+]] = arith.constant 32 : index +// CONF2-DAG: %[[C16:.+]] = arith.constant 16 : index +// CONF2-DAG: %[[C0:.+]] = arith.constant 0 : index +// CONF2: %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> +// CONF2-NEXT: linalg.fill ins(%cst : bf16) outs(%subview : memref<32x32xbf16, strided<[32, 1], offset: ?>>) +// CONF2-NEXT: %subview_0 = memref.subview %expand_shape[%arg1, 0, 0, 0, 0] [1, 32, 32, 16, 2] [1, 1, 1, 1, 1] : memref<8x32x32x16x2xbf16> to memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>> +// CONF2-NEXT: scf.for %[[I:.+]] = %[[C0]] to %[[C32]] step %[[C32]] { +// CONF2-NEXT: scf.for %[[J:.+]] = %[[C0]] to %[[C32]] step %[[C32]] { +// CONF2-NEXT: scf.for %[[K:.+]] = %[[C0]] to %[[C32]] step %[[C1]] { +// CONF2-NEXT: scf.for %[[L:.+]] = %[[C0]] to %[[C16]] step %[[C16]] { +// CONF2-NEXT: %subview_1 = memref.subview %subview_0[%[[K]], %[[I]], %[[L]], 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>> to memref<1x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>> +// CONF2-NEXT: %subview_2 = memref.subview %0[%[[K]], %[[L]], %[[J]], 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<32x16x32x2xbf16> to memref<1x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> +// CONF2-NEXT: %subview_3 = memref.subview %subview[%[[I]], %[[J]]] [32, 32] [1, 1] : memref<32x32xbf16, strided<[32, 1], offset: ?>> to memref<32x32xbf16, strided<[32, 1], offset: ?>> +// CONF2-NEXT: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%subview_1, %subview_2 : memref<1x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>, memref<1x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%subview_3 : memref<32x32xbf16, strided<[32, 1], offset: ?>>) // ----- @@ -131,12 +135,6 @@ module { // CONF1-LABEL: func.func @brgemm_tensor_type_no_tiling func.func @brgemm_tensor_type_no_tiling(%arg0: tensor<128x256x512xf32>, %arg1: tensor<128x512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> { // CONF1-NOT: scf.for -// CONF1-NOT: scf.for -// CONF1-NOT: scf.for -// CONF1-NOT: scf.for -// CONF1-NOT: memref.subview -// CONF1-NOT: memref.subview -// CONF1-NOT: memref.subview %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<128x256x512xf32>, tensor<128x512x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %0 : tensor<256x256xf32> } @@ -155,12 +153,6 @@ module { // CONF1-LABEL: func.func @matmul_no_tiling func.func @matmul_no_tiling(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>, %arg2: memref<64x64xf32>) { // CONF1-NOT: scf.for -// CONF1-NOT: scf.for -// CONF1-NOT: scf.for -// CONF1-NOT: scf.for -// CONF1-NOT: memref.subview -// CONF1-NOT: memref.subview -// CONF1-NOT: memref.subview linalg.matmul ins(%arg0, %arg1 : memref<64x64xf32>, memref<64x64xf32>) outs(%arg2 : memref<64x64xf32>) return From 3306413fffc3170a9ff3b36394b520483a80f6f7 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Tue, 18 Feb 2025 05:08:17 -0800 Subject: [PATCH 09/15] With modified tile sizes and interchange options, upstream workd for bf16 vnni --- lib/TPP/Transforms/BrgemmLinalgTiling.cpp | 142 +++------------------- 1 file changed, 16 insertions(+), 126 deletions(-) diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index 4a9275992..9043c9c8b 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -11,18 +11,10 @@ //===----------------------------------------------------------------------===// #include "TPP/Transforms/Transforms.h" #include "TPP/Transforms/Utils/VNNIUtils.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Affine/IR/AffineValueMap.h" -#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/SCF/Utils/Utils.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/PatternMatch.h" @@ -104,8 +96,10 @@ struct LinalgOpTiling : OpRewritePattern { SmallVector mxnxkTile(options.registerTileShape.begin(), options.registerTileShape.end()); - // We do manual tiling for bf16type with vnni layout. It seems the - // upstream tiling interface is broken for vnni layouts. + linalg::LinalgTilingOptions options; + options.setLoopType(linalg::LinalgTilingLoopType::Loops); + FailureOr tiledOp; + if (vnniOpt) { // k-tile size adjusted based on the vnni layout for bf16 type auto tensorShape = @@ -120,129 +114,25 @@ struct LinalgOpTiling : OpRewritePattern { "with vnni layout. K tile size should be >= vnni layout"); mxnxkTile[2] = kTileVnni; + // Tile options for bf16 type with vnni layout + options.setTileSizes({1, 0, mxnxkTile[0], mxnxkTile[1], mxnxkTile[2]}); + options.setInterchange({2, 3, 0, 4, 1}); + tiledOp = + linalg::tileLinalgOp(rewriter, brgemmOp, options); - SmallVector swap_i = {0, 2, 1}; - std::map> inductionVars; - // For M, N, and K loops - scf::ForOp innermostForLoop; - // Creating the tiled loops - for (auto [i, itrShapeMNK] : llvm::enumerate(mxnxkTile)) { - auto upperBound = - dyn_cast(brgemmOp.getOperand(swap_i[i]).getType()) - .getShape()[1]; - // Tile size should not be greater than the upperBound - if ((itrShapeMNK) > upperBound) - return rewriter.notifyMatchFailure( - brgemmOp, "Tile size is greater than the dimension"); - - Location loc = brgemmOp.getLoc(); - Value zeroCst = rewriter.create(loc, 0); - Value ubCstTiledLoop = - rewriter.create(loc, upperBound); - Value stepCstTiledLoop = - rewriter.create(loc, itrShapeMNK); - // Creates M, N, and K tile loops - scf::ForOp loopOp = rewriter.create( - brgemmOp.getLoc(), zeroCst, ubCstTiledLoop, stepCstTiledLoop); - rewriter.setInsertionPointToStart(loopOp.getBody()); - innermostForLoop = loopOp; - - // Stores the induction variable with respect to the operands mapping - // it's subview. - if (i == 0) { // Stores iv for M loop - inductionVars[0][1] = loopOp.getInductionVar(); - inductionVars[2][0] = loopOp.getInductionVar(); - } else if (i == 1) { // stores iv for N loop, creates batch loop, and - // maps iv of batch loop - inductionVars[1][2] = loopOp.getInductionVar(); - inductionVars[2][1] = loopOp.getInductionVar(); - // Creates reduction loop after the N loop - Value ubCstReduction = rewriter.create( - loc, dyn_cast(brgemmOp.getOperand(0).getType()) - .getShape()[0]); - Value stepCstReduction = - rewriter.create(loc, 1); - scf::ForOp redloopOp = rewriter.create( - brgemmOp.getLoc(), zeroCst, ubCstReduction, stepCstReduction); - rewriter.setInsertionPointToStart(redloopOp.getBody()); - inductionVars[0][0] = redloopOp.getInductionVar(); - inductionVars[1][0] = redloopOp.getInductionVar(); - - } else if (i == 2) { // stores iv for k-loop - inductionVars[0][2] = loopOp.getInductionVar(); - inductionVars[1][1] = loopOp.getInductionVar(); - } - } - - // DS to assist while creating new subviews with correct indices and - // shapes - SmallVector mxkTile{mxnxkTile[0], mxnxkTile[2]}; - SmallVector kxnTile{mxnxkTile[2], mxnxkTile[1]}; - SmallVector mxnTile{mxnxkTile[0], mxnxkTile[1]}; - - SmallVector> tileshapes{mxkTile, kxnTile, mxnTile}; - // Creating subviews - for (size_t i = 0, opSize = brgemmOp.getNumOperands(); i < opSize; i++) { - SmallVector offsets; - SmallVector indices; - SmallVector shape; - SmallVector strides; - - auto input = brgemmOp.getOperand(i); - auto tensorShape = dyn_cast(input.getType()).getShape(); - auto tileItr = tileshapes[i].begin(); - - // Iterates over the shape of each tensor and update its offsets, - // indices, shapes, strides with respect to tile sizes - for (size_t j = 0, tSize = tensorShape.size(); j < tSize; j++) { - if (j == 0 && (i < 2)) { // Updates the batch dimension - offsets.push_back(inductionVars[i][j]); - indices.push_back(1); - shape.push_back(rewriter.getIndexAttr(1)); - strides.push_back(rewriter.getIndexAttr(1)); - } else if (j < 3) { // Updates the M, N, and K dimensions - offsets.push_back(inductionVars[i][j]); - indices.push_back((*tileItr)); - shape.push_back(rewriter.getIndexAttr(*tileItr)); - strides.push_back(rewriter.getIndexAttr(1)); - tileItr++; - } else { // Just copies the vnni layout dimensions - offsets.push_back(rewriter.getIndexAttr(0)); - indices.push_back(tensorShape[j]); - shape.push_back(rewriter.getIndexAttr(tensorShape[j])); - strides.push_back(rewriter.getIndexAttr(1)); - } - } - - auto subview = rewriter.create( - brgemmOp.getLoc(), input, offsets, shape, strides); - brgemmOp.setOperand(i, subview); - } - - rewriter.setInsertionPoint( - innermostForLoop.getBody(), - std::prev(innermostForLoop.getBody()->end(), 1)); - auto clone = rewriter.clone(*brgemmOp); - brgemmOp.replaceAllUsesWith(clone); - - if (brgemmOp->use_empty()) - rewriter.eraseOp(brgemmOp); - - } else { - // f32 gets tiled with the help of upstream tiling interface - linalg::LinalgTilingOptions options; + } else { + // Tile options for f32 type. options.setTileSizes({1, mxnxkTile[0], mxnxkTile[1], mxnxkTile[2]}); - options.setLoopType(linalg::LinalgTilingLoopType::Loops); options.setInterchange({1, 2, 0, 3}); - - FailureOr tiledOp = + tiledOp = linalg::tileLinalgOp(rewriter, brgemmOp, options); + } - if (failed(tiledOp)) { + if (failed(tiledOp)) { return failure(); } - rewriter.replaceOp(brgemmOp, tiledOp->op->getResults()); - } + rewriter.replaceOp(brgemmOp, tiledOp->op->getResults()); + return success(); } From e7682e29d82eb5016da99ded19f70bf5f369d7e2 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Wed, 19 Feb 2025 01:46:47 -0800 Subject: [PATCH 10/15] tile sizes and interchange options are adjusted with respect to maps --- lib/TPP/Transforms/BrgemmLinalgTiling.cpp | 103 ++++++++++++++++------ 1 file changed, 77 insertions(+), 26 deletions(-) diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index 9043c9c8b..42de847ca 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -92,45 +92,96 @@ struct LinalgOpTiling : OpRewritePattern { brgemmOp, "Failed matching for batch reduce operation with vnni layout"); - // Get the register blocking tile shape from the user input - SmallVector mxnxkTile(options.registerTileShape.begin(), - options.registerTileShape.end()); - - linalg::LinalgTilingOptions options; - options.setLoopType(linalg::LinalgTilingLoopType::Loops); + // Tiling with the help of upstream APIs + linalg::LinalgTilingOptions tilingOptions; + tilingOptions.setLoopType(linalg::LinalgTilingLoopType::Loops); FailureOr tiledOp; + // Get rank and map of linalg op + unsigned rankA = + (dyn_cast((brgemmOp->getOperand(0)).getType())).getRank(); + unsigned rankB = + (dyn_cast((brgemmOp->getOperand(1)).getType())).getRank(); + AffineMap mapA = + brgemmOp.getMatchingIndexingMap(&brgemmOp->getOpOperand(0)); + AffineMap mapB = + brgemmOp.getMatchingIndexingMap(&brgemmOp->getOpOperand(1)); + if (vnniOpt) { // k-tile size adjusted based on the vnni layout for bf16 type - auto tensorShape = + auto shape = dyn_cast(brgemmOp.getOperand(0).getType()).getShape(); - auto kTileVnni = mxnxkTile[2] / tensorShape[3]; + auto kTileVnni = options.registerTileShape[2] / shape[3]; - // Note: We make an assumption that the k tile size is divisible to + // Note: We make an assumption that the k tile size is divisible to // the powers of 2. - if (kTileVnni < 1) - return rewriter.notifyMatchFailure( + if (kTileVnni < 1 || (kTileVnni % 2 != 0)) + return rewriter.notifyMatchFailure( brgemmOp, "Failed matching K tile size for batch reduce operation " - "with vnni layout. K tile size should be >= vnni layout"); - - mxnxkTile[2] = kTileVnni; - // Tile options for bf16 type with vnni layout - options.setTileSizes({1, 0, mxnxkTile[0], mxnxkTile[1], mxnxkTile[2]}); - options.setInterchange({2, 3, 0, 4, 1}); - tiledOp = - linalg::tileLinalgOp(rewriter, brgemmOp, options); + "with vnni layout. K tile size should be >= vnni layout " + "and divisible by 2"); + + // Calculating the tile sizes based on affine map for bf16 type with vnni + auto vnniDim = + (dyn_cast(mapA.getResult(rankA - 1))).getPosition(); + auto dimM = + (dyn_cast(mapA.getResult(rankA - 3))).getPosition(); + auto dimN = + (dyn_cast(mapB.getResult(rankB - 2))).getPosition(); + auto dimBR = + (dyn_cast(mapA.getResult(rankA - 4))).getPosition(); + auto dimK = + (dyn_cast(mapA.getResult(rankA - 2))).getPosition(); + + // To set the loop interchange options + SmallVector tileSizes(5); + tileSizes[dimBR] = 1; + tileSizes[dimM] = options.registerTileShape[0]; + tileSizes[dimN] = options.registerTileShape[1]; + tileSizes[dimK] = kTileVnni; + tileSizes[vnniDim] = 0; + + tilingOptions.setTileSizes({tileSizes[0], tileSizes[1], tileSizes[2], + tileSizes[3], tileSizes[4]}); + tilingOptions.setInterchange({dimM, dimN, dimBR, dimK, vnniDim}); + tiledOp = linalg::tileLinalgOp(rewriter, brgemmOp, tilingOptions); } else { - // Tile options for f32 type. - options.setTileSizes({1, mxnxkTile[0], mxnxkTile[1], mxnxkTile[2]}); - options.setInterchange({1, 2, 0, 3}); - tiledOp = - linalg::tileLinalgOp(rewriter, brgemmOp, options); + + // Calculating the tile sizes based on affine map for fp32 type + auto dimM = + (dyn_cast(mapA.getResult(rankA - 2))).getPosition(); + auto dimN = + (dyn_cast(mapB.getResult(rankB - 1))).getPosition(); + auto dimBR = + (dyn_cast(mapA.getResult(rankA - 3))).getPosition(); + auto dimK = + (dyn_cast(mapA.getResult(rankA - 1))).getPosition(); + + // Checks dimensions are aligned with the iterator types + if (brgemmIteratorTypes[dimM] != mlir::utils::IteratorType::parallel || + brgemmIteratorTypes[dimN] != mlir::utils::IteratorType::parallel || + brgemmIteratorTypes[dimBR] != mlir::utils::IteratorType::reduction || + brgemmIteratorTypes[dimK] != mlir::utils::IteratorType::reduction) + return rewriter.notifyMatchFailure( + brgemmOp, "Failed macthing with iterator types and dimension"); + + // To set the loop interchange options + SmallVector tileSizes(4); + tileSizes[dimBR] = 1; + tileSizes[dimM] = options.registerTileShape[0]; + tileSizes[dimN] = options.registerTileShape[1]; + tileSizes[dimK] = options.registerTileShape[2]; + + tilingOptions.setTileSizes( + {tileSizes[0], tileSizes[1], tileSizes[2], tileSizes[3]}); + tilingOptions.setInterchange({dimM, dimN, dimBR, dimK}); + tiledOp = linalg::tileLinalgOp(rewriter, brgemmOp, tilingOptions); } if (failed(tiledOp)) { - return failure(); - } + return failure(); + } rewriter.replaceOp(brgemmOp, tiledOp->op->getResults()); return success(); From 0923d3e8ad725830cfb6ab8ff5ff2bd3ea3c8fee Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Wed, 19 Feb 2025 01:58:19 -0800 Subject: [PATCH 11/15] tile sizes and interchange options are adjusted with respect to maps --- lib/TPP/Transforms/BrgemmLinalgTiling.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index 42de847ca..74fce16ce 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -76,13 +76,13 @@ struct LinalgOpTiling : OpRewritePattern { return rewriter.notifyMatchFailure(brgemmOp, "The operation is not a gemm"); - auto tensorShapeLhs = + auto shapeLhs = dyn_cast(brgemmOp.getOperand(0).getType()).getShape(); - auto tensorShapeRhs = + auto shapeRhs = dyn_cast(brgemmOp.getOperand(1).getType()).getShape(); if (reductionCount == 2 && - (tensorShapeLhs.size() != 3 || tensorShapeRhs.size() != 3)) + (shapeLhs.size() != 3 || shapeRhs.size() != 3)) return rewriter.notifyMatchFailure( brgemmOp, "Invalid rank for batch reduce operation"); @@ -115,11 +115,11 @@ struct LinalgOpTiling : OpRewritePattern { // Note: We make an assumption that the k tile size is divisible to // the powers of 2. - if (kTileVnni < 1 || (kTileVnni % 2 != 0)) + if (kTileVnni < 1 || (options.registerTileShape[2] % shape[3] != 0)) return rewriter.notifyMatchFailure( brgemmOp, "Failed matching K tile size for batch reduce operation " "with vnni layout. K tile size should be >= vnni layout " - "and divisible by 2"); + "and divisible by vnni layout"); // Calculating the tile sizes based on affine map for bf16 type with vnni auto vnniDim = @@ -164,7 +164,7 @@ struct LinalgOpTiling : OpRewritePattern { brgemmIteratorTypes[dimBR] != mlir::utils::IteratorType::reduction || brgemmIteratorTypes[dimK] != mlir::utils::IteratorType::reduction) return rewriter.notifyMatchFailure( - brgemmOp, "Failed macthing with iterator types and dimension"); + brgemmOp, "Failed matching with iterator types and dimension"); // To set the loop interchange options SmallVector tileSizes(4); From 34e5283af05acc0c65cfbea81357de63491039ec Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Wed, 19 Feb 2025 06:48:45 -0800 Subject: [PATCH 12/15] code-refactoring + 1 updated test-case --- lib/TPP/Transforms/BrgemmLinalgTiling.cpp | 32 ++++++++----------- .../pass-tile-brgemm-linalg-matmul.mlir | 8 ++--- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index 74fce16ce..eea07fd7b 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -64,22 +64,21 @@ struct LinalgOpTiling : OpRewritePattern { std::count(brgemmIteratorTypes.begin(), brgemmIteratorTypes.end(), utils::IteratorType::reduction); - if (reductionCount == 0) + if (reductionCount == 0 || reductionCount > 3) return rewriter.notifyMatchFailure(brgemmOp, - "Matmul operation not supported yet"); + "Excepted GEMM like operation"); if (reductionCount == 1) return rewriter.notifyMatchFailure( brgemmOp, "Batch matmul operation not supported yet"); - if (reductionCount > 3) - return rewriter.notifyMatchFailure(brgemmOp, - "The operation is not a gemm"); + auto shapeTypeLhs = + dyn_cast(brgemmOp.getOperand(0).getType()); + auto shapeTypeRhs = + dyn_cast(brgemmOp.getOperand(1).getType()); - auto shapeLhs = - dyn_cast(brgemmOp.getOperand(0).getType()).getShape(); - auto shapeRhs = - dyn_cast(brgemmOp.getOperand(1).getType()).getShape(); + auto shapeLhs = shapeTypeLhs.getShape(); + auto shapeRhs = shapeTypeRhs.getShape(); if (reductionCount == 2 && (shapeLhs.size() != 3 || shapeRhs.size() != 3)) @@ -98,10 +97,8 @@ struct LinalgOpTiling : OpRewritePattern { FailureOr tiledOp; // Get rank and map of linalg op - unsigned rankA = - (dyn_cast((brgemmOp->getOperand(0)).getType())).getRank(); - unsigned rankB = - (dyn_cast((brgemmOp->getOperand(1)).getType())).getRank(); + unsigned rankA = shapeTypeLhs.getRank(); + unsigned rankB = shapeTypeRhs.getRank(); AffineMap mapA = brgemmOp.getMatchingIndexingMap(&brgemmOp->getOpOperand(0)); AffineMap mapB = @@ -109,13 +106,11 @@ struct LinalgOpTiling : OpRewritePattern { if (vnniOpt) { // k-tile size adjusted based on the vnni layout for bf16 type - auto shape = - dyn_cast(brgemmOp.getOperand(0).getType()).getShape(); - auto kTileVnni = options.registerTileShape[2] / shape[3]; + auto kTileVnni = options.registerTileShape[2] / shapeLhs[3]; // Note: We make an assumption that the k tile size is divisible to // the powers of 2. - if (kTileVnni < 1 || (options.registerTileShape[2] % shape[3] != 0)) + if (kTileVnni < 1 || (options.registerTileShape[2] % shapeLhs[3] != 0)) return rewriter.notifyMatchFailure( brgemmOp, "Failed matching K tile size for batch reduce operation " "with vnni layout. K tile size should be >= vnni layout " @@ -144,7 +139,6 @@ struct LinalgOpTiling : OpRewritePattern { tilingOptions.setTileSizes({tileSizes[0], tileSizes[1], tileSizes[2], tileSizes[3], tileSizes[4]}); tilingOptions.setInterchange({dimM, dimN, dimBR, dimK, vnniDim}); - tiledOp = linalg::tileLinalgOp(rewriter, brgemmOp, tilingOptions); } else { @@ -176,9 +170,9 @@ struct LinalgOpTiling : OpRewritePattern { tilingOptions.setTileSizes( {tileSizes[0], tileSizes[1], tileSizes[2], tileSizes[3]}); tilingOptions.setInterchange({dimM, dimN, dimBR, dimK}); - tiledOp = linalg::tileLinalgOp(rewriter, brgemmOp, tilingOptions); } + tiledOp = linalg::tileLinalgOp(rewriter, brgemmOp, tilingOptions); if (failed(tiledOp)) { return failure(); } diff --git a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir index 9bc70439b..4e6d7109c 100644 --- a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir +++ b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir @@ -79,8 +79,8 @@ module { // ----- -#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> -#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d4)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> module { memref.global "private" constant @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} @@ -135,6 +135,7 @@ module { // CONF1-LABEL: func.func @brgemm_tensor_type_no_tiling func.func @brgemm_tensor_type_no_tiling(%arg0: tensor<128x256x512xf32>, %arg1: tensor<128x512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> { // CONF1-NOT: scf.for +// CONF2-NOT: scf.for %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<128x256x512xf32>, tensor<128x512x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %0 : tensor<256x256xf32> } @@ -153,9 +154,8 @@ module { // CONF1-LABEL: func.func @matmul_no_tiling func.func @matmul_no_tiling(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>, %arg2: memref<64x64xf32>) { // CONF1-NOT: scf.for +// CONF2-NOT: scf.for linalg.matmul ins(%arg0, %arg1 : memref<64x64xf32>, memref<64x64xf32>) outs(%arg2 : memref<64x64xf32>) return } - - From a178fe2b5ada32b981613df8cc8193fd9c936d6a Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Wed, 19 Feb 2025 06:56:59 -0800 Subject: [PATCH 13/15] 1 updated test-case for different permutations of affine-map --- test/Passes/pass-tile-brgemm-linalg-matmul.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir index 4e6d7109c..64f379556 100644 --- a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir +++ b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir @@ -79,9 +79,9 @@ module { // ----- -#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d4)> -#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> -#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d1, d4)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d2)> module { memref.global "private" constant @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} func.func @gemm_64tiles_do_tiling_bf16(%arg0: memref<4x16x64x64xbf16>) -> memref<4x16x64x64xbf16> { From 746aa8fc5fcaddd284f897a11e9e68e5c0ddcaba Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Wed, 19 Feb 2025 07:34:34 -0800 Subject: [PATCH 14/15] minor refactoring tileSizes option --- lib/TPP/Transforms/BrgemmLinalgTiling.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index eea07fd7b..274b13434 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -136,8 +136,7 @@ struct LinalgOpTiling : OpRewritePattern { tileSizes[dimK] = kTileVnni; tileSizes[vnniDim] = 0; - tilingOptions.setTileSizes({tileSizes[0], tileSizes[1], tileSizes[2], - tileSizes[3], tileSizes[4]}); + tilingOptions.setTileSizes(tileSizes); tilingOptions.setInterchange({dimM, dimN, dimBR, dimK, vnniDim}); } else { @@ -167,8 +166,7 @@ struct LinalgOpTiling : OpRewritePattern { tileSizes[dimN] = options.registerTileShape[1]; tileSizes[dimK] = options.registerTileShape[2]; - tilingOptions.setTileSizes( - {tileSizes[0], tileSizes[1], tileSizes[2], tileSizes[3]}); + tilingOptions.setTileSizes(tileSizes); tilingOptions.setInterchange({dimM, dimN, dimBR, dimK}); } From f7172590973998ef3bfc6fff14737f7743cd695a Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Wed, 19 Feb 2025 09:26:32 -0800 Subject: [PATCH 15/15] validation of parallel count, remove func.func in test-case --- lib/TPP/Transforms/BrgemmLinalgTiling.cpp | 9 ++++++--- test/Passes/pass-tile-brgemm-linalg-matmul.mlir | 9 --------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp index 274b13434..fb99e3ff2 100644 --- a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -64,7 +64,11 @@ struct LinalgOpTiling : OpRewritePattern { std::count(brgemmIteratorTypes.begin(), brgemmIteratorTypes.end(), utils::IteratorType::reduction); - if (reductionCount == 0 || reductionCount > 3) + int parallelCount = + std::count(brgemmIteratorTypes.begin(), brgemmIteratorTypes.end(), + utils::IteratorType::parallel); + + if (reductionCount == 0 || reductionCount > 3 || parallelCount != 2) return rewriter.notifyMatchFailure(brgemmOp, "Excepted GEMM like operation"); @@ -94,7 +98,6 @@ struct LinalgOpTiling : OpRewritePattern { // Tiling with the help of upstream APIs linalg::LinalgTilingOptions tilingOptions; tilingOptions.setLoopType(linalg::LinalgTilingLoopType::Loops); - FailureOr tiledOp; // Get rank and map of linalg op unsigned rankA = shapeTypeLhs.getRank(); @@ -170,7 +173,7 @@ struct LinalgOpTiling : OpRewritePattern { tilingOptions.setInterchange({dimM, dimN, dimBR, dimK}); } - tiledOp = linalg::tileLinalgOp(rewriter, brgemmOp, tilingOptions); + FailureOr tiledOp = linalg::tileLinalgOp(rewriter, brgemmOp, tilingOptions); if (failed(tiledOp)) { return failure(); } diff --git a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir index 64f379556..94d2368cc 100644 --- a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir +++ b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir @@ -133,12 +133,8 @@ module { // CONF1-LABEL: func.func @brgemm_tensor_type_no_tiling -func.func @brgemm_tensor_type_no_tiling(%arg0: tensor<128x256x512xf32>, %arg1: tensor<128x512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> { // CONF1-NOT: scf.for // CONF2-NOT: scf.for - %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<128x256x512xf32>, tensor<128x512x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> - return %0 : tensor<256x256xf32> -} // ----- @@ -152,10 +148,5 @@ module { // CONF1-LABEL: func.func @matmul_no_tiling -func.func @matmul_no_tiling(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>, %arg2: memref<64x64xf32>) { // CONF1-NOT: scf.for // CONF2-NOT: scf.for - linalg.matmul ins(%arg0, %arg1 : memref<64x64xf32>, memref<64x64xf32>) - outs(%arg2 : memref<64x64xf32>) - return -}