From 35ccc9b00200b7377398ea19b90ce2ba1ac5cc01 Mon Sep 17 00:00:00 2001 From: Lingjie Li Date: Mon, 25 Mar 2024 12:53:48 +0800 Subject: [PATCH] refactor: make code more readable by extracting a function --- .../csrc_dipu/vendor/ascend/deviceimpl.cpp | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp b/dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp index 5623e17e1..a12cfe4ea 100644 --- a/dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp +++ b/dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp @@ -17,9 +17,24 @@ namespace devapis { // Device class related // ===================== using AscendDeviceId = int32_t; + +namespace { + constexpr AscendDeviceId kDeviceIdUnset = -1; constexpr AscendDeviceId kDeviceIdDefault = 0; -static thread_local auto g_current_thread_device_id = kDeviceIdUnset; +thread_local auto g_current_thread_device_id = kDeviceIdUnset; + +// atomically set global device id if it is unset +// and anyway return the global device id +AscendDeviceId setOrGetGlobalDeviceId(AscendDeviceId device_id_if_unset) { + static std::atomic global_device_id = kDeviceIdUnset; + auto expectedUnset = kDeviceIdUnset; + std::atomic_compare_exchange_strong(&global_device_id, &expectedUnset, + device_id_if_unset); + return global_device_id.load(); +} + +} // namespace void initializeVendor() { DIPU_CALLACLRT(aclInit(nullptr)); @@ -46,10 +61,7 @@ void setDevice(deviceId_t device_id) { } auto ascend_device_id = static_cast(device_id); if (g_current_thread_device_id == kDeviceIdUnset) { - static std::atomic g_global_device_id = kDeviceIdUnset; - std::atomic_compare_exchange_strong( - &g_global_device_id, &g_current_thread_device_id, ascend_device_id); - g_current_thread_device_id = g_global_device_id.load(); + g_current_thread_device_id = setOrGetGlobalDeviceId(ascend_device_id); } if (ascend_device_id != g_current_thread_device_id) { DIPU_LOGW(