From 6b225081c0579725596c4c1b0034e58a3b20afa4 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Fri, 8 Nov 2024 16:07:08 -0500 Subject: [PATCH] Minor changes to the memory management. --- lib/nnc/mfa/ccv_nnc_mfa.cpp | 2 +- lib/nnc/mfa/ccv_nnc_mfa_add.cpp | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/lib/nnc/mfa/ccv_nnc_mfa.cpp b/lib/nnc/mfa/ccv_nnc_mfa.cpp index 13ae4b798..702ba3f27 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa.cpp @@ -210,7 +210,7 @@ MTL::Buffer* mfa::context::request_scratch(uint64_t size) { uint64_t leading_zeroes = __builtin_clzll(padded_size); uint64_t rounded_size = 1 << uint64_t(64 - leading_zeroes); - auto buffer = device->newBuffer(rounded_size, MTL::ResourceStorageModePrivate); + auto buffer = device->newBuffer(rounded_size, MTL::ResourceStorageModePrivate | MTL::ResourceHazardTrackingModeTracked); CCV_NNC_MFA_PRECONDITION(buffer != nullptr); this->scratch = NS::TransferPtr(buffer); } diff --git a/lib/nnc/mfa/ccv_nnc_mfa_add.cpp b/lib/nnc/mfa/ccv_nnc_mfa_add.cpp index 5d0714044..1a1f7deab 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_add.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_add.cpp @@ -31,9 +31,17 @@ void ccv_nnc_mfa_encode_add(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_add_para CCV_NNC_MFA_PRECONDITION(num_tensors == 3); encoder->setComputePipelineState(pipeline->add_pso.get()); - encoder->useResource(tensors[0], MTL::ResourceUsageRead); - encoder->useResource(tensors[1], MTL::ResourceUsageRead); - encoder->useResource(tensors[2], MTL::ResourceUsageWrite); + if (tensors[0] == tensors[2]) { + encoder->useResource(tensors[0], MTL::ResourceUsageRead | MTL::ResourceUsageWrite); + encoder->useResource(tensors[1], MTL::ResourceUsageRead); + } else if (tensors[1] == tensors[2]) { + encoder->useResource(tensors[0], MTL::ResourceUsageRead); + encoder->useResource(tensors[1], MTL::ResourceUsageRead | MTL::ResourceUsageWrite); + } else { + encoder->useResource(tensors[0], MTL::ResourceUsageRead); + encoder->useResource(tensors[1], MTL::ResourceUsageRead); + encoder->useResource(tensors[2], MTL::ResourceUsageWrite); + } auto grid_size = pipeline->grid_size; CCV_NNC_MFA_PRECONDITION(grid_size.depth > 0);