Skip to content

Commit

Permalink
Code re-factoring, changes to test-cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Arun Thangamani committed Feb 13, 2025
1 parent ef08d66 commit 974c3eb
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 113 deletions.
17 changes: 8 additions & 9 deletions lib/TPP/Transforms/BrgemmLinalgTiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {

if (reductionCount == 0)
return rewriter.notifyMatchFailure(
brgemmOp, "Batch matmul operation not supported yet");
brgemmOp, "Matmul operation not supported yet");

if (reductionCount == 1)
return rewriter.notifyMatchFailure(
brgemmOp, "Matmul operation not supported yet");
brgemmOp, "Batch matmul operation not supported yet");

if (reductionCount > 3)
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -107,28 +107,27 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
mxnxkTile[2] = mxnxkTile[2] / tensorShape[3];
}

size_t i = 0;
SmallVector<int> swap_i = {0, 2, 1};
std::map<int, std::map<int, Value>> inductionVars;

// For M, N, and K loops
scf::ForOp innermostForLoop;
// Creating the tiled loops
for (auto itrShapeMNK = mxnxkTile.begin(); itrShapeMNK != mxnxkTile.end();
itrShapeMNK++, i++) {
for (auto [i, itrShapeMNK] : llvm::enumerate(mxnxkTile)) {
auto upperBound =
dyn_cast<MemRefType>(brgemmOp.getOperand(swap_i[i]).getType())
.getShape()[1];
// Tile size should not be greater than the upperBound
if ((*itrShapeMNK) > upperBound)
return failure();
if ((itrShapeMNK) > upperBound)
return rewriter.notifyMatchFailure(
brgemmOp, "Tile size is greater than the dimension");

Location loc = brgemmOp.getLoc();
Value zeroCst = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value ubCstTiledLoop =
rewriter.create<arith::ConstantIndexOp>(loc, upperBound);
Value stepCstTiledLoop =
rewriter.create<arith::ConstantIndexOp>(loc, *itrShapeMNK);
rewriter.create<arith::ConstantIndexOp>(loc, itrShapeMNK);
// Creates M, N, and K tile loops
scf::ForOp loopOp = rewriter.create<scf::ForOp>(
brgemmOp.getLoc(), zeroCst, ubCstTiledLoop, stepCstTiledLoop);
Expand Down Expand Up @@ -201,7 +200,7 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
}

auto subview = rewriter.create<memref::SubViewOp>(
brgemmOp.getLoc(), MemRefType(), input, offsets, shape, strides);
brgemmOp.getLoc(), input, offsets, shape, strides);
brgemmOp.setOperand(i, subview);
}

Expand Down
3 changes: 1 addition & 2 deletions test/Integration/tile-brgemm-linalg-matmul-bf16.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// RUN: tpp-run -e register_tile_bf16 --entry-point-result=void -print %s > %t.1
// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" -convert-linalg-to-xsmm | tpp-run -e register_tile_bf16 --entry-point-result=void -print > %t.2
// RUN: diff %t.1 %t.2
// RUN: rm %t.1 %t.2
// RUN: fpcmp %t.1 %t.2

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
Expand Down
2 changes: 1 addition & 1 deletion test/Integration/tile-brgemm-linalg-matmul.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: tpp-run -e entry --entry-point-result=void -print %s > %t.1
// RUN: tpp-run -e entry --entry-point-result=void --vector-to-kernels --registerBlocking=8,32,1 %s -print > %t.2
// RUN: diff %t.1 %t.2
// RUN: fpcmp %t.1 %t.2
// RUN: rm %t.1 %t.2

module {
Expand Down
140 changes: 39 additions & 101 deletions test/Passes/pass-tile-brgemm-linalg-matmul.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=8,32,1" --split-input-file | FileCheck -check-prefix=CONF1 %s
// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" --canonicalize --split-input-file | FileCheck -check-prefix=CONF2 %s
// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" --split-input-file | FileCheck -check-prefix=CONF2 %s

module {
func.func @gemm_do_register_tiling(%arg0: memref<16x32x16x32xf32>, %arg1: memref<32x32x32x32xf32>, %arg2: memref<16x32x16x32xf32>) {
Expand All @@ -13,34 +13,18 @@ module {
}
}

// CONF1-LABEL: func.func @gemm_do_register_tiling(
// CONF1-SAME: %[[VAL_0:.*]]: memref<16x32x16x32xf32>,
// CONF1-SAME: %[[VAL_1:.*]]: memref<32x32x32x32xf32>,
// CONF1-SAME: %[[VAL_2:.*]]: memref<16x32x16x32xf32>) {
// CONF1: %[[VAL_3:.*]] = arith.constant 1 : index
// CONF1: %[[VAL_4:.*]] = arith.constant 32 : index
// CONF1: %[[VAL_5:.*]] = arith.constant 8 : index
// CONF1: %[[VAL_6:.*]] = arith.constant 16 : index
// CONF1: %[[VAL_7:.*]] = arith.constant 0 : index
// CONF1: scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) in (16, 32) {
// CONF1: %[[VAL_10:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_8]], 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>>
// CONF1: %[[VAL_11:.*]] = memref.subview %[[VAL_1]]{{\[}}%[[VAL_9]], 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>
// CONF1: %[[VAL_12:.*]] = memref.subview %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<16x32xf32, strided<[32, 1], offset: ?>>
// CONF1: scf.for %[[VAL_13:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_5]] {
// CONF1: scf.for %[[VAL_14:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_4]] {
// CONF1: scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_3]] {
// CONF1: scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_3]] {
// CONF1: %[[VAL_17:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_15]], %[[VAL_13]], %[[VAL_16]]] [1, 8, 1] [1, 1, 1] : memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> to memref<1x8x1xf32, strided<[512, 32, 1], offset: ?>>
// CONF1: %[[VAL_18:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_15]], %[[VAL_16]], %[[VAL_14]]] [1, 1, 32] [1, 1, 1] : memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>
// CONF1: %[[VAL_19:.*]] = memref.subview %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_14]]] [8, 32] [1, 1] : memref<16x32xf32, strided<[32, 1], offset: ?>> to memref<8x32xf32, strided<[32, 1], offset: ?>>
// CONF1: linalg.batch_reduce_matmul ins(%[[VAL_17]], %[[VAL_18]] : memref<1x8x1xf32, strided<[512, 32, 1], offset: ?>>, memref<1x1x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_19]] : memref<8x32xf32, strided<[32, 1], offset: ?>>)
// CONF1: }
// CONF1: }
// CONF1: }
// CONF1: }
// CONF1: }
// CONF1: return
// CONF1: }
// CONF1-LABEL: func.func @gemm_do_register_tiling
// CONF1: memref.subview
// CONF1-NEXT: memref.subview
// CONF1-NEXT: memref.subview
// CONF1-NEXT: scf.for
// CONF1-NEXT: scf.for
// CONF1-NEXT: scf.for
// CONF1-NEXT: scf.for
// CONF1-NEXT: memref.subview
// CONF1-NEXT: memref.subview
// CONF1-NEXT: memref.subview
// CONF1-NEXT: linalg.batch_reduce_matmul

// -----

Expand Down Expand Up @@ -146,7 +130,7 @@ module {
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
module {
memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64}
func.func @gemm_32tiles_do_tiling(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> {
func.func @gemm_32tiles_do_tiling_bf16(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> {
%cst = arith.constant 0.000000e+00 : bf16
%0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16>
Expand All @@ -166,37 +150,18 @@ module {
}
}

// CONF2: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
// CONF2: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
// CONF2: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>

// CONF2-LABEL: memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64}
// CONF2-LABEL: func.func @gemm_32tiles_do_tiling(
// CONF2-SAME: %[[VAL_0:.*]]: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> {
// CONF2: %[[VAL_1:.*]] = arith.constant 1 : index
// CONF2: %[[VAL_2:.*]] = arith.constant 32 : index
// CONF2: %[[VAL_3:.*]] = arith.constant 0 : index
// CONF2: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : bf16
// CONF2: %[[VAL_5:.*]] = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16>
// CONF2: %[[VAL_6:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16>
// CONF2: %[[VAL_7:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2] : memref<8x32x32x32xbf16> into memref<8x32x32x16x2xbf16>
// CONF2: scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) in (8, 32) {
// CONF2: %[[VAL_10:.*]] = memref.subview %[[VAL_6]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>>
// CONF2: linalg.fill ins(%[[VAL_4]] : bf16) outs(%[[VAL_10]] : memref<32x32xbf16, strided<[32, 1], offset: ?>>)
// CONF2: %[[VAL_11:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_8]], 0, 0, 0, 0] [1, 32, 32, 16, 2] [1, 1, 1, 1, 1] : memref<8x32x32x16x2xbf16> to memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>
// CONF2: scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_1]] {
// CONF2: %[[VAL_13:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>> to memref<1x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>
// CONF2: %[[VAL_14:.*]] = memref.subview %[[VAL_5]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<32x16x32x2xbf16> to memref<1x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
// CONF2: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%[[VAL_13]], %[[VAL_14]] : memref<1x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>, memref<1x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%[[VAL_10]] : memref<32x32xbf16, strided<[32, 1], offset: ?>>) {
// CONF2: ^bb0(%[[VAL_15:.*]]: bf16, %[[VAL_16:.*]]: bf16, %[[VAL_17:.*]]: bf16):
// CONF2: %[[VAL_18:.*]] = arith.mulf %[[VAL_15]], %[[VAL_16]] : bf16
// CONF2: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : bf16
// CONF2: linalg.yield %[[VAL_19]] : bf16
// CONF2: }
// CONF2: }
// CONF2: }
// CONF2: return %[[VAL_6]] : memref<8x32x32x32xbf16>
// CONF2: }
// CONF2-LABEL: func.func @gemm_32tiles_do_tiling_bf16
// CONF2: memref.subview
// CONF2-NEXT: linalg.fill
// CONF2-NEXT: memref.subview
// CONF2-NEXT: scf.for
// CONF2-NEXT: scf.for
// CONF2-NEXT: scf.for
// CONF2-NEXT: scf.for
// CONF2-NEXT: memref.subview
// CONF2-NEXT: memref.subview
// CONF2-NEXT: memref.subview
// CONF2-NEXT: linalg.generic

// -----

Expand All @@ -205,7 +170,7 @@ module {
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
module {
memref.global "private" constant @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64}
func.func @gemm_64tiles_do_tiling(%arg0: memref<4x16x64x64xbf16>) -> memref<4x16x64x64xbf16> {
func.func @gemm_64tiles_do_tiling_bf16(%arg0: memref<4x16x64x64xbf16>) -> memref<4x16x64x64xbf16> {
%cst = arith.constant 0.000000e+00 : bf16
%0 = memref.get_global @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x64x64xbf16>
Expand All @@ -225,42 +190,15 @@ module {
}
}

// CONF2: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
// CONF2: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
// CONF2: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
// CONF2-LABEL: memref.global "private" constant @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64}
// CONF2-LABEL: func.func @gemm_64tiles_do_tiling(
// CONF2-SAME: %[[VAL_0:.*]]: memref<4x16x64x64xbf16>) -> memref<4x16x64x64xbf16> {
// CONF2: %[[VAL_1:.*]] = arith.constant 1 : index
// CONF2: %[[VAL_2:.*]] = arith.constant 16 : index
// CONF2: %[[VAL_3:.*]] = arith.constant 32 : index
// CONF2: %[[VAL_4:.*]] = arith.constant 64 : index
// CONF2: %[[VAL_5:.*]] = arith.constant 0 : index
// CONF2: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : bf16
// CONF2: %[[VAL_7:.*]] = memref.get_global @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16>
// CONF2: %[[VAL_8:.*]] = memref.alloc() {alignment = 64 : i64} : memref<4x16x64x64xbf16>
// CONF2: %[[VAL_9:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0], [1], [2], [3, 4]] output_shape [4, 16, 64, 32, 2] : memref<4x16x64x64xbf16> into memref<4x16x64x32x2xbf16>
// CONF2: scf.forall (%[[VAL_10:.*]], %[[VAL_11:.*]]) in (4, 16) {
// CONF2: %[[VAL_12:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_10]], %[[VAL_11]], 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<4x16x64x64xbf16> to memref<64x64xbf16, strided<[64, 1], offset: ?>>
// CONF2: linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_12]] : memref<64x64xbf16, strided<[64, 1], offset: ?>>)
// CONF2: %[[VAL_13:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_10]], 0, 0, 0, 0] [1, 16, 64, 32, 2] [1, 1, 1, 1, 1] : memref<4x16x64x32x2xbf16> to memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>
// CONF2: scf.for %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] {
// CONF2: scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] {
// CONF2: scf.for %[[VAL_16:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_1]] {
// CONF2: scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_2]] {
// CONF2: %[[VAL_18:.*]] = memref.subview %[[VAL_13]]{{\[}}%[[VAL_16]], %[[VAL_14]], %[[VAL_17]], 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> to memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>
// CONF2: %[[VAL_19:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_15]], 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x32x64x2xbf16> to memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>>
// CONF2: %[[VAL_20:.*]] = memref.subview %[[VAL_12]]{{\[}}%[[VAL_14]], %[[VAL_15]]] [32, 32] [1, 1] : memref<64x64xbf16, strided<[64, 1], offset: ?>> to memref<32x32xbf16, strided<[64, 1], offset: ?>>
// CONF2: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%[[VAL_18]], %[[VAL_19]] : memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>, memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>>) outs(%[[VAL_20]] : memref<32x32xbf16, strided<[64, 1], offset: ?>>) {
// CONF2: ^bb0(%[[VAL_21:.*]]: bf16, %[[VAL_22:.*]]: bf16, %[[VAL_23:.*]]: bf16):
// CONF2: %[[VAL_24:.*]] = arith.mulf %[[VAL_21]], %[[VAL_22]] : bf16
// CONF2: %[[VAL_25:.*]] = arith.addf %[[VAL_23]], %[[VAL_24]] : bf16
// CONF2: linalg.yield %[[VAL_25]] : bf16
// CONF2: }
// CONF2: }
// CONF2: }
// CONF2: }
// CONF2: }
// CONF2: }
// CONF2: return %[[VAL_8]] : memref<4x16x64x64xbf16>
// CONF2: }
// CONF2-LABEL: func.func @gemm_64tiles_do_tiling_bf16
// CONF2: memref.subview
// CONF2-NEXT: linalg.fill
// CONF2-NEXT: memref.subview
// CONF2-NEXT: scf.for
// CONF2-NEXT: scf.for
// CONF2-NEXT: scf.for
// CONF2-NEXT: scf.for
// CONF2-NEXT: memref.subview
// CONF2-NEXT: memref.subview
// CONF2-NEXT: memref.subview
// CONF2-NEXT: linalg.generic

0 comments on commit 974c3eb

Please sign in to comment.