diff --git a/dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp b/dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp index 5623e17e17..fc0fa4539d 100644 --- a/dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp +++ b/dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp @@ -17,9 +17,23 @@ namespace devapis { // Device class related // ===================== using AscendDeviceId = int32_t; + +namespace { + constexpr AscendDeviceId kDeviceIdUnset = -1; constexpr AscendDeviceId kDeviceIdDefault = 0; -static thread_local auto g_current_thread_device_id = kDeviceIdUnset; +thread_local auto g_current_thread_device_id = kDeviceIdUnset; + +// atomically set global device id if it is unset +// and anyway return the global device id +AscendDeviceId setOrGetGlobalDeviceId(AscendDeviceId device_id_if_unset) { + static std::atomic global_device_id = kDeviceIdUnset; + auto expectedUnset = kDeviceIdUnset; + std::atomic_compare_exchange_strong(&global_device_id, &expectedUnset, device_id_if_unset); + return global_device_id.load(); +} + +} // namespace void initializeVendor() { DIPU_CALLACLRT(aclInit(nullptr)); @@ -46,10 +60,7 @@ void setDevice(deviceId_t device_id) { } auto ascend_device_id = static_cast(device_id); if (g_current_thread_device_id == kDeviceIdUnset) { - static std::atomic g_global_device_id = kDeviceIdUnset; - std::atomic_compare_exchange_strong( - &g_global_device_id, &g_current_thread_device_id, ascend_device_id); - g_current_thread_device_id = g_global_device_id.load(); + g_current_thread_device_id = setOrGetGlobalDeviceId(ascend_device_id); } if (ascend_device_id != g_current_thread_device_id) { DIPU_LOGW(