Skip to content

Commit

Permalink
Ensure each use of tensor_init op is bufferzied into a unique buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchenye committed Mar 21, 2024
1 parent d0613c7 commit 3b081e0
Showing 1 changed file with 12 additions and 16 deletions.
28 changes: 12 additions & 16 deletions lib/Dialect/HLS/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,22 +198,19 @@ struct TensorInitOpInterface
if (failed(maybeType))
return failure();

FailureOr<Value> buffer = options.createAlloc(
rewriter, tensorInit.getLoc(), maybeType->cast<MemRefType>(), {});
if (failed(buffer))
return failure();

// Handle initial value.
if (auto initValue = tensorInit.getInitValue()) {
auto initValueOp = initValue.getDefiningOp<arith::ConstantOp>();
auto bufferOp = buffer->getDefiningOp<BufferOp>();
if (!initValueOp || !bufferOp)
return failure();
bufferOp.setInitValueAttr(initValueOp.getValue());
// Replace uses of the tensor init op with the buffer op.
for (auto &use : llvm::make_early_inc_range(tensorInit->getUses())) {
rewriter.setInsertionPoint(use.getOwner());
auto buffer = rewriter.create<hls::BufferOp>(tensorInit.getLoc(),
cast<MemRefType>(*maybeType),
tensorInit.getInitValue());
auto replacement = rewriter.create<bufferization::ToTensorOp>(
tensorInit.getLoc(), buffer);
rewriter.updateRootInPlace(use.getOwner(), [&] { use.set(replacement); });
}

// Replace op.
replaceOpWithBufferizedValues(rewriter, tensorInit, *buffer);
// Erase the tensor init op.
rewriter.eraseOp(tensorInit);
return success();
}

Expand All @@ -230,8 +227,7 @@ struct TensorInitOpInterface
else
return tensorInit.emitError("could not infer memory space");

return getMemRefTypeWithStaticIdentityLayout(tensorInit.getType(),
memorySpace);
return getMemRefType(value, options, /*layout=*/{}, memorySpace);
}
};

Expand Down

0 comments on commit 3b081e0

Please sign in to comment.