Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dipu,vendor,ascend): simple and thread-safe device management #746

Merged
merged 7 commits into from
Apr 12, 2024
63 changes: 47 additions & 16 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <acl/acl.h>
#include <acl/acl_op.h>
#include <acl/acl_op_compiler.h>
#include <atomic>

#include <csrc_dipu/common.h>
#include <csrc_dipu/runtime/device/deviceapis.h>
Expand All @@ -15,8 +16,25 @@ namespace devapis {
// =====================
// Device class related
// =====================
using ascend_deviceId = int32_t;
thread_local bool setDevFlag = false;
using AscendDeviceId = int32_t;

namespace {

constexpr AscendDeviceId kDeviceIdUnset = -1;
constexpr AscendDeviceId kDeviceIdDefault = 0;
thread_local auto g_current_thread_device_id = kDeviceIdUnset;

// atomically set global device id if it is unset
// and anyway return the global device id
AscendDeviceId setOrGetGlobalDeviceId(AscendDeviceId device_id_if_unset) {
static std::atomic global_device_id = kDeviceIdUnset;
auto expectedUnset = kDeviceIdUnset;
std::atomic_compare_exchange_strong(&global_device_id, &expectedUnset,
device_id_if_unset);
return global_device_id.load();
}

} // namespace

void initializeVendor() {
DIPU_CALLACLRT(aclInit(nullptr));
Expand All @@ -26,15 +44,35 @@ void initializeVendor() {
void finalizeVendor() { DIPU_CALLACLRT(aclFinalize()); }

deviceId_t current_device() {
if (setDevFlag == false) {
DIPU_CALLACLRT(aclrtSetDevice(0));
setDevFlag = true;
if (g_current_thread_device_id == kDeviceIdUnset) {
DIPU_LOGW(
"current_device() is called before setDevice(). Setting device to "
"default device.");
DIPU_CALLACLRT(aclrtSetDevice(kDeviceIdDefault));
}
ascend_deviceId devId_;
DIPU_CALLACLRT(::aclrtGetDevice(&devId_))
return static_cast<deviceId_t>(devId_);
return static_cast<deviceId_t>(g_current_thread_device_id);
}
DIPUDeviceProperties getDeviceProperties(int32_t device_index) {

// set current device given device according to id
void setDevice(deviceId_t device_id) {
if (device_id < 0) {
DIPU_LOGW("Requested device id is invalid. Ignoring the request.");
return;
}
auto ascend_device_id = static_cast<AscendDeviceId>(device_id);
if (g_current_thread_device_id == kDeviceIdUnset) {
g_current_thread_device_id = setOrGetGlobalDeviceId(ascend_device_id);
}
if (ascend_device_id != g_current_thread_device_id) {
DIPU_LOGW(
"Trying to use multiple cards in the same process may cause unexpected "
"results in hccl communication, such as sdma memory copy failure");
return;
}
DIPU_CALLACLRT(::aclrtSetDevice(ascend_device_id))
}

DIPUDeviceProperties getDeviceProperties(AscendDeviceId device_index) {
const char* device_name;
size_t device_free;
size_t device_total;
Expand All @@ -55,13 +93,6 @@ DIPUDeviceProperties getDeviceProperties(int32_t device_index) {
return prop;
}

// set current device given device according to id
void setDevice(deviceId_t devId) {
ascend_deviceId devId_ = static_cast<deviceId_t>(devId);
DIPU_CALLACLRT(::aclrtSetDevice(devId_))
setDevFlag = true;
}

void resetDevice(deviceId_t devId) { DIPU_CALLACLRT(::aclrtResetDevice(devId)) }

void syncDevice() { DIPU_CALLACLRT(::aclrtSynchronizeDevice()) }
Expand Down
Loading