From ad31e0f1f3875ff5f0456ad96b4963847cc6a571 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 13:59:58 +0100 Subject: [PATCH] Restore deleted tests --- test/BF16/Integration/matmul-pbf16.mlir | 50 +++++++ .../BF16/Integration/mlp-all-bf16-tpprun.mlir | 137 ++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 test/BF16/Integration/matmul-pbf16.mlir create mode 100644 test/BF16/Integration/mlp-all-bf16-tpprun.mlir diff --git a/test/BF16/Integration/matmul-pbf16.mlir b/test/BF16/Integration/matmul-pbf16.mlir new file mode 100644 index 000000000..f2434271d --- /dev/null +++ b/test/BF16/Integration/matmul-pbf16.mlir @@ -0,0 +1,50 @@ +// RUN: tpp-run %s -print \ +// RUN: -e entry -entry-point-result=void | \ +// RUN: FileCheck %s + +#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +func.func @matmultpp(%A: memref<4x8xbf16>, + %B: memref<4x4x2xbf16>, %C: memref<4x4xbf16>) { + %expanded = memref.expand_shape %A [[0], [1, 2]] output_shape [4, 4, 2] + : memref<4x8xbf16> into memref<4x4x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %B : memref<4x4x2xbf16>, memref<4x4x2xbf16>) + outs(%C : memref<4x4xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return +} + +func.func @entry() { + %c0 = arith.constant 0 : index + %f0 = arith.constant 1.0 : bf16 + %da = memref.alloc() :memref<4x8xbf16> + linalg.fill ins(%f0 : bf16) outs(%da : memref<4x8xbf16>) + // Call kernel. + %0 = memref.alloc() : memref<4x4x2xbf16> + linalg.fill ins(%f0:bf16) outs (%0: memref<4x4x2xbf16>) + %D = memref.alloc() : memref<4x4xbf16> + %zero = arith.constant 0.0 : bf16 + linalg.fill ins(%zero : bf16) outs(%D:memref<4x4xbf16>) + call @matmultpp(%da, %0, %D) + : (memref<4x8xbf16>, memref<4x4x2xbf16>, memref<4x4xbf16>)->() + + // + // CHECK:( ( 8, 8, 8, 8 ), ( 8, 8, 8, 8 ), ( 8, 8, 8, 8 ), ( 8, 8, 8, 8 ) ) + // + %d1 = arith.constant -1.0 : bf16 + + %v0 = vector.transfer_read %D[%c0, %c0], %d1 : memref<4x4xbf16>, vector<4x4xbf16> + %f1 = arith.extf %v0:vector<4x4xbf16> to vector<4x4xf32> + vector.print %f1 : vector<4x4xf32> + + return +} diff --git a/test/BF16/Integration/mlp-all-bf16-tpprun.mlir b/test/BF16/Integration/mlp-all-bf16-tpprun.mlir new file mode 100644 index 000000000..5f7968719 --- /dev/null +++ b/test/BF16/Integration/mlp-all-bf16-tpprun.mlir @@ -0,0 +1,137 @@ +// RUN: tpp-run %s \ +// RUN: -e entry -entry-point-result=void + +memref.global "private" constant @arg1 : memref<128x512x2xbf16> = dense<1.00e+00> +memref.global "private" constant @arg3 : memref<256x1024x2xbf16> = dense<1.00e+00> +memref.global "private" constant @arg5 : memref<512x2048x2xbf16> = dense<1.00e+00> +memref.global "private" constant @arg7 : memref<1024x1000x2xbf16> = dense<1.00e+00> + +#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +#map3 = affine_map<(d0, d1) -> (d0, d1)> +#map4 = affine_map<(d0, d1) -> (d1)> + +func.func @entry(%arg0: memref<128x256xbf16>, %arg2: memref<512xbf16>, %arg4: memref<1024xbf16>, + %arg6: memref<2048xbf16>, %arg8: memref<1000xbf16>, %arg9: memref<128x512xbf16>, + %arg10: memref<128x1024xbf16>, %arg11: memref<128x2048xbf16>, %arg12: memref<128x1000xbf16>) { + %c0 = arith.constant 0.0 : bf16 + linalg.generic { + indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg2: memref<512xbf16>) outs(%arg9: memref<128x512xbf16>) { + ^bb0(%in: bf16, %out: bf16): + linalg.yield %in : bf16 + } + + %e0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [128, 128, 2] + : memref<128x256xbf16> into memref<128x128x2xbf16> + %relayout_arg0 = memref.get_global @arg1:memref<128x512x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%e0, %relayout_arg0 : memref<128x128x2xbf16>, memref<128x512x2xbf16>) + outs(%arg9 : memref<128x512xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + linalg.generic { + indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg9 : memref<128x512xbf16>) outs(%arg9 : memref<128x512xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %2 = arith.maximumf %in, %c0 : bf16 + linalg.yield %2 : bf16 + } + + linalg.generic { + indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg4: memref<1024xbf16>) outs(%arg10: memref<128x1024xbf16>) { + ^bb0(%in: bf16, %out: bf16): + linalg.yield %in : bf16 + } + + %e1 = memref.expand_shape %arg9 [[0], [1, 2]] output_shape [128, 256, 2] + : memref<128x512xbf16> into memref<128x256x2xbf16> + %relayout_arg12 = memref.get_global @arg3:memref<256x1024x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%e1, %relayout_arg12 : memref<128x256x2xbf16>, memref<256x1024x2xbf16>) + outs(%arg10 : memref<128x1024xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + linalg.generic { + indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg10 : memref<128x1024xbf16>) outs(%arg10 : memref<128x1024xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %2 = arith.maximumf %in, %c0 : bf16 + linalg.yield %2 : bf16 + } + + linalg.generic { + indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg6: memref<2048xbf16>) outs(%arg11: memref<128x2048xbf16>) { + ^bb0(%in: bf16, %out: bf16): + linalg.yield %in : bf16 + } + + %relayout_arg11 = memref.get_global @arg5:memref<512x2048x2xbf16> + %e2 = memref.expand_shape %arg10 [[0], [1, 2]] output_shape [128, 512, 2] + : memref<128x1024xbf16> into memref<128x512x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%e2, %relayout_arg11 : memref<128x512x2xbf16>, memref<512x2048x2xbf16>) + outs(%arg11 : memref<128x2048xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + linalg.generic { + indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg11 : memref<128x2048xbf16>) outs(%arg11 : memref<128x2048xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %2 = arith.maximumf %in, %c0 : bf16 + linalg.yield %2 : bf16 + } + + linalg.generic { + indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg8: memref<1000xbf16>) outs(%arg12: memref<128x1000xbf16>) { + ^bb0(%in: bf16, %out: bf16): + linalg.yield %in : bf16 + } + + %relayout_arg10 = memref.get_global @arg7:memref<1024x1000x2xbf16> + %e3 = memref.expand_shape %arg11 [[0], [1, 2]] output_shape [128, 1024, 2] + : memref<128x2048xbf16> into memref<128x1024x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%e3, %relayout_arg10 : memref<128x1024x2xbf16>, memref<1024x1000x2xbf16>) + outs(%arg12 : memref<128x1000xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + linalg.generic { + indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} + ins(%arg12 : memref<128x1000xbf16>) outs(%arg12 : memref<128x1000xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %2 = arith.maximumf %in, %c0 : bf16 + linalg.yield %2 : bf16 + } + + %threshold = arith.constant 1.0 : bf16 + %c4 = arith.constant 2.74878e+11: bf16 + %interim4 = memref.alloc(): memref<128x1000xbf16> + linalg.fill ins(%c4:bf16) outs(%interim4: memref<128x1000xbf16>) + check.expect_almost_eq(%interim4, %arg12, %threshold): memref<128x1000xbf16>, memref<128x1000xbf16>, bf16 + return +}