From 3b081e0b66844e09bf3ce47b8481e38ca91f3d9e Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Wed, 20 Mar 2024 22:15:05 -0500 Subject: [PATCH] Ensure each use of tensor_init op is bufferzied into a unique buffer --- .../BufferizableOpInterfaceImpl.cpp | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/lib/Dialect/HLS/Transforms/BufferizableOpInterfaceImpl.cpp b/lib/Dialect/HLS/Transforms/BufferizableOpInterfaceImpl.cpp index f0877306..3dd07c70 100644 --- a/lib/Dialect/HLS/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/lib/Dialect/HLS/Transforms/BufferizableOpInterfaceImpl.cpp @@ -198,22 +198,19 @@ struct TensorInitOpInterface if (failed(maybeType)) return failure(); - FailureOr buffer = options.createAlloc( - rewriter, tensorInit.getLoc(), maybeType->cast(), {}); - if (failed(buffer)) - return failure(); - - // Handle initial value. - if (auto initValue = tensorInit.getInitValue()) { - auto initValueOp = initValue.getDefiningOp(); - auto bufferOp = buffer->getDefiningOp(); - if (!initValueOp || !bufferOp) - return failure(); - bufferOp.setInitValueAttr(initValueOp.getValue()); + // Replace uses of the tensor init op with the buffer op. + for (auto &use : llvm::make_early_inc_range(tensorInit->getUses())) { + rewriter.setInsertionPoint(use.getOwner()); + auto buffer = rewriter.create(tensorInit.getLoc(), + cast(*maybeType), + tensorInit.getInitValue()); + auto replacement = rewriter.create( + tensorInit.getLoc(), buffer); + rewriter.updateRootInPlace(use.getOwner(), [&] { use.set(replacement); }); } - // Replace op. - replaceOpWithBufferizedValues(rewriter, tensorInit, *buffer); + // Erase the tensor init op. + rewriter.eraseOp(tensorInit); return success(); } @@ -230,8 +227,7 @@ struct TensorInitOpInterface else return tensorInit.emitError("could not infer memory space"); - return getMemRefTypeWithStaticIdentityLayout(tensorInit.getType(), - memorySpace); + return getMemRefType(value, options, /*layout=*/{}, memorySpace); } };