Skip to content

Commit

Permalink
fix(dipu,vendor,ascend): simple and thread-safe device management (#746)
Browse files Browse the repository at this point in the history
* fix(dipu,vendor,ascend): simple and thread-safe device management

* refactor: make code more readable by extracting a function

* fix: set device id in current_device() goes wrong

* refactor and fix a bug for a new thread calling current_device before setDevice

* rename process_device_id_thread_cache to g_process_device_id_thread_cache

* rename global_device_id to process_device_id

---------

Co-authored-by: jfxu-st <xujinfan@sensetime.com>
  • Loading branch information
lljbash and jfxu-st authored Apr 12, 2024
1 parent 27252cf commit 2f9cb9f
Showing 1 changed file with 47 additions and 43 deletions.
90 changes: 47 additions & 43 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,43 @@ namespace devapis {
// =====================
// Device class related
// =====================
using ascend_deviceId = int32_t;
thread_local int currentDeviceIndex = -1;

using AscendDeviceId = int32_t;

namespace {

constexpr AscendDeviceId kDeviceIdUninit = -1;
constexpr AscendDeviceId kDeviceIdDefault = 0;
// Thread-level cache for process-level device id
thread_local auto g_process_device_id_thread_cache = kDeviceIdUninit;

// Try to initialize process-level device id if it hasnt' been initialized,
// which is designed to be written only once,
// and anyway return the process-level device id
AscendDeviceId tryInitAndAnywayGetProcessDevice(
AscendDeviceId ascend_device_id) {
static std::atomic process_device_id = kDeviceIdUninit;
auto expectedUninit = kDeviceIdUninit;
std::atomic_compare_exchange_strong(&process_device_id, &expectedUninit,
ascend_device_id);
return process_device_id.load();
}

// If thread-level cache of process-level device id hasn't been initialized,
// try to initialize thread-level cache, which is designed to be written once.
// If process-level device id hasn't been initialized,
// try to initialize process-level device id first using input ascend_device_id.
void tryInitThreadDeviceCache(AscendDeviceId ascend_device_id) {
if (g_process_device_id_thread_cache == kDeviceIdUninit) {
// Ensure that aclrtSetDevice is called
// when and only when the thread-level cache is initialized.
g_process_device_id_thread_cache =
tryInitAndAnywayGetProcessDevice(ascend_device_id);
DIPU_CALLACLRT(::aclrtSetDevice(g_process_device_id_thread_cache))
}
}

} // namespace

void initializeVendor() {
DIPU_CALLACLRT(aclInit(nullptr));
Expand All @@ -27,56 +62,25 @@ void initializeVendor() {
void finalizeVendor() { DIPU_CALLACLRT(aclFinalize()); }

deviceId_t current_device() {
if (currentDeviceIndex < 0) {
setDevice(-1);
DIPU_CALLACLRT(::aclrtGetDevice(&currentDeviceIndex))
if (g_process_device_id_thread_cache == kDeviceIdUninit) {
DIPU_LOGW("current_device() is called before setDevice()");
tryInitThreadDeviceCache(kDeviceIdDefault);
}
return static_cast<deviceId_t>(currentDeviceIndex);
}

int defaultDeviceIndex = -1;
std::atomic_int defaultDeviceIndexAtomic(-1);

void setDefalutDevice(int index) {
defaultDeviceIndexAtomic = index;
defaultDeviceIndex = index;
return static_cast<deviceId_t>(g_process_device_id_thread_cache);
}

// set current device given device according to id
void setDevice(deviceId_t devId) {
// In order to reduce performance loss, try to reduce the number of reads and
// writes of atomic variables.
// Atomic variables will only be manipulated when starting up.
// In most other cases, reading and writing atomic variables is no longer
// required. This function is called extremely frequently.
if (devId < 0) {
if (defaultDeviceIndex < 0) {
if (defaultDeviceIndexAtomic < 0) {
setDefalutDevice(0);
}
}
devId = defaultDeviceIndexAtomic;
} else {
if (defaultDeviceIndex < 0) {
if (defaultDeviceIndexAtomic < 0) {
setDefalutDevice(devId);
}
}
}
if (devId != defaultDeviceIndex) {
TORCH_WARN_ONCE(
void setDevice(deviceId_t device_id) {
auto ascend_device_id = static_cast<AscendDeviceId>(device_id);
tryInitThreadDeviceCache(ascend_device_id);
if (ascend_device_id != g_process_device_id_thread_cache) {
DIPU_LOGW(
"Trying to use multiple cards in the same process may cause unexpected "
"results in hccl communication, such as sdma memory copy failure");
} else {
ascend_deviceId devId_ = static_cast<deviceId_t>(devId);
if (devId_ != currentDeviceIndex) {
DIPU_CALLACLRT(::aclrtSetDevice(devId_))
currentDeviceIndex = devId_;
}
}
}

DIPUDeviceProperties getDeviceProperties(int32_t device_index) {
DIPUDeviceProperties getDeviceProperties(AscendDeviceId device_index) {
const char* device_name;
size_t device_free;
size_t device_total;
Expand Down

0 comments on commit 2f9cb9f

Please sign in to comment.