From db5bab2b2eb5090c8b496674c5f45b399c908cd0 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 10 Jan 2025 15:47:26 +0100 Subject: [PATCH 01/26] Get VNNI factor from DLTI --- include/TPP/Transforms/Utils/VNNIUtils.h | 7 +++-- lib/TPP/Transforms/Utils/VNNIUtils.cpp | 33 ++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/include/TPP/Transforms/Utils/VNNIUtils.h b/include/TPP/Transforms/Utils/VNNIUtils.h index 58d1c73bd..343e4d8be 100644 --- a/include/TPP/Transforms/Utils/VNNIUtils.h +++ b/include/TPP/Transforms/Utils/VNNIUtils.h @@ -20,6 +20,7 @@ class OpOperand; class AffineDimExpr; class AffineMap; class VectorType; +class Operation; namespace linalg { class LinalgOp; @@ -35,8 +36,10 @@ enum class VnniOperandRank { BRGEMM_OUTS = 3 }; -// Return the VNNI blocking factor: 2 for BF16 and 4 for BF8. -std::optional getVnniBlockingFactor(Type type); +// Return the VNNI blocking factor. +// Optionally, operation can be provided to give access to DLTI. +std::optional getVnniBlockingFactor(Type type, + Operation *op = nullptr); // Return true if the memref is in VNNI layout with rank `expectedRank`. bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref); diff --git a/lib/TPP/Transforms/Utils/VNNIUtils.cpp b/lib/TPP/Transforms/Utils/VNNIUtils.cpp index dd44c247f..0715b2c0a 100644 --- a/lib/TPP/Transforms/Utils/VNNIUtils.cpp +++ b/lib/TPP/Transforms/Utils/VNNIUtils.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TPP/Transforms/Utils/VNNIUtils.h" + #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -20,10 +21,38 @@ namespace mlir { namespace vnni { namespace utils { -std::optional getVnniBlockingFactor(Type type) { +std::optional getVnniBlockingFactor(Type type, Operation *op) { auto elementType = getElementTypeOrSelf(type); - if (elementType.isBF16()) + if (elementType.isBF16()) { + // Check if a VNNI factor hint is associated to the IR via DLTI. + auto deriveVnniFromDLTI = [&]() -> std::optional { + if (!op) + return std::nullopt; + ModuleOp moduleOp = op->getParentOfType(); + if (!moduleOp) + return std::nullopt; + TargetSystemSpecInterface sysSpec = moduleOp.getTargetSystemSpec(); + if (!sysSpec) + return std::nullopt; + auto deviceId = StringAttr::get(moduleOp->getContext(), "CPU"); + auto deviceSpec = sysSpec.getDeviceSpecForDeviceID(deviceId); + if (!deviceSpec) + return std::nullopt; + auto tileSizeId = StringAttr::get(moduleOp->getContext(), "vnni"); + DataLayoutEntryInterface entry = + (*deviceSpec).getSpecForIdentifier(tileSizeId); + if (!entry) + return std::nullopt; + Attribute value = entry.getValue(); + if (auto intAttr = llvm::dyn_cast(value)) + return intAttr.getInt(); + return std::nullopt; + }; + if (auto vnniFactor = deriveVnniFromDLTI()) + return *vnniFactor; + return libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16); + } return std::nullopt; } From 795b14c58175abe9521a7fb8817679de170e5e3a Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 10 Jan 2025 16:40:49 +0100 Subject: [PATCH 02/26] WIP update tests --- test/BF16/Integration/matmul-pbf16.mlir | 50 ------- test/BF16/Integration/mlir-gen-bf16.mlir | 28 ++-- .../BF16/Integration/mlp-all-bf16-tpprun.mlir | 137 ------------------ .../BF16/Integration/tpp-run-splat-shape.mlir | 2 +- test/BF16/Integration/vnni-xsmm-vs-loops.mlir | 29 +--- 5 files changed, 23 insertions(+), 223 deletions(-) delete mode 100644 test/BF16/Integration/matmul-pbf16.mlir delete mode 100644 test/BF16/Integration/mlp-all-bf16-tpprun.mlir diff --git a/test/BF16/Integration/matmul-pbf16.mlir b/test/BF16/Integration/matmul-pbf16.mlir deleted file mode 100644 index f2434271d..000000000 --- a/test/BF16/Integration/matmul-pbf16.mlir +++ /dev/null @@ -1,50 +0,0 @@ -// RUN: tpp-run %s -print \ -// RUN: -e entry -entry-point-result=void | \ -// RUN: FileCheck %s - -#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> - -func.func @matmultpp(%A: memref<4x8xbf16>, - %B: memref<4x4x2xbf16>, %C: memref<4x4xbf16>) { - %expanded = memref.expand_shape %A [[0], [1, 2]] output_shape [4, 4, 2] - : memref<4x8xbf16> into memref<4x4x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %B : memref<4x4x2xbf16>, memref<4x4x2xbf16>) - outs(%C : memref<4x4xbf16>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return -} - -func.func @entry() { - %c0 = arith.constant 0 : index - %f0 = arith.constant 1.0 : bf16 - %da = memref.alloc() :memref<4x8xbf16> - linalg.fill ins(%f0 : bf16) outs(%da : memref<4x8xbf16>) - // Call kernel. - %0 = memref.alloc() : memref<4x4x2xbf16> - linalg.fill ins(%f0:bf16) outs (%0: memref<4x4x2xbf16>) - %D = memref.alloc() : memref<4x4xbf16> - %zero = arith.constant 0.0 : bf16 - linalg.fill ins(%zero : bf16) outs(%D:memref<4x4xbf16>) - call @matmultpp(%da, %0, %D) - : (memref<4x8xbf16>, memref<4x4x2xbf16>, memref<4x4xbf16>)->() - - // - // CHECK:( ( 8, 8, 8, 8 ), ( 8, 8, 8, 8 ), ( 8, 8, 8, 8 ), ( 8, 8, 8, 8 ) ) - // - %d1 = arith.constant -1.0 : bf16 - - %v0 = vector.transfer_read %D[%c0, %c0], %d1 : memref<4x4xbf16>, vector<4x4xbf16> - %f1 = arith.extf %v0:vector<4x4xbf16> to vector<4x4xf32> - vector.print %f1 : vector<4x4xf32> - - return -} diff --git a/test/BF16/Integration/mlir-gen-bf16.mlir b/test/BF16/Integration/mlir-gen-bf16.mlir index a0db89a6b..97035a7d1 100644 --- a/test/BF16/Integration/mlir-gen-bf16.mlir +++ b/test/BF16/Integration/mlir-gen-bf16.mlir @@ -1,28 +1,28 @@ // MLP without softmax (can't print packed version for now) -// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void +// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16,16 --float-type=bf16 | tpp-run -e entry -entry-point-result=void +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16,16 --float-type=bf16 | tpp-run -e entry -entry-point-result=void // Matmul only -// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void +// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --float-type=bf16 | tpp-run -e entry -entry-point-result=void +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --float-type=bf16 | tpp-run -e entry -entry-point-result=void // Kernel - matmul -// RUN: mlir-gen --kernel=args --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16 -// RUN: mlir-gen --output=named --kernel=args --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16 +// RUN: mlir-gen --kernel=args --seed=123 --float-type=bf16 --batch=16 --layers=16,16 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16 +// RUN: mlir-gen --output=named --kernel=args --seed=123 --float-type=bf16 --batch=16 --layers=16,16 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16 // Kernel - fc -// RUN: mlir-gen --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16 -// RUN: mlir-gen --output=named --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16 +// RUN: mlir-gen --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=16 --layers=16,16 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16 +// RUN: mlir-gen --output=named --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=16 --layers=16,16 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16 // BF16/VNNI execution -// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF -// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF -// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF -// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF +// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --tiles=8,8,8 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --tiles=8,8,8 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF +// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --tiles=8,8,8 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --tiles=8,8,8 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF -// GEN-MATMUL-BF16: ( 11, 11, 11, 11, 11, 11, 11, 11, 11, 11 ) +// GEN-MATMUL-BF16: ( 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17 ) -// GEN-FC-BF16: ( 12, 12, 12, 12, 12, 12, 12, 12, 12, 12 ) +// GEN-FC-BF16: ( 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18 ) // PERF: {{[0-9]+}}{{.?}}{{[0-9e-]+}} diff --git a/test/BF16/Integration/mlp-all-bf16-tpprun.mlir b/test/BF16/Integration/mlp-all-bf16-tpprun.mlir deleted file mode 100644 index 5f7968719..000000000 --- a/test/BF16/Integration/mlp-all-bf16-tpprun.mlir +++ /dev/null @@ -1,137 +0,0 @@ -// RUN: tpp-run %s \ -// RUN: -e entry -entry-point-result=void - -memref.global "private" constant @arg1 : memref<128x512x2xbf16> = dense<1.00e+00> -memref.global "private" constant @arg3 : memref<256x1024x2xbf16> = dense<1.00e+00> -memref.global "private" constant @arg5 : memref<512x2048x2xbf16> = dense<1.00e+00> -memref.global "private" constant @arg7 : memref<1024x1000x2xbf16> = dense<1.00e+00> - -#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> -#map3 = affine_map<(d0, d1) -> (d0, d1)> -#map4 = affine_map<(d0, d1) -> (d1)> - -func.func @entry(%arg0: memref<128x256xbf16>, %arg2: memref<512xbf16>, %arg4: memref<1024xbf16>, - %arg6: memref<2048xbf16>, %arg8: memref<1000xbf16>, %arg9: memref<128x512xbf16>, - %arg10: memref<128x1024xbf16>, %arg11: memref<128x2048xbf16>, %arg12: memref<128x1000xbf16>) { - %c0 = arith.constant 0.0 : bf16 - linalg.generic { - indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]} - ins(%arg2: memref<512xbf16>) outs(%arg9: memref<128x512xbf16>) { - ^bb0(%in: bf16, %out: bf16): - linalg.yield %in : bf16 - } - - %e0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [128, 128, 2] - : memref<128x256xbf16> into memref<128x128x2xbf16> - %relayout_arg0 = memref.get_global @arg1:memref<128x512x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%e0, %relayout_arg0 : memref<128x128x2xbf16>, memref<128x512x2xbf16>) - outs(%arg9 : memref<128x512xbf16>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - linalg.generic { - indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} - ins(%arg9 : memref<128x512xbf16>) outs(%arg9 : memref<128x512xbf16>) { - ^bb0(%in: bf16, %out: bf16): - %2 = arith.maximumf %in, %c0 : bf16 - linalg.yield %2 : bf16 - } - - linalg.generic { - indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]} - ins(%arg4: memref<1024xbf16>) outs(%arg10: memref<128x1024xbf16>) { - ^bb0(%in: bf16, %out: bf16): - linalg.yield %in : bf16 - } - - %e1 = memref.expand_shape %arg9 [[0], [1, 2]] output_shape [128, 256, 2] - : memref<128x512xbf16> into memref<128x256x2xbf16> - %relayout_arg12 = memref.get_global @arg3:memref<256x1024x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%e1, %relayout_arg12 : memref<128x256x2xbf16>, memref<256x1024x2xbf16>) - outs(%arg10 : memref<128x1024xbf16>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - linalg.generic { - indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} - ins(%arg10 : memref<128x1024xbf16>) outs(%arg10 : memref<128x1024xbf16>) { - ^bb0(%in: bf16, %out: bf16): - %2 = arith.maximumf %in, %c0 : bf16 - linalg.yield %2 : bf16 - } - - linalg.generic { - indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]} - ins(%arg6: memref<2048xbf16>) outs(%arg11: memref<128x2048xbf16>) { - ^bb0(%in: bf16, %out: bf16): - linalg.yield %in : bf16 - } - - %relayout_arg11 = memref.get_global @arg5:memref<512x2048x2xbf16> - %e2 = memref.expand_shape %arg10 [[0], [1, 2]] output_shape [128, 512, 2] - : memref<128x1024xbf16> into memref<128x512x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%e2, %relayout_arg11 : memref<128x512x2xbf16>, memref<512x2048x2xbf16>) - outs(%arg11 : memref<128x2048xbf16>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - linalg.generic { - indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} - ins(%arg11 : memref<128x2048xbf16>) outs(%arg11 : memref<128x2048xbf16>) { - ^bb0(%in: bf16, %out: bf16): - %2 = arith.maximumf %in, %c0 : bf16 - linalg.yield %2 : bf16 - } - - linalg.generic { - indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]} - ins(%arg8: memref<1000xbf16>) outs(%arg12: memref<128x1000xbf16>) { - ^bb0(%in: bf16, %out: bf16): - linalg.yield %in : bf16 - } - - %relayout_arg10 = memref.get_global @arg7:memref<1024x1000x2xbf16> - %e3 = memref.expand_shape %arg11 [[0], [1, 2]] output_shape [128, 1024, 2] - : memref<128x2048xbf16> into memref<128x1024x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%e3, %relayout_arg10 : memref<128x1024x2xbf16>, memref<1024x1000x2xbf16>) - outs(%arg12 : memref<128x1000xbf16>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - linalg.generic { - indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} - ins(%arg12 : memref<128x1000xbf16>) outs(%arg12 : memref<128x1000xbf16>) { - ^bb0(%in: bf16, %out: bf16): - %2 = arith.maximumf %in, %c0 : bf16 - linalg.yield %2 : bf16 - } - - %threshold = arith.constant 1.0 : bf16 - %c4 = arith.constant 2.74878e+11: bf16 - %interim4 = memref.alloc(): memref<128x1000xbf16> - linalg.fill ins(%c4:bf16) outs(%interim4: memref<128x1000xbf16>) - check.expect_almost_eq(%interim4, %arg12, %threshold): memref<128x1000xbf16>, memref<128x1000xbf16>, bf16 - return -} diff --git a/test/BF16/Integration/tpp-run-splat-shape.mlir b/test/BF16/Integration/tpp-run-splat-shape.mlir index 624aeb754..935586599 100644 --- a/test/BF16/Integration/tpp-run-splat-shape.mlir +++ b/test/BF16/Integration/tpp-run-splat-shape.mlir @@ -41,7 +41,7 @@ func.func @entry(%arg0: tensor<4x8x8x8xbf16>, %output: tensor<4x8x8x8xbf16>) -> // due to compile time packing. // CHECK-NOT: memref.global "private" constant @__constant_{{.*}}: memref<8x8xbf16> // CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<4x8x8x8xbf16> -// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<8x8x4x8x2xbf16> +// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<8x8x4x8x{{[2|4|8]}}xbf16> // CHECK: xsmm_brgemm_invoke // CHECK: xsmm_binary_invoke // CHECK: xsmm_unary_invoke diff --git a/test/BF16/Integration/vnni-xsmm-vs-loops.mlir b/test/BF16/Integration/vnni-xsmm-vs-loops.mlir index 2a8419395..0f7eb99d1 100644 --- a/test/BF16/Integration/vnni-xsmm-vs-loops.mlir +++ b/test/BF16/Integration/vnni-xsmm-vs-loops.mlir @@ -1,26 +1,13 @@ -// RUN: tpp-run %s -print -seed 123 \ +// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 \ +// RUN: --tiles=16,16,16 --float-type=bf16 | \ +// RUN: tpp-opt --pack-vnni | \ +// RUN: tpp-run -print -seed 123 \ // RUN: -e entry -entry-point-result=void > %t.xsmm -// RUN: tpp-run %s -print -seed 123 -linalg-to-loops \ +// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 \ +// RUN: --tiles=16,16,16 --float-type=bf16 | \ +// RUN: tpp-opt --pack-vnni | \ +// RUN: tpp-run -print -seed 123 -linalg-to-loops \ // RUN: -e entry -entry-point-result=void > %t.loops // RUN: fpcmp -r 0.01 %t.xsmm %t.loops - -#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6, d3)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6, d5, d3)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> - -func.func @entry(%arg0: tensor<2x2x7x4x2xbf16>, %arg1: tensor<2x2x4x5x2xbf16>, - %arg2: tensor<2x2x7x5xbf16>) -> tensor<2x2x7x5xbf16> { - %1 = linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : tensor<2x2x7x4x2xbf16>, tensor<2x2x4x5x2xbf16>) - outs(%arg2 : tensor<2x2x7x5xbf16>) { - ^bb0(%in: bf16, %in_0: bf16, %out: bf16): - %2 = arith.mulf %in, %in_0 : bf16 - %3 = arith.addf %out, %2 : bf16 - linalg.yield %3 : bf16 - } -> tensor<2x2x7x5xbf16> - return %1 : tensor<2x2x7x5xbf16> -} From e36dac9d51b594a9a142821d1044f00fdecd6b83 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Mon, 13 Jan 2025 15:11:43 +0100 Subject: [PATCH 03/26] WIP update vnni matchers --- test/BF16/Integration/tpp-run-splat-shape.mlir | 2 +- test/BF16/brgemm-tpp.mlir | 8 ++++---- test/BF16/brgemm-vnni.mlir | 16 ++++++++-------- test/BF16/matmul-untiled-vnni.mlir | 2 +- test/BF16/matmul-vnni.mlir | 16 ++++++++-------- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/test/BF16/Integration/tpp-run-splat-shape.mlir b/test/BF16/Integration/tpp-run-splat-shape.mlir index 935586599..4a865ab09 100644 --- a/test/BF16/Integration/tpp-run-splat-shape.mlir +++ b/test/BF16/Integration/tpp-run-splat-shape.mlir @@ -41,7 +41,7 @@ func.func @entry(%arg0: tensor<4x8x8x8xbf16>, %output: tensor<4x8x8x8xbf16>) -> // due to compile time packing. // CHECK-NOT: memref.global "private" constant @__constant_{{.*}}: memref<8x8xbf16> // CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<4x8x8x8xbf16> -// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<8x8x4x8x{{[2|4|8]}}xbf16> +// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<8x8x{{[4|2]}}x8x{{2|4}}xbf16> // CHECK: xsmm_brgemm_invoke // CHECK: xsmm_binary_invoke // CHECK: xsmm_unary_invoke diff --git a/test/BF16/brgemm-tpp.mlir b/test/BF16/brgemm-tpp.mlir index 78caff1ed..08408e57b 100644 --- a/test/BF16/brgemm-tpp.mlir +++ b/test/BF16/brgemm-tpp.mlir @@ -14,10 +14,10 @@ func.func @brgemm(%arg0: tensor<32x4x4xbf16>, %arg1: tensor<32x4x4xbf16>, // CHECK-LABEL: brgemm // CHECK-SAME: %[[ARG0:.+]]: tensor<32x4x4xbf16>, %[[ARG1:.+]]: tensor<32x4x4xbf16>, // CHECK-SAME: %[[ARG2:.+]]: tensor<4x4xbf16> -// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] output_shape [32, 4, 2, 2] : tensor<32x4x4xbf16> into tensor<32x4x2x2xbf16> -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x2x4x2xbf16> -// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] inner_dims_pos = [1] inner_tiles = [2] -// CHECK-SAME: into %[[EMPTY]] : tensor<32x4x4xbf16> -> tensor<32x2x4x2xbf16> +// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] output_shape{{.*}}: tensor<32x4x4xbf16> into tensor<32x4x{{2|1}}x{{2|4}}xbf16> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x2x4x{{2|4}}xbf16> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] inner_dims_pos = [1] inner_tiles = [{{2|4}}] +// CHECK-SAME: into %[[EMPTY]] : tensor<32x4x4xbf16> -> tensor<32x{{2|1}}x4x{{2|4}}xbf16> // CHECK: %{{.+}} = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"] diff --git a/test/BF16/brgemm-vnni.mlir b/test/BF16/brgemm-vnni.mlir index aa6d069d2..5970ebec4 100644 --- a/test/BF16/brgemm-vnni.mlir +++ b/test/BF16/brgemm-vnni.mlir @@ -14,11 +14,11 @@ func.func @brgemm(%arg0: tensor<32x4x4xbf16>, %arg1: tensor<32x4x4xbf16>, // CHECK-LABEL: brgemm // CHECK-SAME: %[[ARG0:.+]]: tensor<32x4x4xbf16>, %[[ARG1:.+]]: tensor<32x4x4xbf16>, // CHECK-SAME: %[[ARG2:.+]]: tensor<4x4xbf16> -// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] output_shape [32, 4, 2, 2] : tensor<32x4x4xbf16> into tensor<32x4x2x2xbf16> -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x2x4x2xbf16> +// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] output_shape{{.*}}: tensor<32x4x4xbf16> into tensor<32x4x{{2|1}}x{{2|4}}xbf16> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x{{2|1}}x4x{{2|4}}xbf16> // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] -// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [2] into %[[EMPTY]] -// CHECK-SAME: : tensor<32x4x4xbf16> -> tensor<32x2x4x2xbf16> +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [{{2|4}}] into %[[EMPTY]] +// CHECK-SAME: : tensor<32x4x4xbf16> -> tensor<32x{{2|1}}x4x{{2|4}}xbf16> // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"] @@ -69,10 +69,10 @@ func.func @prepacked_matmul(%pack: tensor<4x4x32x32xbf16>, %pack_0: tensor<4x4x3 // CHECK-SAME: %[[ARG0:.+]]: tensor<4x4x32x32xbf16>, %[[ARG1:.+]]: tensor<4x4x32x32xbf16>, // CHECK-SAME: %[[ARG2:.+]]: tensor<4x4x32x32xbf16> // CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]] -// CHECK-SAME: output_shape [4, 4, 32, 16, 2] : tensor<4x4x32x32xbf16> into tensor<4x4x32x16x2xbf16> -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x4x16x32x2xbf16> -// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] inner_dims_pos = [2] inner_tiles = [2] into %[[EMPTY]] -// CHECK-SAME: : tensor<4x4x32x32xbf16> -> tensor<4x4x16x32x2xbf16> +// CHECK-SAME: output_shape{{.*}}: tensor<4x4x32x32xbf16> into tensor<4x4x32x{{16|8}}x{{2|4}}xbf16> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x4x{{16|8}}x32x{{2|4}}xbf16> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] inner_dims_pos = [2] inner_tiles = [{{2|4}}] into %[[EMPTY]] +// CHECK-SAME: : tensor<4x4x32x32xbf16> -> tensor<4x4x{{16|8}}x32x{{2|4}}xbf16> // CHECK: {{.+}} = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "reduction"] diff --git a/test/BF16/matmul-untiled-vnni.mlir b/test/BF16/matmul-untiled-vnni.mlir index 2609ca90f..7a47d9b07 100644 --- a/test/BF16/matmul-untiled-vnni.mlir +++ b/test/BF16/matmul-untiled-vnni.mlir @@ -26,7 +26,7 @@ func.func @blocked_matmul(%arg0: tensor<32x64x4x4xbf16>, %arg1: tensor<128x64x4x // CHECK: %[[ARG0:.*]]: tensor<32x64x4x4xbf16>, // CHECK: %[[ARG1:.*]]: tensor<128x64x4x4xbf16>, // CHECK: %[[ARG2:.*]]: tensor<32x128x4x4xbf16>) -> tensor<32x128x4x4xbf16> { -// CHECK: %[[PACKBUF:.*]] = tensor.empty() : tensor<128x64x2x4x2xbf16> +// CHECK: %[[PACKBUF:.*]] = tensor.empty() : tensor<128x64x{{2|1}}x4x{{2|4}}xbf16> // CHECK: linalg.generic // CHECK: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] // CHECK: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "reduction"] diff --git a/test/BF16/matmul-vnni.mlir b/test/BF16/matmul-vnni.mlir index 2d4a5ffda..24e83a8b3 100644 --- a/test/BF16/matmul-vnni.mlir +++ b/test/BF16/matmul-vnni.mlir @@ -25,17 +25,17 @@ func.func @matmul_static( // CHECK: %[[PACK_0:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] // CHECK-SAME: into %{{.+}} : tensor<512x1024xbf16> -> tensor<32x16x32x32xbf16> // CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1], [2], [3, 4]] -// CHECK-SAME: output_shape [8, 16, 32, 16, 2] : tensor<8x16x32x32xbf16> into tensor<8x16x32x16x2xbf16> -// CHECK: %[[EMPTY_2:.+]] = tensor.empty() : tensor<32x16x16x32x2xbf16> -// CHECK: %[[PACK_1:.+]] = tensor.pack %[[PACK_0]] inner_dims_pos = [2] inner_tiles = [2] into %[[EMPTY_2]] -// CHECK-SAME: : tensor<32x16x32x32xbf16> -> tensor<32x16x16x32x2xbf16> +// CHECK-SAME: output_shape{{.*}}: tensor<8x16x32x32xbf16> into tensor<8x16x32x{{16|8}}x{{2|4}}xbf16> +// CHECK: %[[EMPTY_2:.+]] = tensor.empty() : tensor<32x16x{{16|8}}x32x{{2|4}}xbf16> +// CHECK: %[[PACK_1:.+]] = tensor.pack %[[PACK_0]] inner_dims_pos = [2] inner_tiles = [{{2|4}}] into %[[EMPTY_2]] +// CHECK-SAME: : tensor<32x16x32x32xbf16> -> tensor<32x16x{{16|8}}x32x{{2|4}}xbf16> // CHECK: %{{.+}} = scf.forall (%[[ARG3:.+]], %[[ARG4:.+]]) in (8, 32) shared_outs(%[[ARG5:.+]] = %[[ARG2]]) // CHECK: %[[APPLY:.+]] = affine.apply #[[MAP]](%[[ARG3]]) // CHECK: %[[APPLY_1:.+]] = affine.apply #[[MAP]](%[[ARG4]]) -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[VNNI_A]][%[[ARG3]], 0, 0, 0, 0] [1, 16, 32, 16, 2] [1, 1, 1, 1, 1] -// CHECK-SAME: : tensor<8x16x32x16x2xbf16> to tensor<16x32x16x2xbf16> -// CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[PACK_1]][%[[ARG4]], 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] -// CHECK-SAME: : tensor<32x16x16x32x2xbf16> to tensor<16x16x32x2xbf16> +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[VNNI_A]][%[[ARG3]], 0, 0, 0, 0] [1, 16, 32, {{16|8}}, {{2|4}}] [1, 1, 1, 1, 1] +// CHECK-SAME: : tensor<8x16x32x{{16|8}}x{{2|4}}xbf16> to tensor<16x32x{{16|8}}x{{2|4}}xbf16> +// CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[PACK_1]][%[[ARG4]], 0, 0, 0, 0] [1, 16, {{16|8}}, 32, {{2|4}}] [1, 1, 1, 1, 1] +// CHECK-SAME: : tensor<32x16x{{16|8}}x32x{{2|4}}xbf16> to tensor<16x{{16|8}}x32x{{2|4}}xbf16> // CHECK: %[[SLICE_3:.+]] = tensor.extract_slice %[[ARG5]][%[[APPLY]], %[[APPLY_1]]] [32, 32] [1, 1] // CHECK-SAME: : tensor<256x1024xbf16> to tensor<32x32xbf16> // CHECK: %[[GEMM:.+]] = linalg.generic From 4a1c34c543978368e1f64fa3d616917fdf7bc71d Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 15 Jan 2025 12:37:08 +0100 Subject: [PATCH 04/26] Enable BF16 tests only on x86 arch --- test/BF16/Integration/lit.local.cfg | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/BF16/Integration/lit.local.cfg b/test/BF16/Integration/lit.local.cfg index 48448ac5c..fefbd9c07 100644 --- a/test/BF16/Integration/lit.local.cfg +++ b/test/BF16/Integration/lit.local.cfg @@ -19,6 +19,28 @@ def has_support(feature): return True +def is_arch(target): + # Arch detection not working on Windows + if sys.platform in ['win32']: + return False + + try: + cmd = subprocess.Popen( + ['uname', '-m'], stdout=subprocess.PIPE) + except OSError: + return False + + out = cmd.stdout.read().decode('ascii') + cmd.wait() + + return out == target + + # AVX512 and SVE should have BF16 support if not has_support('avx512') and not has_support('avx2') and not has_support('sve'): config.unsupported = True +# Enable only on x86 +# Other targets may use different VNNI blocking scheme that is not compatible with +# prepacked shapes in some of the tests +if not is_arch("x86_64"): + config.unsupported = True From 2e8b1983fdc7cae2093ac8dc5d60e0b2bf0acf4e Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 15 Jan 2025 13:05:27 +0100 Subject: [PATCH 05/26] Silence diffs --- test/Integration/hoist-vector-transfer-brgemm.mlir | 2 +- test/Integration/vector-contract-to-fma.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/Integration/hoist-vector-transfer-brgemm.mlir b/test/Integration/hoist-vector-transfer-brgemm.mlir index 37190b68c..3a1bab701 100644 --- a/test/Integration/hoist-vector-transfer-brgemm.mlir +++ b/test/Integration/hoist-vector-transfer-brgemm.mlir @@ -1,6 +1,6 @@ // RUN: tpp-run -e entry --entry-point-result=void -print %s > %t.1 // RUN: tpp-opt %s --loop-invariant-code-motion --vectorization-pass --loop-invariant-code-motion --hoist-vector-transfer | tpp-run -e entry --entry-point-result=void -print > %t.2 -// RUN: diff %t.1 %t.2 +// RUN: diff -q %t.1 %t.2 // RUN: rm %t.1 %t.2 module { diff --git a/test/Integration/vector-contract-to-fma.mlir b/test/Integration/vector-contract-to-fma.mlir index 4d03e8bb8..0005db730 100644 --- a/test/Integration/vector-contract-to-fma.mlir +++ b/test/Integration/vector-contract-to-fma.mlir @@ -1,6 +1,6 @@ // RUN: tpp-opt %s | tpp-run -e entry --entry-point-result=void -seed 123 -print > %t.1 // RUN: tpp-opt %s --vector-contract-to-fma | tpp-run -e entry --entry-point-result=void -seed 123 -print > %t.2 -// RUN: diff %t.1 %t.2 +// RUN: diff -q %t.1 %t.2 // RUN: rm %t.1 %t.2 // DIFF-NOT: {{.}} From d9539d4f0800867438bd81cd90af9a112f8fa242 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 15 Jan 2025 13:26:04 +0100 Subject: [PATCH 06/26] Use blocking factor --- lib/TPP/Transforms/ToBlockLayoutAndBack.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp index 52c0a5f17..3011af560 100644 --- a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp +++ b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp @@ -414,7 +414,8 @@ mlir::linalgx::packVNNIBRGemmOp(RewriterBase &rewriter, return rewriter.notifyMatchFailure(brgemmOp, "unsupported blocking factor for type"); } - SmallVector tilesOnK = {rewriter.getI64IntegerAttr(2)}; + SmallVector tilesOnK = { + rewriter.getI64IntegerAttr(*blockingFactor)}; Location loc = brgemmOp.getLoc(); // Reshape input A. From 55ea6a476559cc6d63a54902ac0f86af0562268d Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 15 Jan 2025 13:27:57 +0100 Subject: [PATCH 07/26] Fix test --- test/BF16/brgemm-tpp.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/BF16/brgemm-tpp.mlir b/test/BF16/brgemm-tpp.mlir index 08408e57b..7ab922e62 100644 --- a/test/BF16/brgemm-tpp.mlir +++ b/test/BF16/brgemm-tpp.mlir @@ -15,7 +15,7 @@ func.func @brgemm(%arg0: tensor<32x4x4xbf16>, %arg1: tensor<32x4x4xbf16>, // CHECK-SAME: %[[ARG0:.+]]: tensor<32x4x4xbf16>, %[[ARG1:.+]]: tensor<32x4x4xbf16>, // CHECK-SAME: %[[ARG2:.+]]: tensor<4x4xbf16> // CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] output_shape{{.*}}: tensor<32x4x4xbf16> into tensor<32x4x{{2|1}}x{{2|4}}xbf16> -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x2x4x{{2|4}}xbf16> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x{{2|1}}x4x{{2|4}}xbf16> // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] inner_dims_pos = [1] inner_tiles = [{{2|4}}] // CHECK-SAME: into %[[EMPTY]] : tensor<32x4x4xbf16> -> tensor<32x{{2|1}}x4x{{2|4}}xbf16> // CHECK: %{{.+}} = linalg.generic From 3c6c6e676eff595ca708cbc3320262002c725686 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 15 Jan 2025 13:35:37 +0100 Subject: [PATCH 08/26] Generalize VNNI check --- lib/TPP/IR/MatcherUtils.cpp | 2 +- lib/TPP/Transforms/ToBlockLayoutAndBack.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/TPP/IR/MatcherUtils.cpp b/lib/TPP/IR/MatcherUtils.cpp index c74e6d5af..2622277b3 100644 --- a/lib/TPP/IR/MatcherUtils.cpp +++ b/lib/TPP/IR/MatcherUtils.cpp @@ -115,7 +115,7 @@ std::pair isMatmulVnniOp(linalg::GenericOp linalgOp, // At this point, the operation is a valid matmul contraction. // Finally, ensure that it is in VNNI layout. - bool isVnniMatmul = vnni::utils::isInVnniLayout(linalgOp, *blockingFactor); + bool isVnniMatmul = vnni::utils::isInVnniLayout(linalgOp); return std::make_pair(isVnniMatmul, hasBatch); } diff --git a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp index 3011af560..f23865846 100644 --- a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp +++ b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp @@ -339,7 +339,7 @@ mlir::linalgx::packVNNIMatmulOp(RewriterBase &rewriter, "unsupported blocking factor for type"); } - if (vnni::utils::isInVnniLayout(matmulOp, *blockingFactor)) { + if (vnni::utils::isInVnniLayout(matmulOp)) { return rewriter.notifyMatchFailure(matmulOp, "already packed to VNNI"); } From eb42726b4bb0f412bf8651485688d968528390e3 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 15 Jan 2025 13:54:35 +0100 Subject: [PATCH 09/26] Hook ops for DLTI vnni factor --- .../Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp | 7 ++++--- .../Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp | 2 +- lib/TPP/IR/MatcherUtils.cpp | 4 ++-- lib/TPP/Transforms/ToBlockLayoutAndBack.cpp | 5 +++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp b/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp index e1e003694..efc1ab5c6 100644 --- a/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp +++ b/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp @@ -1068,8 +1068,9 @@ struct ConvertVnniPacking : public OpRewritePattern { if (failed(stridesOnOutput) || stridesOnOutput->back() != 1) return failure(); // Ajust ldo based on the VNNI factor. - unaryInfo.ldo = stridesOnOutput->front() / - *vnni::utils::getVnniBlockingFactor(out.getType()); + unaryInfo.ldo = + stridesOnOutput->front() / + *vnni::utils::getVnniBlockingFactor(out.getType(), transposeOp); auto flags = rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get( rewriter.getContext(), xsmm::UnaryFlags::NONE)); xsmm::UnaryKindAttr kind = @@ -1112,7 +1113,7 @@ struct ConvertGenericToVnniMatmulLikeOp // Take the whole reduction dim size. Account for the VNNI factor (ensured // by the earlier check) that splits the K dim in the shape. std::optional vnniFactor = - vnni::utils::getVnniBlockingFactor(bufferB.getType()); + vnni::utils::getVnniBlockingFactor(bufferB.getType(), genericOp); if (!vnniFactor) return rewriter.notifyMatchFailure(genericOp, "failed to determine VNNI factor"); diff --git a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp index 04c1f6fdb..e309eb384 100644 --- a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp +++ b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp @@ -99,7 +99,7 @@ convertTransposeOp(PatternRewriter &rewriter, Operation *transposeOp, if (vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::TRANSPOSE, outType)) { // Adjust ldo based on vnni factor - auto vnniFactor = *vnni::utils::getVnniBlockingFactor(outType); + auto vnniFactor = *vnni::utils::getVnniBlockingFactor(outType, transposeOp); unaryInfo.ldo = unaryInfo.ldo / vnniFactor; } else { std::swap(unaryInfo.m, unaryInfo.n); diff --git a/lib/TPP/IR/MatcherUtils.cpp b/lib/TPP/IR/MatcherUtils.cpp index 2622277b3..a0d045364 100644 --- a/lib/TPP/IR/MatcherUtils.cpp +++ b/lib/TPP/IR/MatcherUtils.cpp @@ -40,8 +40,8 @@ getIteratorPos(linalg::LinalgOp linalgOp, AffineMap indexingMap, std::pair isMatmulVnniOp(linalg::GenericOp linalgOp, SmallVectorImpl *operands) { bool hasBatch = false; - auto blockingFactor = - vnni::utils::getVnniBlockingFactor(linalgOp->getOperands()[0].getType()); + auto blockingFactor = vnni::utils::getVnniBlockingFactor( + linalgOp->getOperands()[0].getType(), linalgOp); if (!blockingFactor) return std::make_pair(false, hasBatch); diff --git a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp index f23865846..f547d392d 100644 --- a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp +++ b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp @@ -333,7 +333,7 @@ mlir::linalgx::packVNNIMatmulOp(RewriterBase &rewriter, OpOperand &operandB = matmulOp->getOpOperand(1); auto blockingFactor = - vnni::utils::getVnniBlockingFactor(operandB.get().getType()); + vnni::utils::getVnniBlockingFactor(operandB.get().getType(), matmulOp); if (!blockingFactor) { return rewriter.notifyMatchFailure(matmulOp, "unsupported blocking factor for type"); @@ -409,7 +409,8 @@ mlir::linalgx::packVNNIBRGemmOp(RewriterBase &rewriter, Value operandB = brgemmOp.getInputs()[1]; // Blocking factor on the `k` dimension. - auto blockingFactor = vnni::utils::getVnniBlockingFactor(operandB.getType()); + auto blockingFactor = + vnni::utils::getVnniBlockingFactor(operandB.getType(), brgemmOp); if (!blockingFactor) { return rewriter.notifyMatchFailure(brgemmOp, "unsupported blocking factor for type"); From b36ee9bb2bd05cde1ffca34114985a776f48f9c0 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 15 Jan 2025 13:55:29 +0100 Subject: [PATCH 10/26] Refactor --- lib/TPP/Transforms/Utils/VNNIUtils.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/TPP/Transforms/Utils/VNNIUtils.cpp b/lib/TPP/Transforms/Utils/VNNIUtils.cpp index 0715b2c0a..1bebc72d0 100644 --- a/lib/TPP/Transforms/Utils/VNNIUtils.cpp +++ b/lib/TPP/Transforms/Utils/VNNIUtils.cpp @@ -38,9 +38,9 @@ std::optional getVnniBlockingFactor(Type type, Operation *op) { auto deviceSpec = sysSpec.getDeviceSpecForDeviceID(deviceId); if (!deviceSpec) return std::nullopt; - auto tileSizeId = StringAttr::get(moduleOp->getContext(), "vnni"); + auto vnniId = StringAttr::get(moduleOp->getContext(), "vnni"); DataLayoutEntryInterface entry = - (*deviceSpec).getSpecForIdentifier(tileSizeId); + (*deviceSpec).getSpecForIdentifier(vnniId); if (!entry) return std::nullopt; Attribute value = entry.getValue(); From b90f258d39e13a39ff7431238b7cdf7f807e0606 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 15 Jan 2025 14:04:07 +0100 Subject: [PATCH 11/26] Fix factor for lit tests --- .../LinalgToXsmm/linalg-to-brgemm.mlir | 32 +- .../LinalgToXsmm/linalg-to-gemm.mlir | 276 ++++++++++-------- .../LinalgToXsmm/linalg-to-unary.mlir | 75 +++-- 3 files changed, 221 insertions(+), 162 deletions(-) diff --git a/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir b/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir index 647a5881d..df2fc9451 100644 --- a/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir +++ b/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir @@ -285,21 +285,25 @@ func.func @simple_brgemm(%arg0: memref<2x32x32xf32>, %arg1: memref<2x32x32xf32>, #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> - -func.func @vnni_brgemm_interchanged(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { - %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [16, 32, 16, 2] - : memref<16x32x32xbf16> into memref<16x32x16x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<16x32x16x2xbf16>, memref<16x16x32x2xbf16>) - outs(%arg2 : memref<32x32xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_brgemm_interchanged(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { + %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [16, 32, 16, 2] + : memref<16x32x32xbf16> into memref<16x32x16x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<16x32x16x2xbf16>, memref<16x16x32x2xbf16>) + outs(%arg2 : memref<32x32xbf16>) { + ^bb0(%in: bf16, %in_5: bf16, %out: bf16): + %5 = arith.mulf %in, %in_5 : bf16 + %6 = arith.addf %out, %5 : bf16 + linalg.yield %6 : bf16 + } + return } - return } // CHECK-LABEL: vnni_brgemm_interchanged diff --git a/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir b/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir index de4d88ad1..027404311 100644 --- a/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir +++ b/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir @@ -149,22 +149,28 @@ func.func @mha_projection(%arg0: memref<512x8x64xf32>, %arg1: memref<64x32x512xf #map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)> #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> - -func.func @square_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] - : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +// Fix VNNI blocking factor for lit testing. +// Prevent mismatches due to target specific VNNI factors. +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @square_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, + %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] + : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: square_vnni_gemm @@ -179,20 +185,24 @@ func.func @square_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ? #map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)> #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> - -func.func @expanded_arg_vnni_gemm(%arg0: memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @expanded_arg_vnni_gemm(%arg0: memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, + %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: expanded_arg_vnni_gemm @@ -211,21 +221,26 @@ func.func @expanded_arg_vnni_gemm(%arg0: memref<64x32x2xbf16, strided<[64, 2, 1] #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d1)> // Require a transpose on C, before mapping to vnni Gemm. -func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] - : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, + %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] + : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: expect_not_to_match_vnni_gemm @@ -239,21 +254,26 @@ func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> // Not VNNI layout. A factor of 5 is not VNNI. -func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x160xbf16, strided<[160, 1], offset: ?>>, - %arg1: memref<32x64x5xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 5] - : memref<64x160xbf16, strided<[160, 1], offset: ?>> into memref<64x32x5xbf16, strided<[160, 5, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<64x32x5xbf16, strided<[160, 5, 1], offset: ?>>, memref<32x64x5xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x160xbf16, strided<[160, 1], offset: ?>>, + %arg1: memref<32x64x5xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 5] + : memref<64x160xbf16, strided<[160, 1], offset: ?>> into memref<64x32x5xbf16, strided<[160, 5, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<64x32x5xbf16, strided<[160, 5, 1], offset: ?>>, memref<32x64x5xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: expect_not_to_match_vnni_gemm @@ -267,19 +287,24 @@ func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x160xbf16, strided<[160 #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d1)> // Require a transpose on A, before mapping to vnni Gemm. -func.func @expect_not_to_match_vnni_gemm(%arg0: memref<32x64x2xbf16, strided<[128, 2, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<32x64x2xbf16, strided<[128, 2, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @expect_not_to_match_vnni_gemm(%arg0: memref<32x64x2xbf16, strided<[128, 2, 1], offset: ?>>, + %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1 : memref<32x64x2xbf16, strided<[128, 2, 1], offset: ?>>, memref<32x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: expect_not_to_match_vnni_gemm @@ -294,21 +319,26 @@ func.func @expect_not_to_match_vnni_gemm(%arg0: memref<32x64x2xbf16, strided<[12 // Make sure we can handle interchange on the iterators, but with the right // access patterns. -func.func @vnni_gemm_interchanged(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] - : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction", "reduction"]} - ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_gemm_interchanged(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, + %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] + : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: vnni_gemm_interchanged @@ -353,21 +383,26 @@ func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> -func.func @non_square_vnni_gemm(%arg0: memref<64x16xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<8x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 8, 2] - : memref<64x16xbf16, strided<[64, 1], offset: ?>> into memref<64x8x2xbf16, strided<[64, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<64x8x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<8x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @non_square_vnni_gemm(%arg0: memref<64x16xbf16, strided<[64, 1], offset: ?>>, + %arg1: memref<8x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 8, 2] + : memref<64x16xbf16, strided<[64, 1], offset: ?>> into memref<64x8x2xbf16, strided<[64, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<64x8x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<8x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: non_square_vnni_gemm @@ -383,21 +418,26 @@ func.func @non_square_vnni_gemm(%arg0: memref<64x16xbf16, strided<[64, 1], offse #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> -func.func @non_square_vnni_gemm_1(%arg0: memref<4x16xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<8x64x2xbf16>, %arg2: memref<4x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 8, 2] - : memref<4x16xbf16, strided<[64, 1], offset: ?>> into memref<4x8x2xbf16, strided<[64, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<4x8x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<8x64x2xbf16>) - outs(%arg2 : memref<4x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @non_square_vnni_gemm_1(%arg0: memref<4x16xbf16, strided<[64, 1], offset: ?>>, + %arg1: memref<8x64x2xbf16>, %arg2: memref<4x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 8, 2] + : memref<4x16xbf16, strided<[64, 1], offset: ?>> into memref<4x8x2xbf16, strided<[64, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<4x8x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<8x64x2xbf16>) + outs(%arg2 : memref<4x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: non_square_vnni_gemm_1 diff --git a/test/Conversion/LinalgToXsmm/linalg-to-unary.mlir b/test/Conversion/LinalgToXsmm/linalg-to-unary.mlir index 217491ebe..0f880eea0 100644 --- a/test/Conversion/LinalgToXsmm/linalg-to-unary.mlir +++ b/test/Conversion/LinalgToXsmm/linalg-to-unary.mlir @@ -295,14 +295,19 @@ func.func @identity_3(%arg0: memref<128x1xf32>, %arg1: memref<128x512xf32>) { // ----- -func.func @vnni_packing(%arg0 : memref<32x32xbf16, strided<[512, 1], offset: ?>>, - %arg1: memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) { - %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape[16, 2, 32] - : memref<32x32xbf16, strided<[512, 1], offset: ?>> - into memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>> - linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>>) - outs(%arg1 : memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) permutation = [0, 2, 1] - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_packing(%arg0 : memref<32x32xbf16, strided<[512, 1], offset: ?>>, + %arg1: memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) { + %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape[16, 2, 32] + : memref<32x32xbf16, strided<[512, 1], offset: ?>> + into memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>> + linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>>) + outs(%arg1 : memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) permutation = [0, 2, 1] + return + } } // CHECK-LABEL: vnni_packing @@ -313,14 +318,19 @@ func.func @vnni_packing(%arg0 : memref<32x32xbf16, strided<[512, 1], offset: ?>> // ----- -func.func @not_vnni_packing(%arg0 : memref<32x32xf32, strided<[512, 1], offset: ?>>, - %arg1: memref<16x32x2xf32, strided<[64, 2, 1], offset: ?>>) { - %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape[16, 2, 32] - : memref<32x32xf32, strided<[512, 1], offset: ?>> - into memref<16x2x32xf32, strided<[1024, 512, 1], offset: ?>> - linalg.transpose ins(%expand_shape : memref<16x2x32xf32, strided<[1024, 512, 1], offset: ?>>) - outs(%arg1 : memref<16x32x2xf32, strided<[64, 2, 1], offset: ?>>) permutation = [0, 2, 1] - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @not_vnni_packing(%arg0 : memref<32x32xf32, strided<[512, 1], offset: ?>>, + %arg1: memref<16x32x2xf32, strided<[64, 2, 1], offset: ?>>) { + %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape[16, 2, 32] + : memref<32x32xf32, strided<[512, 1], offset: ?>> + into memref<16x2x32xf32, strided<[1024, 512, 1], offset: ?>> + linalg.transpose ins(%expand_shape : memref<16x2x32xf32, strided<[1024, 512, 1], offset: ?>>) + outs(%arg1 : memref<16x32x2xf32, strided<[64, 2, 1], offset: ?>>) permutation = [0, 2, 1] + return + } } // CHECK-LABEL: not_vnni_packing @@ -351,21 +361,26 @@ func.func @identity_4(%arg0: memref<1024xbf16>, %arg1: memref<128x1024xbf16>) { #map = affine_map<(d0) -> (d0 * 32)> -func.func @vnni_packing_1(%arg1: memref<128x128xbf16>, %arg2: memref<4x4x16x32x2xbf16>) { - scf.forall (%arg3, %arg4) in (4, 4) { - %0 = affine.apply #map(%arg4) - %1 = affine.apply #map(%arg3) - %subview = memref.subview %arg1[%0, %1] [32, 32] [1, 1] - : memref<128x128xbf16> to memref<32x32xbf16, strided<[128, 1], offset: ?>> - %subview_1 = memref.subview %arg2[%arg3, %arg4, 0, 0, 0] [1, 1, 16, 32, 2] [1, 1, 1, 1, 1] - : memref<4x4x16x32x2xbf16> to memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> - %expand_shape = memref.expand_shape %subview [[0, 1], [2]] output_shape[16, 2, 32] - : memref<32x32xbf16, strided<[128, 1], offset: ?>> into memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>> - linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>>) - outs(%subview_1 : memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) - permutation = [0, 2, 1] +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_packing_1(%arg1: memref<128x128xbf16>, %arg2: memref<4x4x16x32x2xbf16>) { + scf.forall (%arg3, %arg4) in (4, 4) { + %0 = affine.apply #map(%arg4) + %1 = affine.apply #map(%arg3) + %subview = memref.subview %arg1[%0, %1] [32, 32] [1, 1] + : memref<128x128xbf16> to memref<32x32xbf16, strided<[128, 1], offset: ?>> + %subview_1 = memref.subview %arg2[%arg3, %arg4, 0, 0, 0] [1, 1, 16, 32, 2] [1, 1, 1, 1, 1] + : memref<4x4x16x32x2xbf16> to memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> + %expand_shape = memref.expand_shape %subview [[0, 1], [2]] output_shape[16, 2, 32] + : memref<32x32xbf16, strided<[128, 1], offset: ?>> into memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>> + linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>>) + outs(%subview_1 : memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) + permutation = [0, 2, 1] + } + return } - return } // CHECK: #[[MAP:.+]] = affine_map<(d0) -> (d0 * 32)> From e7abc16e567d66b443d68271e64d9a482f1a66d7 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 15 Jan 2025 14:47:17 +0100 Subject: [PATCH 12/26] Further dlti tests --- .../LinalgToXsmm/linalg-to-brgemm.mlir | 97 +++++++++++-------- .../LinalgToXsmm/linalg-to-gemm.mlir | 35 ++++--- 2 files changed, 76 insertions(+), 56 deletions(-) diff --git a/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir b/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir index df2fc9451..a83787d6c 100644 --- a/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir +++ b/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir @@ -320,20 +320,25 @@ module attributes { #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> -func.func @vnni_brgemm(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { - %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [16, 32, 16, 2] - : memref<16x32x32xbf16> into memref<16x32x16x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} - ins(%expanded, %arg1 : memref<16x32x16x2xbf16>, memref<16x16x32x2xbf16>) - outs(%arg2 : memref<32x32xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_brgemm(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { + %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [16, 32, 16, 2] + : memref<16x32x32xbf16> into memref<16x32x16x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%expanded, %arg1 : memref<16x32x16x2xbf16>, memref<16x16x32x2xbf16>) + outs(%arg2 : memref<32x32xbf16>) { + ^bb0(%in: bf16, %in_5: bf16, %out: bf16): + %5 = arith.mulf %in, %in_5 : bf16 + %6 = arith.addf %out, %5 : bf16 + linalg.yield %6 : bf16 + } + return } - return } // CHECK-LABEL: vnni_brgemm @@ -350,22 +355,27 @@ func.func @vnni_brgemm(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> -func.func @vnni_brgemm_strided(%arg0: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, - %arg1: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>, - %arg2: memref<8x8xbf16>) { - %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [8, 8, 4, 2] - : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>> into memref<8x8x4x2xbf16, strided<[64, 8, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} - ins(%expanded, %arg1 : memref<8x8x4x2xbf16, strided<[64, 8, 2, 1], offset: ?>>, memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>) - outs(%arg2 : memref<8x8xbf16>) { - ^bb0(%in: bf16, %in_9: bf16, %out: bf16): - %11 = arith.mulf %in, %in_9 : bf16 - %12 = arith.addf %out, %11 : bf16 - linalg.yield %12 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_brgemm_strided(%arg0: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, + %arg1: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>, + %arg2: memref<8x8xbf16>) { + %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [8, 8, 4, 2] + : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>> into memref<8x8x4x2xbf16, strided<[64, 8, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%expanded, %arg1 : memref<8x8x4x2xbf16, strided<[64, 8, 2, 1], offset: ?>>, memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>) + outs(%arg2 : memref<8x8xbf16>) { + ^bb0(%in: bf16, %in_9: bf16, %out: bf16): + %11 = arith.mulf %in, %in_9 : bf16 + %12 = arith.addf %out, %11 : bf16 + linalg.yield %12 : bf16 + } + return } - return } // CHECK-LABEL: vnni_brgemm_strided @@ -383,20 +393,25 @@ func.func @vnni_brgemm_strided(%arg0: memref<8x8x8xbf16, strided<[64, 8, 1], off #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d2)> -func.func @vnni_brgemm_require_transpose_on_C(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { - %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [16, 32, 16, 2] - : memref<16x32x32xbf16> into memref<16x32x16x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<16x32x16x2xbf16>, memref<16x16x32x2xbf16>) - outs(%arg2 : memref<32x32xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_brgemm_require_transpose_on_C(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { + %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [16, 32, 16, 2] + : memref<16x32x32xbf16> into memref<16x32x16x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<16x32x16x2xbf16>, memref<16x16x32x2xbf16>) + outs(%arg2 : memref<32x32xbf16>) { + ^bb0(%in: bf16, %in_5: bf16, %out: bf16): + %5 = arith.mulf %in, %in_5 : bf16 + %6 = arith.addf %out, %5 : bf16 + linalg.yield %6 : bf16 + } + return } - return } // CHECK-LABEL: vnni_brgemm_require_transpose_on_C diff --git a/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir b/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir index 027404311..c4e1e9eca 100644 --- a/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir +++ b/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir @@ -355,21 +355,26 @@ module attributes { #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d1)> // Not VNNI layout. The VNNI is not innermost in the access pattern for B. -func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<2x64x32xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] - : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<2x64x32xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, + %arg1: memref<2x64x32xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] + : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<2x64x32xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: expect_not_to_match_vnni_gemm From d5d5ef53b4c7a0569b4541a12f8e92357ca1fa5b Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 15 Jan 2025 15:10:18 +0100 Subject: [PATCH 13/26] Hook dlti in more places --- include/TPP/Transforms/Utils/VNNIUtils.h | 13 +++++++--- lib/TPP/Dialect/Xsmm/XsmmOps.cpp | 11 +++++---- lib/TPP/Dialect/Xsmm/XsmmVerify.cpp | 9 ++++--- lib/TPP/Transforms/Utils/VNNIUtils.cpp | 30 ++++++++++++++++-------- 4 files changed, 43 insertions(+), 20 deletions(-) diff --git a/include/TPP/Transforms/Utils/VNNIUtils.h b/include/TPP/Transforms/Utils/VNNIUtils.h index 343e4d8be..fdda42d4e 100644 --- a/include/TPP/Transforms/Utils/VNNIUtils.h +++ b/include/TPP/Transforms/Utils/VNNIUtils.h @@ -42,12 +42,19 @@ std::optional getVnniBlockingFactor(Type type, Operation *op = nullptr); // Return true if the memref is in VNNI layout with rank `expectedRank`. -bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref); +// Optionally, the check can be constrained to a specific VNNI blocking factor. +bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref, + std::optional blockingFactor = std::nullopt); // Return true if the vector is in VNNI layout with rank `expectedRank`. -bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector); +// Optionally, the check can be constrained to a specific VNNI blocking factor. +bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector, + std::optional blockingFactor = std::nullopt); -bool isInVnniLayout(int64_t expectedRank, VectorType vector); +// Return true if the vector is in VNNI layout with rank `expectedRank`. +// Optionally, the check can be constrained to a specific VNNI blocking factor. +bool isInVnniLayout(int64_t expectedRank, VectorType vector, + std::optional blockingFactor = std::nullopt); // Return true if the operation is in VNNI layout. // Optionally, the check can be constrained to a specific VNNI blocking factor. diff --git a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp b/lib/TPP/Dialect/Xsmm/XsmmOps.cpp index c1e3209a9..f69667051 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmOps.cpp @@ -455,10 +455,13 @@ LogicalResult GemmOp::verify() { auto memref = dyn_cast(memrefOperands[idx].getType()); assert(memref && (memref.getRank() == 2 || memref.getRank() == 3)); - if (memref.getRank() == 3 && - !vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::GEMM, - memref)) { - return emitOpError() << "expect VNNI layout for operand: " << actualIdx; + if (memref.getRank() == 3) { + auto vnniFactor = vnni::utils::getVnniBlockingFactor(memref); + if (!vnniFactor || (*vnniFactor) % 2 != 0 || + !vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::GEMM, + memref)) { + return emitOpError() << "expect VNNI layout for operand: " << actualIdx; + } } } return success(); diff --git a/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp b/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp index 5d5abc45f..df8071ba8 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp @@ -71,21 +71,24 @@ static LogicalResult verifyGemmDispatchAndInvokeLikeOp(InvokeTy gemmOp) { : vnni::utils::VnniOperandRank::GEMM; // VNNI flags must be consistent with the memref shapes. + auto vnniFactor = vnni::utils::getVnniBlockingFactor(operandA, gemmOp); ArrayAttr flags = dispatchOp->getFlags(); for (auto flag : flags) { int64_t gemmFlag = cast(flag).getInt(); if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_A) && - !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA)) { + !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA, + vnniFactor)) { return gemmOp.emitOpError( "expect VNNI layout for operand A or invalid VNNI_A flags"); } if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_B) && - !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB)) { + !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB, + vnniFactor)) { return gemmOp.emitOpError( "expect VNNI layout for operand B or invalid VNNI_B flags"); } if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_C) && - !vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC)) { + !vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC, vnniFactor)) { return gemmOp.emitOpError( "expect VNNI layout for operand C or invalid VNNI_C flags"); } diff --git a/lib/TPP/Transforms/Utils/VNNIUtils.cpp b/lib/TPP/Transforms/Utils/VNNIUtils.cpp index 1bebc72d0..41d00ca6c 100644 --- a/lib/TPP/Transforms/Utils/VNNIUtils.cpp +++ b/lib/TPP/Transforms/Utils/VNNIUtils.cpp @@ -59,12 +59,16 @@ std::optional getVnniBlockingFactor(Type type, Operation *op) { // Until we have a better way to express the VNNI layout (see: #563), it is up // to the callee to specify the expected rank in the VNNI layout as the rank // depends on the operations we are dealing with. -bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref) { +bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref, + std::optional blockingFactor) { if (memref.getRank() != static_cast(expectedRank) || - !memref.getElementType().isBF16()) { + !memref.getElementType().isBF16()) return false; - } - return memref.getShape().back() == vnni::utils::getVnniBlockingFactor(memref); + + if (blockingFactor && memref.getShape().back() != *blockingFactor) + return false; + + return true; } bool isInVnniLayout(linalg::LinalgOp linalgOp, @@ -138,15 +142,21 @@ bool isInVnniLayout(linalg::LinalgOp linalgOp, return true; } -bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector) { - return isInVnniLayout(static_cast(expectedRank), vector); +bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector, + std::optional blockingFactor) { + return isInVnniLayout(static_cast(expectedRank), vector, + blockingFactor); } -bool isInVnniLayout(int64_t expectedRank, VectorType vector) { - if (vector.getRank() != expectedRank || !vector.getElementType().isBF16()) { +bool isInVnniLayout(int64_t expectedRank, VectorType vector, + std::optional blockingFactor) { + if (vector.getRank() != expectedRank || !vector.getElementType().isBF16()) return false; - } - return vector.getShape().back() == vnni::utils::getVnniBlockingFactor(vector); + + if (blockingFactor && vector.getShape().back() != *blockingFactor) + return false; + + return true; } } // namespace utils From cee0e00be646972b53c7d170bb49a4398c0fbf99 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 13:35:21 +0100 Subject: [PATCH 14/26] Update tests --- .../VectorToXsmm/vector-to-transpose.mlir | 55 ++-- test/Dialect/Xsmm/xsmm-dispatch-invoke.mlir | 15 +- test/Integration/transpose-bf16.mlir | 19 +- test/Passes/DefaultPipeline/linalg.mlir | 63 ++-- test/Passes/DefaultPipeline/vnni.mlir | 272 ++++++++++-------- test/Passes/DefaultPipeline/xsmm.mlir | 82 +++--- test/Passes/xsmm-combine.mlir | 202 +++++++------ 7 files changed, 388 insertions(+), 320 deletions(-) diff --git a/test/Conversion/VectorToXsmm/vector-to-transpose.mlir b/test/Conversion/VectorToXsmm/vector-to-transpose.mlir index 57af7099c..99610d1b8 100644 --- a/test/Conversion/VectorToXsmm/vector-to-transpose.mlir +++ b/test/Conversion/VectorToXsmm/vector-to-transpose.mlir @@ -41,14 +41,19 @@ func.func @transpose_op_3d_f32(%arg0: memref<5x3x5xf32>, %arg1: memref<5x5x3xf32 // CHECK-NOT: call @xsmm_unary_invoke // ----- -func.func @vnni_packing_2d_bf16(%arg0: memref<32x32xbf16, strided<[512, 1], offset: ?>>, %arg1: memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) { - %cst = arith.constant 0.000000e+00 : bf16 - %c0 = arith.constant 0 : index - %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [16, 2, 32] : memref<32x32xbf16, strided<[512, 1], offset: ?>> into memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>> - %0 = vector.transfer_read %expand_shape[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>>, vector<16x2x32xbf16> - %1 = vector.transpose %0, [0, 2, 1] : vector<16x2x32xbf16> to vector<16x32x2xbf16> - vector.transfer_write %1, %arg1[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<16x32x2xbf16>, memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_packing_2d_bf16(%arg0: memref<32x32xbf16, strided<[512, 1], offset: ?>>, %arg1: memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) { + %cst = arith.constant 0.000000e+00 : bf16 + %c0 = arith.constant 0 : index + %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [16, 2, 32] : memref<32x32xbf16, strided<[512, 1], offset: ?>> into memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>> + %0 = vector.transfer_read %expand_shape[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>>, vector<16x2x32xbf16> + %1 = vector.transpose %0, [0, 2, 1] : vector<16x2x32xbf16> to vector<16x32x2xbf16> + vector.transfer_write %1, %arg1[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<16x32x2xbf16>, memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> + return + } } // CHECK-LABEL: func.func @vnni_packing_2d_bf16( @@ -87,20 +92,25 @@ func.func @not_vnni_packing_2d_f32(%arg0: memref<32x32xf32, strided<[512, 1], of // ----- #map = affine_map<(d0) -> (d0 * 32)> -func.func @vnni_packing_2d_bf16_forall(%arg0: memref<128x128xbf16>, %arg1: memref<4x4x16x32x2xbf16>) { - %cst = arith.constant 0.000000e+00 : bf16 - %c0 = arith.constant 0 : index - scf.forall (%arg2, %arg3) in (4, 4) { - %0 = affine.apply #map(%arg3) - %1 = affine.apply #map(%arg2) - %subview = memref.subview %arg0[%0, %1] [32, 32] [1, 1] : memref<128x128xbf16> to memref<32x32xbf16, strided<[128, 1], offset: ?>> - %subview_0 = memref.subview %arg1[%arg2, %arg3, 0, 0, 0] [1, 1, 16, 32, 2] [1, 1, 1, 1, 1] : memref<4x4x16x32x2xbf16> to memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> - %expand_shape = memref.expand_shape %subview [[0, 1], [2]] output_shape [16, 2, 32] : memref<32x32xbf16, strided<[128, 1], offset: ?>> into memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>> - %2 = vector.transfer_read %expand_shape[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>>, vector<16x2x32xbf16> - %3 = vector.transpose %2, [0, 2, 1] : vector<16x2x32xbf16> to vector<16x32x2xbf16> - vector.transfer_write %3, %subview_0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<16x32x2xbf16>, memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_packing_2d_bf16_forall(%arg0: memref<128x128xbf16>, %arg1: memref<4x4x16x32x2xbf16>) { + %cst = arith.constant 0.000000e+00 : bf16 + %c0 = arith.constant 0 : index + scf.forall (%arg2, %arg3) in (4, 4) { + %0 = affine.apply #map(%arg3) + %1 = affine.apply #map(%arg2) + %subview = memref.subview %arg0[%0, %1] [32, 32] [1, 1] : memref<128x128xbf16> to memref<32x32xbf16, strided<[128, 1], offset: ?>> + %subview_0 = memref.subview %arg1[%arg2, %arg3, 0, 0, 0] [1, 1, 16, 32, 2] [1, 1, 1, 1, 1] : memref<4x4x16x32x2xbf16> to memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> + %expand_shape = memref.expand_shape %subview [[0, 1], [2]] output_shape [16, 2, 32] : memref<32x32xbf16, strided<[128, 1], offset: ?>> into memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>> + %2 = vector.transfer_read %expand_shape[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>>, vector<16x2x32xbf16> + %3 = vector.transpose %2, [0, 2, 1] : vector<16x2x32xbf16> to vector<16x32x2xbf16> + vector.transfer_write %3, %subview_0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<16x32x2xbf16>, memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> + } + return + } } // CHECK-LABEL: func.func @vnni_packing_2d_bf16_forall( @@ -126,4 +136,3 @@ func.func @vnni_packing_2d_bf16_forall(%arg0: memref<128x128xbf16>, %arg1: memre // CHECK-NEXT: %[[indexCast2:.*]] = arith.index_cast %[[intptr0]] // CHECK-NEXT: %[[inttoptr2:.*]] = llvm.inttoptr %[[indexCast2]] // CHECK: func.call @xsmm_unary_invoke(%[[c2_i64]], %[[dispatch]], %[[inttoptr]], %[[offset]], %[[inttoptr2]], %[[offset_1]]) - diff --git a/test/Dialect/Xsmm/xsmm-dispatch-invoke.mlir b/test/Dialect/Xsmm/xsmm-dispatch-invoke.mlir index 353408670..eec22627b 100644 --- a/test/Dialect/Xsmm/xsmm-dispatch-invoke.mlir +++ b/test/Dialect/Xsmm/xsmm-dispatch-invoke.mlir @@ -30,9 +30,14 @@ func.func @identity(%arg0: f32, %arg1: memref<1x1xf32>) { // ----- -func.func @gemm(%arg0: memref<3x6x2xbf16>, %arg1: memref<6x6xbf16>) { - %0 = xsmm.gemm.dispatch [6, 6, 6, 6, 6, 6] flags = (vnni_a) data_type = bf16 - xsmm.gemm(data_type = bf16, %0, %arg0, %arg0, %arg1) : - (i64, memref<3x6x2xbf16>, memref<3x6x2xbf16>, memref<6x6xbf16>) -> () - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @gemm(%arg0: memref<3x6x2xbf16>, %arg1: memref<6x6xbf16>) { + %0 = xsmm.gemm.dispatch [6, 6, 6, 6, 6, 6] flags = (vnni_a) data_type = bf16 + xsmm.gemm(data_type = bf16, %0, %arg0, %arg0, %arg1) : + (i64, memref<3x6x2xbf16>, memref<3x6x2xbf16>, memref<6x6xbf16>) -> () + return + } } diff --git a/test/Integration/transpose-bf16.mlir b/test/Integration/transpose-bf16.mlir index d4f4472a1..862f055a1 100644 --- a/test/Integration/transpose-bf16.mlir +++ b/test/Integration/transpose-bf16.mlir @@ -3,13 +3,18 @@ // RUN: tpp-opt --default-tpp-passes="vector-to-xsmm" %s -mlir-print-ir-after=vectorization-pass 2>&1 | FileCheck %s --check-prefix=VECTOR // RUN: tpp-run --vector-to-XSMM %s -e entry -entry-point-result=void -print-mlir=mid 2>&1 | FileCheck %s --check-prefix=XSMM -func.func @entry(%arg0 : tensor<4x4xbf16>, %arg1 : tensor<2x4x2xbf16>)-> tensor<2x4x2xbf16> { - %expand_shape = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape[2, 2, 4] - : tensor<4x4xbf16> - into tensor<2x2x4xbf16> - %retval = linalg.transpose ins(%expand_shape : tensor<2x2x4xbf16>) - outs(%arg1 : tensor<2x4x2xbf16>) permutation = [0, 2, 1] - return %retval: tensor<2x4x2xbf16> +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @entry(%arg0 : tensor<4x4xbf16>, %arg1 : tensor<2x4x2xbf16>)-> tensor<2x4x2xbf16> { + %expand_shape = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape[2, 2, 4] + : tensor<4x4xbf16> + into tensor<2x2x4xbf16> + %retval = linalg.transpose ins(%expand_shape : tensor<2x2x4xbf16>) + outs(%arg1 : tensor<2x4x2xbf16>) permutation = [0, 2, 1] + return %retval: tensor<2x4x2xbf16> + } } // VECTOR: vector.transfer_read diff --git a/test/Passes/DefaultPipeline/linalg.mlir b/test/Passes/DefaultPipeline/linalg.mlir index e32309ef1..0c7f18f78 100644 --- a/test/Passes/DefaultPipeline/linalg.mlir +++ b/test/Passes/DefaultPipeline/linalg.mlir @@ -207,36 +207,41 @@ func.func @brgemm(%arg0: memref<2x3x4xf32>, %arg1: memref<2x4x3xf32>, %arg2: mem // CHECK-SAME: %[[ARG0:.+]]: memref<64x4x4xbf16>, // CHECK-SAME: %[[ARG1:.+]]: memref<64x2x4x2xbf16>, // CHECK-SAME: %[[ARG2:.+]]: memref<4x4xbf16> -func.func @brgemm_bf16(%arg0: memref<64x4x4xbf16>, %arg1: memref<64x2x4x2xbf16>, - %arg2: memref<4x4xbf16>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_brgemm_dispatch - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] - %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [64, 4, 2, 2] - : memref<64x4x4xbf16> into memref<64x4x2x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} - ins(%expanded, %arg1 : memref<64x4x2x2xbf16>, memref<64x2x4x2xbf16>) - outs(%arg2 : memref<4x4xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @brgemm_bf16(%arg0: memref<64x4x4xbf16>, %arg1: memref<64x2x4x2xbf16>, + %arg2: memref<4x4xbf16>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: call @xsmm_brgemm_dispatch + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + + // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] + %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [64, 4, 2, 2] + : memref<64x4x4xbf16> into memref<64x4x2x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%expanded, %arg1 : memref<64x4x2x2xbf16>, memref<64x2x4x2xbf16>) + outs(%arg2 : memref<4x4xbf16>) { + ^bb0(%in: bf16, %in_5: bf16, %out: bf16): + %5 = arith.mulf %in, %in_5 : bf16 + %6 = arith.addf %out, %5 : bf16 + linalg.yield %6 : bf16 + } + return } - return } // ----- diff --git a/test/Passes/DefaultPipeline/vnni.mlir b/test/Passes/DefaultPipeline/vnni.mlir index ae54cde9e..7cebef489 100644 --- a/test/Passes/DefaultPipeline/vnni.mlir +++ b/test/Passes/DefaultPipeline/vnni.mlir @@ -8,39 +8,44 @@ // CHECK-SAME: %[[ARG0:.+]]: memref<128x1024xbf16>, // CHECK-SAME: %[[ARG1:.+]]: memref<512x2048x2xbf16>, // CHECK-SAME: %[[ARG2:.+]]: memref<128x2048xbf16>) -func.func @matmul_tensor(%arg0: tensor<128x1024xbf16>, - %arg1: tensor<512x2048x2xbf16>, - %arg2: tensor<128x2048xbf16>) -> tensor<128x2048xbf16> { - // CHECK: %[[of:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_gemm_dispatch - - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] - %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [128, 512, 2] - : tensor<128x1024xbf16> into tensor<128x512x2xbf16> - %result = linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : tensor<128x512x2xbf16>, tensor<512x2048x2xbf16>) - outs(%arg2 : tensor<128x2048xbf16>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } -> tensor<128x2048xbf16> - - return %result : tensor<128x2048xbf16> +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @matmul_tensor(%arg0: tensor<128x1024xbf16>, + %arg1: tensor<512x2048x2xbf16>, + %arg2: tensor<128x2048xbf16>) -> tensor<128x2048xbf16> { + // CHECK: %[[of:.*]] = arith.constant 0 : index + // CHECK: call @xsmm_gemm_dispatch + + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + + // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] + %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [128, 512, 2] + : tensor<128x1024xbf16> into tensor<128x512x2xbf16> + %result = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : tensor<128x512x2xbf16>, tensor<512x2048x2xbf16>) + outs(%arg2 : tensor<128x2048xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } -> tensor<128x2048xbf16> + + return %result : tensor<128x2048xbf16> + } } // ----- @@ -53,38 +58,43 @@ func.func @matmul_tensor(%arg0: tensor<128x1024xbf16>, // CHECK-SAME: %[[ARG0:.+]]: memref<128x1024xbf16>, // CHECK-SAME: %[[ARG1:.+]]: memref<512x2048x2xbf16>, // CHECK-SAME: %[[ARG2:.+]]: memref<128x2048xbf16>) -func.func @matmul_memref(%arg0: memref<128x1024xbf16>, - %arg1: memref<512x2048x2xbf16>, - %arg2: memref<128x2048xbf16>) -> memref<128x2048xbf16> { - // CHECK: call @xsmm_gemm_dispatch - - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [128, 512, 2] - : memref<128x1024xbf16> into memref<128x512x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<128x512x2xbf16>, memref<512x2048x2xbf16>) - outs(%arg2 : memref<128x2048xbf16>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @matmul_memref(%arg0: memref<128x1024xbf16>, + %arg1: memref<512x2048x2xbf16>, + %arg2: memref<128x2048xbf16>) -> memref<128x2048xbf16> { + // CHECK: call @xsmm_gemm_dispatch + + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + + // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [128, 512, 2] + : memref<128x1024xbf16> into memref<128x512x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<128x512x2xbf16>, memref<512x2048x2xbf16>) + outs(%arg2 : memref<128x2048xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + + return %arg2 : memref<128x2048xbf16> } - - return %arg2 : memref<128x2048xbf16> } // ----- @@ -97,39 +107,44 @@ func.func @matmul_memref(%arg0: memref<128x1024xbf16>, // CHECK: %[[ARG0:.+]]: memref<4x256x512xbf16>, // CHECK: %[[ARG1:.+]]: memref<4x512x1024xbf16>, // CHECK: %[[ARG2:.+]]: memref<256x1024xbf16>) -func.func @brgemm_static_tensor(%arg0: tensor<4x256x512xbf16>, %arg1: tensor<4x512x1024xbf16>, %arg2: tensor<256x1024xbf16>) -> tensor<256x1024xbf16> { - // CHECK: %[[alloc:.*]] = memref.alloc{{.*}}: memref<4x256x1024x2xbf16> - %0 = tensor.empty() : tensor<4x256x1024x2xbf16> - %1 = tensor.pack %arg1 inner_dims_pos = [1] inner_tiles = [2] into %0 : tensor<4x512x1024xbf16> -> tensor<4x256x1024x2xbf16> - - // CHECK: call @xsmm_brgemm_dispatch - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[alloc]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] - %expanded = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [4, 256, 256, 2] - : tensor<4x256x512xbf16> into tensor<4x256x256x2xbf16> - %2 = linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} - ins(%expanded, %1 : tensor<4x256x256x2xbf16>, tensor<4x256x1024x2xbf16>) - outs(%arg2 : tensor<256x1024xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 - } -> tensor<256x1024xbf16> - - return %2 : tensor<256x1024xbf16> +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @brgemm_static_tensor(%arg0: tensor<4x256x512xbf16>, %arg1: tensor<4x512x1024xbf16>, %arg2: tensor<256x1024xbf16>) -> tensor<256x1024xbf16> { + // CHECK: %[[alloc:.*]] = memref.alloc{{.*}}: memref<4x256x1024x2xbf16> + %0 = tensor.empty() : tensor<4x256x1024x2xbf16> + %1 = tensor.pack %arg1 inner_dims_pos = [1] inner_tiles = [2] into %0 : tensor<4x512x1024xbf16> -> tensor<4x256x1024x2xbf16> + + // CHECK: call @xsmm_brgemm_dispatch + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[alloc]] + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + + // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] + %expanded = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [4, 256, 256, 2] + : tensor<4x256x512xbf16> into tensor<4x256x256x2xbf16> + %2 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%expanded, %1 : tensor<4x256x256x2xbf16>, tensor<4x256x1024x2xbf16>) + outs(%arg2 : tensor<256x1024xbf16>) { + ^bb0(%in: bf16, %in_5: bf16, %out: bf16): + %5 = arith.mulf %in, %in_5 : bf16 + %6 = arith.addf %out, %5 : bf16 + linalg.yield %6 : bf16 + } -> tensor<256x1024xbf16> + + return %2 : tensor<256x1024xbf16> + } } // ----- @@ -142,34 +157,39 @@ func.func @brgemm_static_tensor(%arg0: tensor<4x256x512xbf16>, %arg1: tensor<4x5 // CHECK: %[[ARG0:.+]]: memref<4x256x512xbf16>, // CHECK: %[[ARG1:.+]]: memref<4x256x1024x2xbf16>, // CHECK: %[[ARG2:.+]]: memref<256x1024xbf16>) -func.func @brgemm_static_memref(%arg0: memref<4x256x512xbf16>, %arg1: memref<4x256x1024x2xbf16>, %arg2: memref<256x1024xbf16>) -> memref<256x1024xbf16> { - // CHECK: call @xsmm_brgemm_dispatch - - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] - %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [4, 256, 256, 2] - : memref<4x256x512xbf16> into memref<4x256x256x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} - ins(%expanded, %arg1 : memref<4x256x256x2xbf16>, memref<4x256x1024x2xbf16>) - outs(%arg2 : memref<256x1024xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @brgemm_static_memref(%arg0: memref<4x256x512xbf16>, %arg1: memref<4x256x1024x2xbf16>, %arg2: memref<256x1024xbf16>) -> memref<256x1024xbf16> { + // CHECK: call @xsmm_brgemm_dispatch + + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + + // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] + %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [4, 256, 256, 2] + : memref<4x256x512xbf16> into memref<4x256x256x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%expanded, %arg1 : memref<4x256x256x2xbf16>, memref<4x256x1024x2xbf16>) + outs(%arg2 : memref<256x1024xbf16>) { + ^bb0(%in: bf16, %in_5: bf16, %out: bf16): + %5 = arith.mulf %in, %in_5 : bf16 + %6 = arith.addf %out, %5 : bf16 + linalg.yield %6 : bf16 + } + + return %arg2 : memref<256x1024xbf16> } - - return %arg2 : memref<256x1024xbf16> } diff --git a/test/Passes/DefaultPipeline/xsmm.mlir b/test/Passes/DefaultPipeline/xsmm.mlir index bee500e22..76fbb1976 100644 --- a/test/Passes/DefaultPipeline/xsmm.mlir +++ b/test/Passes/DefaultPipeline/xsmm.mlir @@ -220,30 +220,35 @@ func.func @brgemm(%arg0: memref<2x3x4xf32>, %arg1: memref<2x4x3xf32>, %arg2: mem // CHECK-SAME: %[[ARG0:.+]]: memref<64x4x4xbf16>, // CHECK-SAME: %[[ARG1:.+]]: memref<64x2x4x2xbf16>, // CHECK-SAME: %[[ARG2:.+]]: memref<4x4xbf16> -func.func @brgemm_bf16(%arg0: memref<64x4x4xbf16>, %arg1: memref<64x2x4x2xbf16>, - %arg2: memref<4x4xbf16>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_brgemm_dispatch +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @brgemm_bf16(%arg0: memref<64x4x4xbf16>, %arg1: memref<64x2x4x2xbf16>, + %arg2: memref<4x4xbf16>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: call @xsmm_brgemm_dispatch - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x4x4xbf16> -> index - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x4x4xbf16> -> index + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<64x2x4x2xbf16> -> index - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<64x2x4x2xbf16> -> index + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] : memref<4x4xbf16> -> index - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] : memref<4x4xbf16> -> index + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] - %c64_i64 = arith.constant 64 : i64 - %0 = xsmm.brgemm.dispatch [4, 4, 4, 4, 4, 4, 16, 16] flags = (vnni_b) data_type = bf16 - xsmm.brgemm(data_type = bf16, %0, %arg0, %arg1, %arg2, %c64_i64) - : (i64, memref<64x4x4xbf16>, memref<64x2x4x2xbf16>, memref<4x4xbf16>, i64) -> () + // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] + %c64_i64 = arith.constant 64 : i64 + %0 = xsmm.brgemm.dispatch [4, 4, 4, 4, 4, 4, 16, 16] flags = (vnni_b) data_type = bf16 + xsmm.brgemm(data_type = bf16, %0, %arg0, %arg1, %arg2, %c64_i64) + : (i64, memref<64x4x4xbf16>, memref<64x2x4x2xbf16>, memref<4x4xbf16>, i64) -> () - return + return + } } // ----- @@ -282,28 +287,33 @@ func.func @gemm(%A: memref<4x8xf32>, // CHECK-SAME: %[[ARG0:.+]]: memref<6x10xbf16>, // CHECK-SAME: %[[ARG1:.+]]: memref<5x6x2xbf16>, // CHECK-SAME: %[[ARG2:.+]]: memref<6x6xbf16> -func.func @gemm_bf16(%arg0: memref<6x10xbf16>, %arg1: memref<5x6x2xbf16>, - %arg2: memref<6x6xbf16>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_gemm_dispatch +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @gemm_bf16(%arg0: memref<6x10xbf16>, %arg1: memref<5x6x2xbf16>, + %arg2: memref<6x6xbf16>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: call @xsmm_gemm_dispatch - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] - %0 = xsmm.gemm.dispatch [6, 6, 10, 10, 6, 6] flags = (vnni_b) data_type = bf16 - xsmm.gemm(data_type = bf16, %0, %arg0, %arg1, %arg2) : (i64, memref<6x10xbf16>, memref<5x6x2xbf16>, memref<6x6xbf16>) -> () + // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] + %0 = xsmm.gemm.dispatch [6, 6, 10, 10, 6, 6] flags = (vnni_b) data_type = bf16 + xsmm.gemm(data_type = bf16, %0, %arg0, %arg1, %arg2) : (i64, memref<6x10xbf16>, memref<5x6x2xbf16>, memref<6x6xbf16>) -> () - return + return + } } // ----- diff --git a/test/Passes/xsmm-combine.mlir b/test/Passes/xsmm-combine.mlir index 8edf7a2bc..e8d6ad5c5 100644 --- a/test/Passes/xsmm-combine.mlir +++ b/test/Passes/xsmm-combine.mlir @@ -133,39 +133,44 @@ func.func @none_on_binary_add(%arg0: memref<256x128xf32>) -> memref<256x512xf32> // ----- -memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_32xbf16: memref<32xbf16, strided<[32], offset:?>> = dense<1.000000e+00> {alignment = 128 : i64} +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} + memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} + memref.global "private" constant @__constant_32xbf16: memref<32xbf16, strided<[32], offset:?>> = dense<1.000000e+00> {alignment = 128 : i64} -// Bcast_col_in0 flag set on binary add -func.func @bcast_col_in0_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c4_i64 = arith.constant 4 : i64 - %c8_i64 = arith.constant 8 : i64 - %cst = arith.constant 0.000000e+00 : bf16 - %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> - %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> - %2 = memref.get_global @__constant_32xbf16 : memref<32xbf16, strided<[32], offset:?>> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> - %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> - %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 - %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in0) data_type = bf16 - %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { - %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - xsmm.binary add(data_type = bf16, %5, %2, %subview, %subview) : (i64, memref<32xbf16, strided<[32], offset:?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - scf.reduce + // Bcast_col_in0 flag set on binary add + func.func @bcast_col_in0_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %cst = arith.constant 0.000000e+00 : bf16 + %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> + %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> + %2 = memref.get_global @__constant_32xbf16 : memref<32xbf16, strided<[32], offset:?>> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> + %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> + %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 + %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in0) data_type = bf16 + %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { + %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + xsmm.binary add(data_type = bf16, %5, %2, %subview, %subview) : (i64, memref<32xbf16, strided<[32], offset:?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () + xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () + scf.reduce + } + return %alloc_1 : memref<256x512xbf16> } - return %alloc_1 : memref<256x512xbf16> } // CHECK-LABEL: func.func @bcast_col_in0_on_binary_add_bf16( @@ -176,39 +181,44 @@ func.func @bcast_col_in0_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memr // ----- -memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_32xbf16: memref<32xbf16, strided<[32], offset:?>> = dense<1.000000e+00> {alignment = 128 : i64} +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} + memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} + memref.global "private" constant @__constant_32xbf16: memref<32xbf16, strided<[32], offset:?>> = dense<1.000000e+00> {alignment = 128 : i64} -// Bcast_col_in1 flag set on binary add -func.func @bcast_col_in1_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c4_i64 = arith.constant 4 : i64 - %c8_i64 = arith.constant 8 : i64 - %cst = arith.constant 0.000000e+00 : bf16 - %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> - %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> - %2 = memref.get_global @__constant_32xbf16 : memref<32xbf16, strided<[32], offset:?>> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> - %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> - %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 - %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in1) data_type = bf16 - %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { - %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - xsmm.binary add(data_type = bf16, %5, %subview, %2, %subview) : (i64 , memref<32x32xbf16, strided<[32, 1], offset: ?>>,memref<32xbf16, strided<[32], offset:?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - scf.reduce + // Bcast_col_in1 flag set on binary add + func.func @bcast_col_in1_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %cst = arith.constant 0.000000e+00 : bf16 + %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> + %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> + %2 = memref.get_global @__constant_32xbf16 : memref<32xbf16, strided<[32], offset:?>> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> + %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> + %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 + %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in1) data_type = bf16 + %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { + %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + xsmm.binary add(data_type = bf16, %5, %subview, %2, %subview) : (i64 , memref<32x32xbf16, strided<[32, 1], offset: ?>>,memref<32xbf16, strided<[32], offset:?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () + xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () + scf.reduce + } + return %alloc_1 : memref<256x512xbf16> } - return %alloc_1 : memref<256x512xbf16> } // CHECK-LABEL: func.func @bcast_col_in1_on_binary_add_bf16( @@ -220,39 +230,44 @@ func.func @bcast_col_in1_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memr // ----- -memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_32x32xbf16: memref<32x32xbf16> = dense<1.000000e+00> {alignment = 128 : i64} +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} + memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} + memref.global "private" constant @__constant_32x32xbf16: memref<32x32xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -// None flag set on binary add -func.func @none_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c4_i64 = arith.constant 4 : i64 - %c8_i64 = arith.constant 8 : i64 - %cst = arith.constant 0.000000e+00 : bf16 - %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> - %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> - %2 = memref.get_global @__constant_32x32xbf16 : memref<32x32xbf16> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> - %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> - %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 - %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (none) data_type = bf16 - %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { - %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - xsmm.binary add(data_type = bf16, %5, %subview, %2, %subview) : (i64 , memref<32x32xbf16, strided<[32, 1], offset: ?>>,memref<32x32xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - scf.reduce + // None flag set on binary add + func.func @none_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %cst = arith.constant 0.000000e+00 : bf16 + %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> + %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> + %2 = memref.get_global @__constant_32x32xbf16 : memref<32x32xbf16> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> + %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> + %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 + %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (none) data_type = bf16 + %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { + %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + xsmm.binary add(data_type = bf16, %5, %subview, %2, %subview) : (i64 , memref<32x32xbf16, strided<[32, 1], offset: ?>>,memref<32x32xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () + xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () + scf.reduce + } + return %alloc_1 : memref<256x512xbf16> } - return %alloc_1 : memref<256x512xbf16> } // CHECK-LABEL: func.func @none_on_binary_add_bf16( @@ -305,4 +320,3 @@ func.func @forward(%arg0: memref<256x1024xf32>) -> memref<256x1024xf32> { // CHECK: xsmm.fused_brgemm(data_type = f32, %[[temp2]], %[[subview_2]], %{{.*}}, %[[subview]], %{{.*}} %[[c32_i64]]) : (i64, memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32xf32>, i64) -> () // CHECK: } // CHECK: return %{{.*}} : memref<256x1024xf32> - From 28eaa453eca2285de7ed5b84e92356172a924dd3 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 13:35:29 +0100 Subject: [PATCH 15/26] Use fpcmp in test --- test/Integration/vector-contract-to-fma.mlir | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/Integration/vector-contract-to-fma.mlir b/test/Integration/vector-contract-to-fma.mlir index 0005db730..c13b340a8 100644 --- a/test/Integration/vector-contract-to-fma.mlir +++ b/test/Integration/vector-contract-to-fma.mlir @@ -1,9 +1,7 @@ // RUN: tpp-opt %s | tpp-run -e entry --entry-point-result=void -seed 123 -print > %t.1 // RUN: tpp-opt %s --vector-contract-to-fma | tpp-run -e entry --entry-point-result=void -seed 123 -print > %t.2 -// RUN: diff -q %t.1 %t.2 -// RUN: rm %t.1 %t.2 +// RUN: fpcmp -r 0.001 %t.1 %t.2 -// DIFF-NOT: {{.}} #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> From 1706f4b155fce63b213b9394ff5e9acec956b815 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 13:35:32 +0100 Subject: [PATCH 16/26] Formatting --- test/BF16/Integration/lit.local.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/test/BF16/Integration/lit.local.cfg b/test/BF16/Integration/lit.local.cfg index fefbd9c07..495067c38 100644 --- a/test/BF16/Integration/lit.local.cfg +++ b/test/BF16/Integration/lit.local.cfg @@ -39,6 +39,7 @@ def is_arch(target): # AVX512 and SVE should have BF16 support if not has_support('avx512') and not has_support('avx2') and not has_support('sve'): config.unsupported = True + # Enable only on x86 # Other targets may use different VNNI blocking scheme that is not compatible with # prepacked shapes in some of the tests From a1e54ce79937addd44cb2b13983cedaa0e935c67 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 13:43:24 +0100 Subject: [PATCH 17/26] Fix arch detection --- test/BF16/Integration/lit.local.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/BF16/Integration/lit.local.cfg b/test/BF16/Integration/lit.local.cfg index 495067c38..5f963c2b5 100644 --- a/test/BF16/Integration/lit.local.cfg +++ b/test/BF16/Integration/lit.local.cfg @@ -33,7 +33,7 @@ def is_arch(target): out = cmd.stdout.read().decode('ascii') cmd.wait() - return out == target + return target in out # AVX512 and SVE should have BF16 support @@ -43,5 +43,5 @@ if not has_support('avx512') and not has_support('avx2') and not has_support('sv # Enable only on x86 # Other targets may use different VNNI blocking scheme that is not compatible with # prepacked shapes in some of the tests -if not is_arch("x86_64"): +if not is_arch('x86'): config.unsupported = True From ad31e0f1f3875ff5f0456ad96b4963847cc6a571 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 13:59:58 +0100 Subject: [PATCH 18/26] Restore deleted tests --- test/BF16/Integration/matmul-pbf16.mlir | 50 +++++++ .../BF16/Integration/mlp-all-bf16-tpprun.mlir | 137 ++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 test/BF16/Integration/matmul-pbf16.mlir create mode 100644 test/BF16/Integration/mlp-all-bf16-tpprun.mlir diff --git a/test/BF16/Integration/matmul-pbf16.mlir b/test/BF16/Integration/matmul-pbf16.mlir new file mode 100644 index 000000000..f2434271d --- /dev/null +++ b/test/BF16/Integration/matmul-pbf16.mlir @@ -0,0 +1,50 @@ +// RUN: tpp-run %s -print \ +// RUN: -e entry -entry-point-result=void | \ +// RUN: FileCheck %s + +#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +func.func @matmultpp(%A: memref<4x8xbf16>, + %B: memref<4x4x2xbf16>, %C: memref<4x4xbf16>) { + %expanded = memref.expand_shape %A [[0], [1, 2]] output_shape [4, 4, 2] + : memref<4x8xbf16> into memref<4x4x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %B : memref<4x4x2xbf16>, memref<4x4x2xbf16>) + outs(%C : memref<4x4xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return +} + +func.func @entry() { + %c0 = arith.constant 0 : index + %f0 = arith.constant 1.0 : bf16 + %da = memref.alloc() :memref<4x8xbf16> + linalg.fill ins(%f0 : bf16) outs(%da : memref<4x8xbf16>) + // Call kernel. + %0 = memref.alloc() : memref<4x4x2xbf16> + linalg.fill ins(%f0:bf16) outs (%0: memref<4x4x2xbf16>) + %D = memref.alloc() : memref<4x4xbf16> + %zero = arith.constant 0.0 : bf16 + linalg.fill ins(%zero : bf16) outs(%D:memref<4x4xbf16>) + call @matmultpp(%da, %0, %D) + : (memref<4x8xbf16>, memref<4x4x2xbf16>, memref<4x4xbf16>)->() + + // + // CHECK:( ( 8, 8, 8, 8 ), ( 8, 8, 8, 8 ), ( 8, 8, 8, 8 ), ( 8, 8, 8, 8 ) ) + // + %d1 = arith.constant -1.0 : bf16 + + %v0 = vector.transfer_read %D[%c0, %c0], %d1 : memref<4x4xbf16>, vector<4x4xbf16> + %f1 = arith.extf %v0:vector<4x4xbf16> to vector<4x4xf32> + vector.print %f1 : vector<4x4xf32> + + return +} diff --git a/test/BF16/Integration/mlp-all-bf16-tpprun.mlir b/test/BF16/Integration/mlp-all-bf16-tpprun.mlir new file mode 100644 index 000000000..5f7968719 --- /dev/null +++ b/test/BF16/Integration/mlp-all-bf16-tpprun.mlir @@ -0,0 +1,137 @@ +// RUN: tpp-run %s \ +// RUN: -e entry -entry-point-result=void + +memref.global "private" constant @arg1 : memref<128x512x2xbf16> = dense<1.00e+00> +memref.global "private" constant @arg3 : memref<256x1024x2xbf16> = dense<1.00e+00> +memref.global "private" constant @arg5 : memref<512x2048x2xbf16> = dense<1.00e+00> +memref.global "private" constant @arg7 : memref<1024x1000x2xbf16> = dense<1.00e+00> + +#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +#map3 = affine_map<(d0, d1) -> (d0, d1)> +#map4 = affine_map<(d0, d1) -> (d1)> + +func.func @entry(%arg0: memref<128x256xbf16>, %arg2: memref<512xbf16>, %arg4: memref<1024xbf16>, + %arg6: memref<2048xbf16>, %arg8: memref<1000xbf16>, %arg9: memref<128x512xbf16>, + %arg10: memref<128x1024xbf16>, %arg11: memref<128x2048xbf16>, %arg12: memref<128x1000xbf16>) { + %c0 = arith.constant 0.0 : bf16 + linalg.generic { + indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg2: memref<512xbf16>) outs(%arg9: memref<128x512xbf16>) { + ^bb0(%in: bf16, %out: bf16): + linalg.yield %in : bf16 + } + + %e0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [128, 128, 2] + : memref<128x256xbf16> into memref<128x128x2xbf16> + %relayout_arg0 = memref.get_global @arg1:memref<128x512x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%e0, %relayout_arg0 : memref<128x128x2xbf16>, memref<128x512x2xbf16>) + outs(%arg9 : memref<128x512xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + linalg.generic { + indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg9 : memref<128x512xbf16>) outs(%arg9 : memref<128x512xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %2 = arith.maximumf %in, %c0 : bf16 + linalg.yield %2 : bf16 + } + + linalg.generic { + indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg4: memref<1024xbf16>) outs(%arg10: memref<128x1024xbf16>) { + ^bb0(%in: bf16, %out: bf16): + linalg.yield %in : bf16 + } + + %e1 = memref.expand_shape %arg9 [[0], [1, 2]] output_shape [128, 256, 2] + : memref<128x512xbf16> into memref<128x256x2xbf16> + %relayout_arg12 = memref.get_global @arg3:memref<256x1024x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%e1, %relayout_arg12 : memref<128x256x2xbf16>, memref<256x1024x2xbf16>) + outs(%arg10 : memref<128x1024xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + linalg.generic { + indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg10 : memref<128x1024xbf16>) outs(%arg10 : memref<128x1024xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %2 = arith.maximumf %in, %c0 : bf16 + linalg.yield %2 : bf16 + } + + linalg.generic { + indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg6: memref<2048xbf16>) outs(%arg11: memref<128x2048xbf16>) { + ^bb0(%in: bf16, %out: bf16): + linalg.yield %in : bf16 + } + + %relayout_arg11 = memref.get_global @arg5:memref<512x2048x2xbf16> + %e2 = memref.expand_shape %arg10 [[0], [1, 2]] output_shape [128, 512, 2] + : memref<128x1024xbf16> into memref<128x512x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%e2, %relayout_arg11 : memref<128x512x2xbf16>, memref<512x2048x2xbf16>) + outs(%arg11 : memref<128x2048xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + linalg.generic { + indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg11 : memref<128x2048xbf16>) outs(%arg11 : memref<128x2048xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %2 = arith.maximumf %in, %c0 : bf16 + linalg.yield %2 : bf16 + } + + linalg.generic { + indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg8: memref<1000xbf16>) outs(%arg12: memref<128x1000xbf16>) { + ^bb0(%in: bf16, %out: bf16): + linalg.yield %in : bf16 + } + + %relayout_arg10 = memref.get_global @arg7:memref<1024x1000x2xbf16> + %e3 = memref.expand_shape %arg11 [[0], [1, 2]] output_shape [128, 1024, 2] + : memref<128x2048xbf16> into memref<128x1024x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%e3, %relayout_arg10 : memref<128x1024x2xbf16>, memref<1024x1000x2xbf16>) + outs(%arg12 : memref<128x1000xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + linalg.generic { + indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg12 : memref<128x1000xbf16>) outs(%arg12 : memref<128x1000xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %2 = arith.maximumf %in, %c0 : bf16 + linalg.yield %2 : bf16 + } + + %threshold = arith.constant 1.0 : bf16 + %c4 = arith.constant 2.74878e+11: bf16 + %interim4 = memref.alloc(): memref<128x1000xbf16> + linalg.fill ins(%c4:bf16) outs(%interim4: memref<128x1000xbf16>) + check.expect_almost_eq(%interim4, %arg12, %threshold): memref<128x1000xbf16>, memref<128x1000xbf16>, bf16 + return +} From 76bcc85b45730ee50c2597db989df9447a8aa010 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 17:00:06 +0100 Subject: [PATCH 19/26] Add DLTI helper --- include/TPP/Transforms/Utils/DLTIUtils.h | 30 ++++++++++++++++++++++ lib/TPP/Transforms/Utils/CMakeLists.txt | 1 + lib/TPP/Transforms/Utils/DLTIUtils.cpp | 32 ++++++++++++++++++++++++ 3 files changed, 63 insertions(+) create mode 100644 include/TPP/Transforms/Utils/DLTIUtils.h create mode 100644 lib/TPP/Transforms/Utils/DLTIUtils.cpp diff --git a/include/TPP/Transforms/Utils/DLTIUtils.h b/include/TPP/Transforms/Utils/DLTIUtils.h new file mode 100644 index 000000000..3e1e17ddc --- /dev/null +++ b/include/TPP/Transforms/Utils/DLTIUtils.h @@ -0,0 +1,30 @@ +//===- DLTIUtils.h -----------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_TRANSFORMS_UTILS_DLTIUTILS_H +#define TPP_TRANSFORMS_UTILS_DLTIUTILS_H + +#include "mlir/Dialect/DLTI/DLTI.h" + +namespace llvm { +class StringRef; +} // namespace llvm + +namespace mlir { +namespace dlti { +namespace utils { + +// Perform a DLTI-query using string keys. +FailureOr query(Operation *op, ArrayRef keys, + bool emitError = false); + +} // namespace utils +} // namespace dlti +} // namespace mlir + +#endif // TPP_TRANSFORMS_UTILS_DLTIUTILS_H diff --git a/lib/TPP/Transforms/Utils/CMakeLists.txt b/lib/TPP/Transforms/Utils/CMakeLists.txt index 4e7e484a8..a6b0c4501 100644 --- a/lib/TPP/Transforms/Utils/CMakeLists.txt +++ b/lib/TPP/Transforms/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(TPPTransformsUtils BuilderUtils.cpp + DLTIUtils.cpp TensorInit.cpp TensorInitFloat.cpp TensorInitInt.cpp diff --git a/lib/TPP/Transforms/Utils/DLTIUtils.cpp b/lib/TPP/Transforms/Utils/DLTIUtils.cpp new file mode 100644 index 000000000..8fac42db2 --- /dev/null +++ b/lib/TPP/Transforms/Utils/DLTIUtils.cpp @@ -0,0 +1,32 @@ +//===- DLTIUtils.cpp ---------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TPP/Transforms/Utils/DLTIUtils.h" + +namespace mlir { +namespace dlti { +namespace utils { + +FailureOr query(Operation *op, ArrayRef keys, + bool emitError) { + if (!op) + return failure(); + + auto ctx = op->getContext(); + SmallVector entryKeys; + for (auto &key : keys) { + auto entry = StringAttr::get(ctx, key); + entryKeys.push_back(entry); + } + + return dlti::query(op, entryKeys, emitError); +} + +} // namespace utils +} // namespace dlti +} // namespace mlir From e7d8a9647f35994e7f9e8e1cc0b5afef2117b372 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 17:04:31 +0100 Subject: [PATCH 20/26] Address comments --- include/TPP/Transforms/Utils/VNNIUtils.h | 19 +++---- lib/TPP/Dialect/Xsmm/XsmmOps.cpp | 3 +- .../TileConsumerAndFuseProducers.cpp | 25 ++------- lib/TPP/Transforms/Utils/VNNIUtils.cpp | 53 ++++--------------- 4 files changed, 21 insertions(+), 79 deletions(-) diff --git a/include/TPP/Transforms/Utils/VNNIUtils.h b/include/TPP/Transforms/Utils/VNNIUtils.h index fdda42d4e..7eb739fd8 100644 --- a/include/TPP/Transforms/Utils/VNNIUtils.h +++ b/include/TPP/Transforms/Utils/VNNIUtils.h @@ -15,7 +15,7 @@ namespace mlir { class Type; -class MemRefType; +class ShapedType; class OpOperand; class AffineDimExpr; class AffineMap; @@ -37,26 +37,21 @@ enum class VnniOperandRank { }; // Return the VNNI blocking factor. -// Optionally, operation can be provided to give access to DLTI. +// Optionally, an operation can be provided to give access to DLTI. std::optional getVnniBlockingFactor(Type type, Operation *op = nullptr); -// Return true if the memref is in VNNI layout with rank `expectedRank`. +// Return true if the shaped type is in VNNI layout with rank `expectedRank`. // Optionally, the check can be constrained to a specific VNNI blocking factor. -bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref, +bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape, std::optional blockingFactor = std::nullopt); -// Return true if the vector is in VNNI layout with rank `expectedRank`. +// Return true if the shaped type is in VNNI layout with rank `expectedRank`. // Optionally, the check can be constrained to a specific VNNI blocking factor. -bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector, +bool isInVnniLayout(int64_t expectedRank, ShapedType shape, std::optional blockingFactor = std::nullopt); -// Return true if the vector is in VNNI layout with rank `expectedRank`. -// Optionally, the check can be constrained to a specific VNNI blocking factor. -bool isInVnniLayout(int64_t expectedRank, VectorType vector, - std::optional blockingFactor = std::nullopt); - -// Return true if the operation is in VNNI layout. +// Return true if the linalg operation is in VNNI layout. // Optionally, the check can be constrained to a specific VNNI blocking factor. bool isInVnniLayout(linalg::LinalgOp linalgOp, std::optional blockingFactor = std::nullopt); diff --git a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp b/lib/TPP/Dialect/Xsmm/XsmmOps.cpp index f69667051..bf04eb496 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmOps.cpp @@ -456,8 +456,7 @@ LogicalResult GemmOp::verify() { assert(memref && (memref.getRank() == 2 || memref.getRank() == 3)); if (memref.getRank() == 3) { - auto vnniFactor = vnni::utils::getVnniBlockingFactor(memref); - if (!vnniFactor || (*vnniFactor) % 2 != 0 || + if (memref.getShape().back() % 2 != 0 || !vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::GEMM, memref)) { return emitOpError() << "expect VNNI layout for operand: " << actualIdx; diff --git a/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp b/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp index 540a10b22..3f24c1bdd 100644 --- a/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp +++ b/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp @@ -8,6 +8,7 @@ #include "TPP/Passes.h" #include "TPP/Transforms/Transforms.h" +#include "TPP/Transforms/Utils/DLTIUtils.h" #include "TPP/Transforms/Utils/TransformUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -457,28 +458,10 @@ static int64_t getTileForDim(linalg::LinalgOp linalgOp, unsigned dim) { int64_t tile = 32; // Check if a tile size hint is associated to the IR via DLTI. - auto deriveFromDLTI = [&](ModuleOp moduleOp) { - if (!moduleOp) - return; - TargetSystemSpecInterface sysSpec = moduleOp.getTargetSystemSpec(); - if (!sysSpec) - return; - auto deviceId = StringAttr::get(linalgOp->getContext(), "CPU"); - auto deviceSpec = sysSpec.getDeviceSpecForDeviceID(deviceId); - if (!deviceSpec) - return; - auto tileSizeId = StringAttr::get(linalgOp->getContext(), "tile_size"); - DataLayoutEntryInterface entry = - (*deviceSpec).getSpecForIdentifier(tileSizeId); - if (!entry) - return; - Attribute value = entry.getValue(); - if (auto intAttr = llvm::dyn_cast(value)) + auto tileValue = dlti::utils::query(linalgOp, {"CPU", "tile_size"}); + if (succeeded(tileValue)) + if (auto intAttr = llvm::dyn_cast(*tileValue)) tile = intAttr.getInt(); - // TODO: might want to print a warning if tile_size exists as a key but the - // associated attribute has an unexpected type. - }; - deriveFromDLTI(linalgOp->getParentOfType()); SmallVector loopsRange = linalgOp.getStaticLoopRanges(); if (loopsRange[dim] == ShapedType::kDynamic) diff --git a/lib/TPP/Transforms/Utils/VNNIUtils.cpp b/lib/TPP/Transforms/Utils/VNNIUtils.cpp index 41d00ca6c..d760eefd3 100644 --- a/lib/TPP/Transforms/Utils/VNNIUtils.cpp +++ b/lib/TPP/Transforms/Utils/VNNIUtils.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TPP/Transforms/Utils/VNNIUtils.h" +#include "TPP/Transforms/Utils/DLTIUtils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" @@ -25,52 +26,16 @@ std::optional getVnniBlockingFactor(Type type, Operation *op) { auto elementType = getElementTypeOrSelf(type); if (elementType.isBF16()) { // Check if a VNNI factor hint is associated to the IR via DLTI. - auto deriveVnniFromDLTI = [&]() -> std::optional { - if (!op) - return std::nullopt; - ModuleOp moduleOp = op->getParentOfType(); - if (!moduleOp) - return std::nullopt; - TargetSystemSpecInterface sysSpec = moduleOp.getTargetSystemSpec(); - if (!sysSpec) - return std::nullopt; - auto deviceId = StringAttr::get(moduleOp->getContext(), "CPU"); - auto deviceSpec = sysSpec.getDeviceSpecForDeviceID(deviceId); - if (!deviceSpec) - return std::nullopt; - auto vnniId = StringAttr::get(moduleOp->getContext(), "vnni"); - DataLayoutEntryInterface entry = - (*deviceSpec).getSpecForIdentifier(vnniId); - if (!entry) - return std::nullopt; - Attribute value = entry.getValue(); - if (auto intAttr = llvm::dyn_cast(value)) + auto vnniValue = dlti::utils::query(op, {"CPU", "vnni"}); + if (succeeded(vnniValue)) + if (auto intAttr = llvm::dyn_cast(*vnniValue)) return intAttr.getInt(); - return std::nullopt; - }; - if (auto vnniFactor = deriveVnniFromDLTI()) - return *vnniFactor; return libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16); } return std::nullopt; } -// Until we have a better way to express the VNNI layout (see: #563), it is up -// to the callee to specify the expected rank in the VNNI layout as the rank -// depends on the operations we are dealing with. -bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref, - std::optional blockingFactor) { - if (memref.getRank() != static_cast(expectedRank) || - !memref.getElementType().isBF16()) - return false; - - if (blockingFactor && memref.getShape().back() != *blockingFactor) - return false; - - return true; -} - bool isInVnniLayout(linalg::LinalgOp linalgOp, std::optional blockingFactor) { // Narrow down type operations - VNNI only applies to contractions. @@ -142,18 +107,18 @@ bool isInVnniLayout(linalg::LinalgOp linalgOp, return true; } -bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector, +bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape, std::optional blockingFactor) { - return isInVnniLayout(static_cast(expectedRank), vector, + return isInVnniLayout(static_cast(expectedRank), shape, blockingFactor); } -bool isInVnniLayout(int64_t expectedRank, VectorType vector, +bool isInVnniLayout(int64_t expectedRank, ShapedType shape, std::optional blockingFactor) { - if (vector.getRank() != expectedRank || !vector.getElementType().isBF16()) + if (shape.getRank() != expectedRank || !shape.getElementType().isBF16()) return false; - if (blockingFactor && vector.getShape().back() != *blockingFactor) + if (blockingFactor && shape.getShape().back() != *blockingFactor) return false; return true; From 8d72e4f6a35abe98747900e2f8f6a80d37310402 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 17:16:08 +0100 Subject: [PATCH 21/26] Remove test tweak --- test/BF16/Integration/tpp-run-splat-shape.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/BF16/Integration/tpp-run-splat-shape.mlir b/test/BF16/Integration/tpp-run-splat-shape.mlir index 4a865ab09..624aeb754 100644 --- a/test/BF16/Integration/tpp-run-splat-shape.mlir +++ b/test/BF16/Integration/tpp-run-splat-shape.mlir @@ -41,7 +41,7 @@ func.func @entry(%arg0: tensor<4x8x8x8xbf16>, %output: tensor<4x8x8x8xbf16>) -> // due to compile time packing. // CHECK-NOT: memref.global "private" constant @__constant_{{.*}}: memref<8x8xbf16> // CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<4x8x8x8xbf16> -// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<8x8x{{[4|2]}}x8x{{2|4}}xbf16> +// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<8x8x4x8x2xbf16> // CHECK: xsmm_brgemm_invoke // CHECK: xsmm_binary_invoke // CHECK: xsmm_unary_invoke From 9aa7b9b9bd65272b9e826afd584f18faebe21b5e Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 17:42:43 +0100 Subject: [PATCH 22/26] Validate VNNI factor --- lib/TPP/Transforms/Utils/VNNIUtils.cpp | 15 +++- test/Passes/pack-vnni.mlir | 104 +++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 4 deletions(-) create mode 100644 test/Passes/pack-vnni.mlir diff --git a/lib/TPP/Transforms/Utils/VNNIUtils.cpp b/lib/TPP/Transforms/Utils/VNNIUtils.cpp index d760eefd3..c185eedbd 100644 --- a/lib/TPP/Transforms/Utils/VNNIUtils.cpp +++ b/lib/TPP/Transforms/Utils/VNNIUtils.cpp @@ -23,16 +23,23 @@ namespace vnni { namespace utils { std::optional getVnniBlockingFactor(Type type, Operation *op) { + int64_t blockingFactor = 0; + auto elementType = getElementTypeOrSelf(type); if (elementType.isBF16()) { // Check if a VNNI factor hint is associated to the IR via DLTI. auto vnniValue = dlti::utils::query(op, {"CPU", "vnni"}); - if (succeeded(vnniValue)) + if (succeeded(vnniValue)) { if (auto intAttr = llvm::dyn_cast(*vnniValue)) - return intAttr.getInt(); - - return libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16); + blockingFactor = intAttr.getInt(); + } else { + blockingFactor = libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16); + } } + + if (blockingFactor != 0 && blockingFactor % 2 == 0) + return blockingFactor; + return std::nullopt; } diff --git a/test/Passes/pack-vnni.mlir b/test/Passes/pack-vnni.mlir new file mode 100644 index 000000000..e30050712 --- /dev/null +++ b/test/Passes/pack-vnni.mlir @@ -0,0 +1,104 @@ +// RUN: tpp-opt -pack-vnni -split-input-file %s | FileCheck %s + +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @brgemm_vnni_2(%arg0: tensor<5x32x64xbf16>, %arg1: tensor<5x64x32xbf16>, + %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16>{ + %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1: tensor<5x32x64xbf16>, tensor<5x64x32xbf16>) + outs(%arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %0: tensor<32x32xbf16> + } +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> + +// CHECK-LABEL: @brgemm_vnni_2( +// CHECK-SAME: %[[ARG0:.+]]: tensor<5x32x64xbf16>, %[[ARG1:.+]]: tensor<5x64x32xbf16>, +// CHECK-SAME: %[[ARG2:.+]]: tensor<32x32xbf16> +// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] +// CHECK-SAME: output_shape{{.*}}: tensor<5x32x64xbf16> into tensor<5x32x32x2xbf16> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [2] +// CHECK-SAME: : tensor<5x64x32xbf16> -> tensor<5x32x32x2xbf16> +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"] +// CHECK-SAME: ins(%[[VNNI_A]], %[[PACK]] +// CHECK-SAME: outs(%[[ARG2]] + +// ----- + +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 4 : i32>> +} { + func.func @brgemm_vnni_4(%arg0: tensor<5x32x64xbf16>, %arg1: tensor<5x64x32xbf16>, + %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16>{ + %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1: tensor<5x32x64xbf16>, tensor<5x64x32xbf16>) + outs(%arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %0: tensor<32x32xbf16> + } +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> + +// CHECK-LABEL: @brgemm_vnni_4( +// CHECK-SAME: %[[ARG0:.+]]: tensor<5x32x64xbf16>, %[[ARG1:.+]]: tensor<5x64x32xbf16>, +// CHECK-SAME: %[[ARG2:.+]]: tensor<32x32xbf16> +// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] +// CHECK-SAME: output_shape{{.*}}: tensor<5x32x64xbf16> into tensor<5x32x16x4xbf16> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [4] +// CHECK-SAME: : tensor<5x64x32xbf16> -> tensor<5x16x32x4xbf16> +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"] +// CHECK-SAME: ins(%[[VNNI_A]], %[[PACK]] +// CHECK-SAME: outs(%[[ARG2]] + +// ----- + +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 0 : i32>> +} { + func.func @invalid_vnni_factor_0(%arg0: tensor<5x32x64xbf16>, %arg1: tensor<5x64x32xbf16>, + %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16>{ + %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1: tensor<5x32x64xbf16>, tensor<5x64x32xbf16>) + outs(%arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %0: tensor<32x32xbf16> + } +} + +// CHECK-LABEL: @invalid_vnni_factor_0( +// CHECK-SAME: %[[ARG0:.+]]: tensor<5x32x64xbf16>, %[[ARG1:.+]]: tensor<5x64x32xbf16>, +// CHECK-SAME: %[[ARG2:.+]]: tensor<32x32xbf16> +// CHECK-NOT: linalg.generic +// CHECK: linalg.batch_reduce_matmul + +// ----- + +// Blocking factor is expected to be divisible by 2. +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 5 : i32>> +} { + func.func @invalid_vnni_factor_5(%arg0: tensor<5x32x64xbf16>, %arg1: tensor<5x64x32xbf16>, + %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16>{ + %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1: tensor<5x32x64xbf16>, tensor<5x64x32xbf16>) + outs(%arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %0: tensor<32x32xbf16> + } +} + +// CHECK-LABEL: @invalid_vnni_factor_5( +// CHECK-SAME: %[[ARG0:.+]]: tensor<5x32x64xbf16>, %[[ARG1:.+]]: tensor<5x64x32xbf16>, +// CHECK-SAME: %[[ARG2:.+]]: tensor<32x32xbf16> +// CHECK-NOT: linalg.generic +// CHECK: linalg.batch_reduce_matmul From c7d5a5d474740328cc1d9891ab0326eade9beaa9 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 17:45:56 +0100 Subject: [PATCH 23/26] Enforce vnni shape to be mod 2 --- lib/TPP/Transforms/Utils/VNNIUtils.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/TPP/Transforms/Utils/VNNIUtils.cpp b/lib/TPP/Transforms/Utils/VNNIUtils.cpp index c185eedbd..dd63f9e1a 100644 --- a/lib/TPP/Transforms/Utils/VNNIUtils.cpp +++ b/lib/TPP/Transforms/Utils/VNNIUtils.cpp @@ -125,6 +125,9 @@ bool isInVnniLayout(int64_t expectedRank, ShapedType shape, if (shape.getRank() != expectedRank || !shape.getElementType().isBF16()) return false; + if (shape.getShape().back() % 2 != 0) + return false; + if (blockingFactor && shape.getShape().back() != *blockingFactor) return false; From ae2dfa6abf610b1ab3d2fd4f831a7d312889abfb Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 18:51:00 +0100 Subject: [PATCH 24/26] Improve VNNI APIs and verification --- include/TPP/Transforms/Utils/VNNIUtils.h | 13 ++++--- .../ConvertLinalgToXsmm.cpp | 7 ++-- .../ConvertVectorToXsmm.cpp | 3 +- lib/TPP/Transforms/ToBlockLayoutAndBack.cpp | 4 +-- lib/TPP/Transforms/Utils/VNNIUtils.cpp | 34 ++++++++++--------- 5 files changed, 32 insertions(+), 29 deletions(-) diff --git a/include/TPP/Transforms/Utils/VNNIUtils.h b/include/TPP/Transforms/Utils/VNNIUtils.h index 7eb739fd8..d5c99c440 100644 --- a/include/TPP/Transforms/Utils/VNNIUtils.h +++ b/include/TPP/Transforms/Utils/VNNIUtils.h @@ -36,25 +36,24 @@ enum class VnniOperandRank { BRGEMM_OUTS = 3 }; -// Return the VNNI blocking factor. +// Return the VNNI blocking factor if it can be determined for the given type or +// zero, otherwise. // Optionally, an operation can be provided to give access to DLTI. -std::optional getVnniBlockingFactor(Type type, - Operation *op = nullptr); +unsigned getVnniBlockingFactor(Type type, Operation *op = nullptr); // Return true if the shaped type is in VNNI layout with rank `expectedRank`. // Optionally, the check can be constrained to a specific VNNI blocking factor. bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape, - std::optional blockingFactor = std::nullopt); + unsigned blockingFactor = 0); // Return true if the shaped type is in VNNI layout with rank `expectedRank`. // Optionally, the check can be constrained to a specific VNNI blocking factor. bool isInVnniLayout(int64_t expectedRank, ShapedType shape, - std::optional blockingFactor = std::nullopt); + unsigned blockingFactor = 0); // Return true if the linalg operation is in VNNI layout. // Optionally, the check can be constrained to a specific VNNI blocking factor. -bool isInVnniLayout(linalg::LinalgOp linalgOp, - std::optional blockingFactor = std::nullopt); +bool isInVnniLayout(linalg::LinalgOp linalgOp, unsigned blockingFactor = 0); } // namespace utils } // namespace vnni diff --git a/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp b/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp index efc1ab5c6..c9166341f 100644 --- a/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp +++ b/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp @@ -1068,9 +1068,10 @@ struct ConvertVnniPacking : public OpRewritePattern { if (failed(stridesOnOutput) || stridesOnOutput->back() != 1) return failure(); // Ajust ldo based on the VNNI factor. - unaryInfo.ldo = - stridesOnOutput->front() / - *vnni::utils::getVnniBlockingFactor(out.getType(), transposeOp); + auto vnniFactor = + vnni::utils::getVnniBlockingFactor(out.getType(), transposeOp); + assert(vnniFactor && "Failed to get VNNI blocking factor"); + unaryInfo.ldo = stridesOnOutput->front() / vnniFactor; auto flags = rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get( rewriter.getContext(), xsmm::UnaryFlags::NONE)); xsmm::UnaryKindAttr kind = diff --git a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp index e309eb384..b6f2ec3ee 100644 --- a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp +++ b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp @@ -99,7 +99,8 @@ convertTransposeOp(PatternRewriter &rewriter, Operation *transposeOp, if (vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::TRANSPOSE, outType)) { // Adjust ldo based on vnni factor - auto vnniFactor = *vnni::utils::getVnniBlockingFactor(outType, transposeOp); + auto vnniFactor = vnni::utils::getVnniBlockingFactor(outType, transposeOp); + assert(vnniFactor && "Failed to get VNNI blocking factor"); unaryInfo.ldo = unaryInfo.ldo / vnniFactor; } else { std::swap(unaryInfo.m, unaryInfo.n); diff --git a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp index f547d392d..688c492f6 100644 --- a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp +++ b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp @@ -345,7 +345,7 @@ mlir::linalgx::packVNNIMatmulOp(RewriterBase &rewriter, Location loc = matmulOp.getLoc(); SmallVector tilesOnSmallK = { - rewriter.getI64IntegerAttr(*blockingFactor)}; + rewriter.getI64IntegerAttr(blockingFactor)}; SmallVector> kOperands; matmulOp.mapIterationSpaceDimToAllOperandDims(dims->k.back(), kOperands); if (kOperands.size() != 2) @@ -416,7 +416,7 @@ mlir::linalgx::packVNNIBRGemmOp(RewriterBase &rewriter, "unsupported blocking factor for type"); } SmallVector tilesOnK = { - rewriter.getI64IntegerAttr(*blockingFactor)}; + rewriter.getI64IntegerAttr(blockingFactor)}; Location loc = brgemmOp.getLoc(); // Reshape input A. diff --git a/lib/TPP/Transforms/Utils/VNNIUtils.cpp b/lib/TPP/Transforms/Utils/VNNIUtils.cpp index dd63f9e1a..0d7a07042 100644 --- a/lib/TPP/Transforms/Utils/VNNIUtils.cpp +++ b/lib/TPP/Transforms/Utils/VNNIUtils.cpp @@ -22,8 +22,8 @@ namespace mlir { namespace vnni { namespace utils { -std::optional getVnniBlockingFactor(Type type, Operation *op) { - int64_t blockingFactor = 0; +unsigned getVnniBlockingFactor(Type type, Operation *op) { + unsigned blockingFactor = 0; auto elementType = getElementTypeOrSelf(type); if (elementType.isBF16()) { @@ -37,14 +37,14 @@ std::optional getVnniBlockingFactor(Type type, Operation *op) { } } - if (blockingFactor != 0 && blockingFactor % 2 == 0) - return blockingFactor; + // Ensure that the factor is divisible by two. + if (blockingFactor % 2 != 0) + return 0; - return std::nullopt; + return blockingFactor; } -bool isInVnniLayout(linalg::LinalgOp linalgOp, - std::optional blockingFactor) { +bool isInVnniLayout(linalg::LinalgOp linalgOp, unsigned blockingFactor) { // Narrow down type operations - VNNI only applies to contractions. if (!linalg::isaContractionOpInterface(linalgOp)) return false; @@ -101,10 +101,12 @@ bool isInVnniLayout(linalg::LinalgOp linalgOp, // - statically known // - multiple of 2 or equal to the specified factor auto vnniDimSize = typeB.getShape().back(); - if (!(vnniDimSize != ShapedType::kDynamic && - typeA.getShape().back() == vnniDimSize && - (blockingFactor ? vnniDimSize == *blockingFactor - : vnniDimSize % 2 == 0))) + if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 || + vnniDimSize % 2 != 0) + return false; + if (typeA.getShape().back() != vnniDimSize) + return false; + if (blockingFactor && vnniDimSize != blockingFactor) return false; // The split reduction dimension size should also match. @@ -115,20 +117,20 @@ bool isInVnniLayout(linalg::LinalgOp linalgOp, } bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape, - std::optional blockingFactor) { + unsigned blockingFactor) { return isInVnniLayout(static_cast(expectedRank), shape, blockingFactor); } bool isInVnniLayout(int64_t expectedRank, ShapedType shape, - std::optional blockingFactor) { + unsigned blockingFactor) { if (shape.getRank() != expectedRank || !shape.getElementType().isBF16()) return false; - if (shape.getShape().back() % 2 != 0) + auto vnniDim = shape.getShape().back(); + if (vnniDim == 0 || vnniDim % 2 != 0) return false; - - if (blockingFactor && shape.getShape().back() != *blockingFactor) + if (blockingFactor && vnniDim != blockingFactor) return false; return true; From c5d9bad8e1e39d53f2d05e31ab1898b32d29f743 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 19:17:24 +0100 Subject: [PATCH 25/26] Verify valid factor --- lib/TPP/Dialect/Xsmm/XsmmVerify.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp b/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp index df8071ba8..384ec7ddc 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp @@ -72,23 +72,25 @@ static LogicalResult verifyGemmDispatchAndInvokeLikeOp(InvokeTy gemmOp) { // VNNI flags must be consistent with the memref shapes. auto vnniFactor = vnni::utils::getVnniBlockingFactor(operandA, gemmOp); + ArrayAttr flags = dispatchOp->getFlags(); for (auto flag : flags) { int64_t gemmFlag = cast(flag).getInt(); if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_A) && - !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA, - vnniFactor)) { + (!vnniFactor || !vnni::utils::isInVnniLayout(expectedVnniRankIns, + operandA, vnniFactor))) { return gemmOp.emitOpError( "expect VNNI layout for operand A or invalid VNNI_A flags"); } if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_B) && - !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB, - vnniFactor)) { + (!vnniFactor || !vnni::utils::isInVnniLayout(expectedVnniRankIns, + operandB, vnniFactor))) { return gemmOp.emitOpError( "expect VNNI layout for operand B or invalid VNNI_B flags"); } if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_C) && - !vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC, vnniFactor)) { + (!vnniFactor || !vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC, + vnniFactor))) { return gemmOp.emitOpError( "expect VNNI layout for operand C or invalid VNNI_C flags"); } From 2b7ccb57831a40e90985d4fb7c129482c190758c Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 17 Jan 2025 09:16:50 +0100 Subject: [PATCH 26/26] Fold factor verification into layout checker --- include/TPP/Transforms/Utils/VNNIUtils.h | 7 ++++--- lib/TPP/Dialect/Xsmm/XsmmVerify.cpp | 11 +++++------ lib/TPP/Transforms/Utils/VNNIUtils.cpp | 11 ++++++----- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/include/TPP/Transforms/Utils/VNNIUtils.h b/include/TPP/Transforms/Utils/VNNIUtils.h index d5c99c440..d5d12a6f6 100644 --- a/include/TPP/Transforms/Utils/VNNIUtils.h +++ b/include/TPP/Transforms/Utils/VNNIUtils.h @@ -44,16 +44,17 @@ unsigned getVnniBlockingFactor(Type type, Operation *op = nullptr); // Return true if the shaped type is in VNNI layout with rank `expectedRank`. // Optionally, the check can be constrained to a specific VNNI blocking factor. bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape, - unsigned blockingFactor = 0); + std::optional blockingFactor = std::nullopt); // Return true if the shaped type is in VNNI layout with rank `expectedRank`. // Optionally, the check can be constrained to a specific VNNI blocking factor. bool isInVnniLayout(int64_t expectedRank, ShapedType shape, - unsigned blockingFactor = 0); + std::optional blockingFactor = std::nullopt); // Return true if the linalg operation is in VNNI layout. // Optionally, the check can be constrained to a specific VNNI blocking factor. -bool isInVnniLayout(linalg::LinalgOp linalgOp, unsigned blockingFactor = 0); +bool isInVnniLayout(linalg::LinalgOp linalgOp, + std::optional blockingFactor = std::nullopt); } // namespace utils } // namespace vnni diff --git a/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp b/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp index 384ec7ddc..2040e8833 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp @@ -77,20 +77,19 @@ static LogicalResult verifyGemmDispatchAndInvokeLikeOp(InvokeTy gemmOp) { for (auto flag : flags) { int64_t gemmFlag = cast(flag).getInt(); if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_A) && - (!vnniFactor || !vnni::utils::isInVnniLayout(expectedVnniRankIns, - operandA, vnniFactor))) { + !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA, + vnniFactor)) { return gemmOp.emitOpError( "expect VNNI layout for operand A or invalid VNNI_A flags"); } if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_B) && - (!vnniFactor || !vnni::utils::isInVnniLayout(expectedVnniRankIns, - operandB, vnniFactor))) { + !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB, + vnniFactor)) { return gemmOp.emitOpError( "expect VNNI layout for operand B or invalid VNNI_B flags"); } if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_C) && - (!vnniFactor || !vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC, - vnniFactor))) { + !vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC, vnniFactor)) { return gemmOp.emitOpError( "expect VNNI layout for operand C or invalid VNNI_C flags"); } diff --git a/lib/TPP/Transforms/Utils/VNNIUtils.cpp b/lib/TPP/Transforms/Utils/VNNIUtils.cpp index 0d7a07042..87f290e25 100644 --- a/lib/TPP/Transforms/Utils/VNNIUtils.cpp +++ b/lib/TPP/Transforms/Utils/VNNIUtils.cpp @@ -44,7 +44,8 @@ unsigned getVnniBlockingFactor(Type type, Operation *op) { return blockingFactor; } -bool isInVnniLayout(linalg::LinalgOp linalgOp, unsigned blockingFactor) { +bool isInVnniLayout(linalg::LinalgOp linalgOp, + std::optional blockingFactor) { // Narrow down type operations - VNNI only applies to contractions. if (!linalg::isaContractionOpInterface(linalgOp)) return false; @@ -106,7 +107,7 @@ bool isInVnniLayout(linalg::LinalgOp linalgOp, unsigned blockingFactor) { return false; if (typeA.getShape().back() != vnniDimSize) return false; - if (blockingFactor && vnniDimSize != blockingFactor) + if (blockingFactor && vnniDimSize != *blockingFactor) return false; // The split reduction dimension size should also match. @@ -117,20 +118,20 @@ bool isInVnniLayout(linalg::LinalgOp linalgOp, unsigned blockingFactor) { } bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape, - unsigned blockingFactor) { + std::optional blockingFactor) { return isInVnniLayout(static_cast(expectedRank), shape, blockingFactor); } bool isInVnniLayout(int64_t expectedRank, ShapedType shape, - unsigned blockingFactor) { + std::optional blockingFactor) { if (shape.getRank() != expectedRank || !shape.getElementType().isBF16()) return false; auto vnniDim = shape.getShape().back(); if (vnniDim == 0 || vnniDim % 2 != 0) return false; - if (blockingFactor && vnniDim != blockingFactor) + if (blockingFactor && vnniDim != *blockingFactor) return false; return true;