diff --git a/dipu/torch_dipu/csrc_dipu/runtime/core/allocator/DIPUBFCachingAllocator.cpp b/dipu/torch_dipu/csrc_dipu/runtime/core/allocator/DIPUBFCachingAllocator.cpp index 13c902147..a13ce575d 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/core/allocator/DIPUBFCachingAllocator.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/core/allocator/DIPUBFCachingAllocator.cpp @@ -429,6 +429,24 @@ class BFCachingAllocator : public CacheAllocator { } void empty_resource_pool() const { + std::lock_guard lk(resource_pool_mutex_); + while (!async_mem_pool()->empty()) { + if (!async_mem_pool()->ready()) { + std::this_thread::yield(); + continue; + } + const auto block = async_mem_pool()->get(); + void* ptr = std::get<0>(block); + int id = static_cast(std::get<1>(block)); + DIPU_DEBUG_ALLOCATOR( + 8, "BFCachingAllocator: " << __FUNCTION__ << " ,ptr:" << ptr + << " ,id:" << id << " ,allocator:" << this + << ", device:" << device()); + impl->releaseRaw(ptr, id); + } + } + + bool try_empty_resource_pool() const { using namespace std::chrono_literals; std::lock_guard lk(resource_pool_mutex_); auto start = std::chrono::steady_clock::now(); @@ -441,7 +459,7 @@ class BFCachingAllocator : public CacheAllocator { std::this_thread::yield(); continue; } - return; + return false; } const auto block = async_mem_pool()->get(); void* ptr = std::get<0>(block); @@ -452,6 +470,7 @@ class BFCachingAllocator : public CacheAllocator { << ", device:" << device()); impl->releaseRaw(ptr, id); } + return true; } void check_impl() const { @@ -520,7 +539,7 @@ class BFCachingAllocator : public CacheAllocator { c10::DataPtr allocate(size_t size) const override { restore(); if (async_mem_pool()->size() > kMaxAsyncResourcePoolLength) { - empty_resource_pool(); + try_empty_resource_pool(); } size = getMemoryAlignmentStrategy()->roundBytes(size); std::tuple block = impl->allocateRaw(size);