Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dipu,vendor,ascend): simple and thread-safe device management #746

Merged
merged 7 commits into from
Apr 12, 2024
90 changes: 47 additions & 43 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,43 @@ namespace devapis {
// =====================
// Device class related
// =====================
using ascend_deviceId = int32_t;
thread_local int currentDeviceIndex = -1;

using AscendDeviceId = int32_t;

namespace {

constexpr AscendDeviceId kDeviceIdUninit = -1;
constexpr AscendDeviceId kDeviceIdDefault = 0;
// Thread-level cache for process-level device id
thread_local auto process_device_id_thread_cache = kDeviceIdUninit;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
thread_local auto process_device_id_thread_cache = kDeviceIdUninit;
thread_local auto g_process_device_id_thread_cache = kDeviceIdUninit;

这还是个全局

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// Try to initialize process-level device id if it hasnt' been initialized,
// which is designed to be written only once,
// and anyway return the process-level device id
AscendDeviceId tryInitAndAnywayGetProcessDevice(
AscendDeviceId ascend_device_id) {
static std::atomic global_device_id = kDeviceIdUninit;
auto expectedUninit = kDeviceIdUninit;
std::atomic_compare_exchange_strong(&global_device_id, &expectedUninit,
ascend_device_id);
return global_device_id.load();
}

// If thread-level cache of process-level device id hasn't been initialized,
// try to initialize thread-level cache, which is designed to be written once.
// If process-level device id hasn't been initialized,
// try to initialize process-level device id first using input ascend_device_id.
void tryInitThreadDeviceCache(AscendDeviceId ascend_device_id) {
if (process_device_id_thread_cache == kDeviceIdUninit) {
// Ensure that aclrtSetDevice is called
// when and only when the thread-level cache is initialized.
process_device_id_thread_cache =
tryInitAndAnywayGetProcessDevice(ascend_device_id);
DIPU_CALLACLRT(::aclrtSetDevice(process_device_id_thread_cache))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set 不同的会有什么后果,会不会导致崩溃什么的,要不要阻止?

Copy link
Collaborator Author

@lljbash lljbash Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能看错了,你检查下就好

Copy link
Collaborator

@jfxu-st jfxu-st Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set 不同的会有什么后果,会不会导致崩溃什么的,要不要阻止?

根据 @zhaoguochun1995 在群里发的聊天记录截图,torch_npu 也是这样做的,无声无息地让第二次及以后的 set 不起效果,甚至连 warning 都没有……

所以这里的 if 检查如果失败了(已经被 set 过了),就什么都不做。我们比 torch_npu 多的就是如果这次 tryInitThreadDeviceCache() 的调用是通过显式调用 setDevice() 发起的,那么就报一个 warning。

这个问题后面 @jingguo-st 会进一步确认。

}
}

} // namespace

void initializeVendor() {
DIPU_CALLACLRT(aclInit(nullptr));
Expand All @@ -27,56 +62,25 @@ void initializeVendor() {
void finalizeVendor() { DIPU_CALLACLRT(aclFinalize()); }

deviceId_t current_device() {
if (currentDeviceIndex < 0) {
setDevice(-1);
DIPU_CALLACLRT(::aclrtGetDevice(&currentDeviceIndex))
if (process_device_id_thread_cache == kDeviceIdUninit) {
DIPU_LOGW("current_device() is called before setDevice()");
tryInitThreadDeviceCache(kDeviceIdDefault);
}
return static_cast<deviceId_t>(currentDeviceIndex);
}

int defaultDeviceIndex = -1;
std::atomic_int defaultDeviceIndexAtomic(-1);

void setDefalutDevice(int index) {
defaultDeviceIndexAtomic = index;
defaultDeviceIndex = index;
return static_cast<deviceId_t>(process_device_id_thread_cache);
}

// set current device given device according to id
void setDevice(deviceId_t devId) {
// In order to reduce performance loss, try to reduce the number of reads and
// writes of atomic variables.
// Atomic variables will only be manipulated when starting up.
// In most other cases, reading and writing atomic variables is no longer
// required. This function is called extremely frequently.
if (devId < 0) {
if (defaultDeviceIndex < 0) {
if (defaultDeviceIndexAtomic < 0) {
setDefalutDevice(0);
}
}
devId = defaultDeviceIndexAtomic;
} else {
if (defaultDeviceIndex < 0) {
if (defaultDeviceIndexAtomic < 0) {
setDefalutDevice(devId);
}
}
}
if (devId != defaultDeviceIndex) {
TORCH_WARN_ONCE(
void setDevice(deviceId_t device_id) {
auto ascend_device_id = static_cast<AscendDeviceId>(device_id);
tryInitThreadDeviceCache(ascend_device_id);
if (ascend_device_id != process_device_id_thread_cache) {
DIPU_LOGW(
"Trying to use multiple cards in the same process may cause unexpected "
"results in hccl communication, such as sdma memory copy failure");
} else {
ascend_deviceId devId_ = static_cast<deviceId_t>(devId);
if (devId_ != currentDeviceIndex) {
DIPU_CALLACLRT(::aclrtSetDevice(devId_))
currentDeviceIndex = devId_;
}
}
}

DIPUDeviceProperties getDeviceProperties(int32_t device_index) {
DIPUDeviceProperties getDeviceProperties(AscendDeviceId device_index) {
const char* device_name;
size_t device_free;
size_t device_total;
Expand Down
Loading