Skip to content

Commit

Permalink
refactor: make code more readable by extracting a function
Browse files Browse the repository at this point in the history
  • Loading branch information
lljbash committed Mar 25, 2024
1 parent 7cbb754 commit 35ccc9b
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,24 @@ namespace devapis {
// Device class related
// =====================
using AscendDeviceId = int32_t;

namespace {

constexpr AscendDeviceId kDeviceIdUnset = -1;
constexpr AscendDeviceId kDeviceIdDefault = 0;
static thread_local auto g_current_thread_device_id = kDeviceIdUnset;
thread_local auto g_current_thread_device_id = kDeviceIdUnset;

// atomically set global device id if it is unset
// and anyway return the global device id
AscendDeviceId setOrGetGlobalDeviceId(AscendDeviceId device_id_if_unset) {
static std::atomic global_device_id = kDeviceIdUnset;
auto expectedUnset = kDeviceIdUnset;
std::atomic_compare_exchange_strong(&global_device_id, &expectedUnset,
device_id_if_unset);
return global_device_id.load();
}

} // namespace

void initializeVendor() {
DIPU_CALLACLRT(aclInit(nullptr));
Expand All @@ -46,10 +61,7 @@ void setDevice(deviceId_t device_id) {
}
auto ascend_device_id = static_cast<AscendDeviceId>(device_id);
if (g_current_thread_device_id == kDeviceIdUnset) {
static std::atomic g_global_device_id = kDeviceIdUnset;
std::atomic_compare_exchange_strong(
&g_global_device_id, &g_current_thread_device_id, ascend_device_id);
g_current_thread_device_id = g_global_device_id.load();
g_current_thread_device_id = setOrGetGlobalDeviceId(ascend_device_id);
}
if (ascend_device_id != g_current_thread_device_id) {
DIPU_LOGW(
Expand Down

0 comments on commit 35ccc9b

Please sign in to comment.