Skip to content

Commit

Permalink
fix(dipu,vendor,ascend): simple and thread-safe device management
Browse files Browse the repository at this point in the history
  • Loading branch information
lljbash committed Mar 25, 2024
1 parent 2ce08a7 commit cb27397
Showing 1 changed file with 28 additions and 16 deletions.
44 changes: 28 additions & 16 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <acl/acl.h>
#include <acl/acl_op.h>
#include <acl/acl_op_compiler.h>
#include <atomic>

#include <csrc_dipu/common.h>
#include <csrc_dipu/runtime/device/deviceapis.h>
Expand All @@ -15,8 +16,10 @@ namespace devapis {
// =====================
// Device class related
// =====================
using ascend_deviceId = int32_t;
thread_local bool setDevFlag = false;
using AscendDeviceId = int32_t;
constexpr AscendDeviceId kDeviceIdUnset = -1;
constexpr AscendDeviceId kDeviceIdDefault = 0;
static thread_local auto g_current_thread_device_id = kDeviceIdUnset;

void initializeVendor() {
DIPU_CALLACLRT(aclInit(nullptr));
Expand All @@ -26,15 +29,31 @@ void initializeVendor() {
void finalizeVendor() { DIPU_CALLACLRT(aclFinalize()); }

deviceId_t current_device() {
if (setDevFlag == false) {
DIPU_CALLACLRT(aclrtSetDevice(0));
setDevFlag = true;
if (g_current_thread_device_id == kDeviceIdUnset) {
DIPU_CALLACLRT(aclrtSetDevice(kDeviceIdDefault));
}
ascend_deviceId devId_;
DIPU_CALLACLRT(::aclrtGetDevice(&devId_))
return static_cast<deviceId_t>(devId_);
return static_cast<deviceId_t>(g_current_thread_device_id);
}
DIPUDeviceProperties getDeviceProperties(int32_t device_index) {

// set current device given device according to id
void setDevice(deviceId_t device_id) {
auto ascend_device_id = static_cast<AscendDeviceId>(device_id);
if (g_current_thread_device_id == kDeviceIdUnset) {
static std::atomic g_global_device_id = kDeviceIdUnset;
std::atomic_compare_exchange_strong(
&g_global_device_id, &g_current_thread_device_id, ascend_device_id);
g_current_thread_device_id = g_global_device_id.load();
}
if (ascend_device_id != g_current_thread_device_id) {
TORCH_WARN_ONCE(
"Trying to use multiple cards in the same process may cause unexpected "
"results in hccl communication, such as sdma memory copy failure");
return;
}
DIPU_CALLACLRT(::aclrtSetDevice(ascend_device_id))
}

DIPUDeviceProperties getDeviceProperties(AscendDeviceId device_index) {
const char* device_name;
size_t device_free;
size_t device_total;
Expand All @@ -55,13 +74,6 @@ DIPUDeviceProperties getDeviceProperties(int32_t device_index) {
return prop;
}

// set current device given device according to id
void setDevice(deviceId_t devId) {
ascend_deviceId devId_ = static_cast<deviceId_t>(devId);
DIPU_CALLACLRT(::aclrtSetDevice(devId_))
setDevFlag = true;
}

void resetDevice(deviceId_t devId) { DIPU_CALLACLRT(::aclrtResetDevice(devId)) }

void syncDevice() { DIPU_CALLACLRT(::aclrtSynchronizeDevice()) }
Expand Down

0 comments on commit cb27397

Please sign in to comment.