Skip to content

Commit

Permalink
[mlir][Linalg] Rewrite PadTensorOp to enable its comprehensive buffer…
Browse files Browse the repository at this point in the history
…ization.

Add the rewrite of PadTensorOp to InitTensor + InsertSlice before the
bufferization analysis starts.

This is exercised via a more advanced integration test.

Since the new behavior triggers folding, 2 tests need to be updated.
One of those seems to exhibit a folding issue with `switch` and is modified.

Differential Revision: https://reviews.llvm.org/D105549
  • Loading branch information
nicolasvasilache committed Jul 7, 2021
1 parent 35df2f6 commit d0b282e
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 24 deletions.
27 changes: 20 additions & 7 deletions mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
Expand All @@ -117,6 +118,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/BufferUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

#include "llvm/ADT/DenseSet.h"
Expand Down Expand Up @@ -1491,9 +1493,7 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
<< "cannot bufferize bodiless function that returns a tensor";
} else {
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
if (!returnOp)
return funcOp->emitError() << "cannot bufferize a FuncOp with tensors "
"and without a unique ReturnOp";
assert(returnOp && "expected func with single return op");

// For each FuncOp result, keep track of which inplace argument it reuses.
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
Expand Down Expand Up @@ -2474,9 +2474,7 @@ static LogicalResult bufferizeFuncOpBoundary(

// Support only single return-terminated block in the function.
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
if (!returnOp)
return funcOp->emitError() << "cannot bufferize a FuncOp with tensors and "
"without a unique ReturnOp";
assert(returnOp && "expected func with single return op");

// 1. For each FuncOp result, keep track of which inplace argument it reuses.
SmallVector<Value> returnValues;
Expand Down Expand Up @@ -2574,7 +2572,15 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
DenseMap<FuncOp, DenseSet<FuncOp>> calledBy;
// For each FuncOp, the number of CallOpInterface it contains.
DenseMap<FuncOp, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](FuncOp funcOp) {
WalkResult res = moduleOp.walk([&](FuncOp funcOp) -> WalkResult {
if (!funcOp.body().empty()) {
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
if (!returnOp)
return funcOp->emitError()
<< "cannot bufferize a FuncOp with tensors and "
"without a unique ReturnOp";
}

numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
// Only support CallOp for now.
Expand Down Expand Up @@ -2622,8 +2628,15 @@ struct LinalgComprehensiveModuleBufferize
};
} // end namespace

static void applyEnablingTransformations(ModuleOp moduleOp) {
RewritePatternSet patterns(moduleOp.getContext());
patterns.add<GeneralizePadTensorOpPattern>(moduleOp.getContext());
(void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
}

void LinalgComprehensiveModuleBufferize::runOnOperation() {
ModuleOp moduleOp = getOperation();
applyEnablingTransformations(moduleOp);

SmallVector<FuncOp> orderedFuncOps;
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,11 @@ func @scf_for_deps(%A : tensor<?xf32> {linalg.inplaceable = true},
// %r0 must be out of place because one use of %t in the subsequent production
// of %r1 is read.
// CHECK: scf.for
// CHECK-NEXT: call
// CHECK-NEXT: scf.yield
// CHECK-NEXT: {__inplace_results_attr__ = ["false"]}
%r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor<?xf32>) {
call @some_use(%t) : (tensor<?xf32>) -> ()
scf.yield %t : tensor<?xf32>
}

Expand All @@ -504,11 +506,13 @@ func @scf_for_deps(%A : tensor<?xf32> {linalg.inplaceable = true},
// %r2 must be out of place because one use of %t in the subsequent production
// of %r3 is read.
// CHECK: linalg.tiled_loop
// CHECK-NEXT: call
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: {__inplace_results_attr__ = ["false"]}
%r2 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step)
ins()
outs(%t = %B: tensor<?xf32>) {
call @some_use(%t) : (tensor<?xf32>) -> ()
linalg.yield %t : tensor<?xf32>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@ func private @foo() -> tensor<?xf32>
// -----

// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
func @switch(%flag : i32, %caseOperand : i32, %t1 : tensor<f32>, %t2 : tensor<f32>)
-> (tensor<f32>)
func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
-> (tensor<f32>, tensor<f32>)
{
switch %flag : i32, [
default: ^bb1(%caseOperand : i32),
42: ^bb2(%caseOperand : i32)
]

^bb1(%bb1arg : i32):
return %t1 : tensor<f32>
^bb2(%bb2arg : i32):
return %t2 : tensor<f32>
cond_br %cond1, ^bb1, ^bb2

^bb1:
%T:2 = scf.if %cond2 -> (tensor<f32>, tensor<f32>) {
scf.yield %t1, %t2 : tensor<f32>, tensor<f32>
} else {
scf.yield %t2, %t1 : tensor<f32>, tensor<f32>
}
return %T#0, %T#1 : tensor<f32>, tensor<f32>
^bb2:
return %t2, %t1 : tensor<f32>, tensor<f32>
}

// -----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,73 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext |\
// RUN: FileCheck %s

func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
%v0 = constant 0.0 : f32
#map0 = affine_map<(d0, d1)[s0] -> ((d1 - d0) ceildiv s0)>
#map1 = affine_map<(d0, d1)[s0] -> ((d0 - d1) ceildiv s0)>

func @init_and_dot(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<f32> {linalg.inplaceable = true}) -> tensor<f32> {
%c64 = constant 64 : index
%cst = constant 0.000000e+00 : f32
%c2 = constant 2 : index
%c0 = constant 0 : index
%0 = linalg.fill(%cst, %arg2) : f32, tensor<f32> -> tensor<f32>
%1 = affine.apply #map0(%c0, %c64)[%c2]
%2 = linalg.init_tensor [%1, 2] : tensor<?x2xf32>
%3 = scf.for %arg3 = %c0 to %c64 step %c2 iter_args(%arg4 = %2) -> (tensor<?x2xf32>) {
%8 = affine.apply #map1(%arg3, %c0)[%c2]
%9 = tensor.extract_slice %arg1[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32>
%10 = tensor.cast %9 : tensor<2xf32> to tensor<?xf32>
%11 = linalg.pad_tensor %10 low[%c0] high[%c0] {
^bb0(%arg5: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?xf32> to tensor<2xf32>
%12 = tensor.insert_slice %11 into %arg4[%8, 0] [1, 2] [1, 1] : tensor<2xf32> into tensor<?x2xf32>
scf.yield %12 : tensor<?x2xf32>
}

// %B = tensor.cast %3 : tensor<?x2xf32> to tensor<*xf32>
// call @print_memref_f32(%B) : (tensor<*xf32>) -> ()

%4 = affine.apply #map0(%c0, %c64)[%c2]
%5 = linalg.init_tensor [%4, 2] : tensor<?x2xf32>
%6 = scf.for %arg3 = %c0 to %c64 step %c2 iter_args(%arg4 = %5) -> (tensor<?x2xf32>) {
%8 = affine.apply #map1(%arg3, %c0)[%c2]
%9 = tensor.extract_slice %arg0[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32>
%10 = tensor.cast %9 : tensor<2xf32> to tensor<?xf32>
%11 = linalg.pad_tensor %10 low[%c0] high[%c0] {
^bb0(%arg5: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?xf32> to tensor<2xf32>
%12 = tensor.insert_slice %11 into %arg4[%8, 0] [1, 2] [1, 1] : tensor<2xf32> into tensor<?x2xf32>
scf.yield %12 : tensor<?x2xf32>
}

// %A = tensor.cast %6 : tensor<?x2xf32> to tensor<*xf32>
// call @print_memref_f32(%A) : (tensor<*xf32>) -> ()

// %C = tensor.cast %0 : tensor<f32> to tensor<*xf32>
// call @print_memref_f32(%C) : (tensor<*xf32>) -> ()

%d = linalg.fill(%v0, %c) : f32, tensor<f32> -> tensor<f32>
%7 = scf.for %arg3 = %c0 to %c64 step %c2 iter_args(%arg4 = %0) -> (tensor<f32>) {
%8 = tensor.extract_slice %arg0[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32>
%9 = tensor.cast %8 : tensor<2xf32> to tensor<?xf32>
%10 = tensor.extract_slice %arg1[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32>
%11 = tensor.cast %10 : tensor<2xf32> to tensor<?xf32>
%12 = affine.apply #map1(%arg3, %c0)[%c2]
%13 = tensor.extract_slice %6[%12, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
%14 = affine.apply #map1(%arg3, %c0)[%c2]
%15 = tensor.extract_slice %3[%14, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
%16 = linalg.dot ins(%13, %15 : tensor<2xf32>, tensor<2xf32>) outs(%arg4 : tensor<f32>) -> tensor<f32>

%e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
outs(%d: tensor<f32>) -> tensor<f32>
// %AA = tensor.cast %13 : tensor<2xf32> to tensor<*xf32>
// call @print_memref_f32(%AA) : (tensor<*xf32>) -> ()
// %BB = tensor.cast %15 : tensor<2xf32> to tensor<*xf32>
// call @print_memref_f32(%BB) : (tensor<*xf32>) -> ()
// %CC = tensor.cast %16 : tensor<f32> to tensor<*xf32>
// call @print_memref_f32(%CC) : (tensor<*xf32>) -> ()

return %e : tensor<f32>
scf.yield %16 : tensor<f32>
}
return %7 : tensor<f32>
}

func @main() {
Expand Down

0 comments on commit d0b282e

Please sign in to comment.