Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dipu,vendor,ascend): simple and thread-safe device management #746

Merged
merged 7 commits into from
Apr 12, 2024
90 changes: 47 additions & 43 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,43 @@ namespace devapis {
// =====================
// Device class related
// =====================
using ascend_deviceId = int32_t;
thread_local int currentDeviceIndex = -1;

using AscendDeviceId = int32_t;

namespace {

constexpr AscendDeviceId kDeviceIdUninit = -1;
constexpr AscendDeviceId kDeviceIdDefault = 0;
// Thread-level cache for process-level device id
thread_local auto g_process_device_id_thread_cache = kDeviceIdUninit;

// Try to initialize process-level device id if it hasnt' been initialized,
// which is designed to be written only once,
// and anyway return the process-level device id
AscendDeviceId tryInitAndAnywayGetProcessDevice(
AscendDeviceId ascend_device_id) {
static std::atomic process_device_id = kDeviceIdUninit;
auto expectedUninit = kDeviceIdUninit;
std::atomic_compare_exchange_strong(&process_device_id, &expectedUninit,
ascend_device_id);
return process_device_id.load();
}

// If thread-level cache of process-level device id hasn't been initialized,
// try to initialize thread-level cache, which is designed to be written once.
// If process-level device id hasn't been initialized,
// try to initialize process-level device id first using input ascend_device_id.
void tryInitThreadDeviceCache(AscendDeviceId ascend_device_id) {
if (g_process_device_id_thread_cache == kDeviceIdUninit) {
// Ensure that aclrtSetDevice is called
// when and only when the thread-level cache is initialized.
g_process_device_id_thread_cache =
tryInitAndAnywayGetProcessDevice(ascend_device_id);
DIPU_CALLACLRT(::aclrtSetDevice(g_process_device_id_thread_cache))
}
}

} // namespace

void initializeVendor() {
DIPU_CALLACLRT(aclInit(nullptr));
Expand All @@ -27,56 +62,25 @@ void initializeVendor() {
void finalizeVendor() { DIPU_CALLACLRT(aclFinalize()); }

deviceId_t current_device() {
if (currentDeviceIndex < 0) {
setDevice(-1);
DIPU_CALLACLRT(::aclrtGetDevice(&currentDeviceIndex))
if (g_process_device_id_thread_cache == kDeviceIdUninit) {
DIPU_LOGW("current_device() is called before setDevice()");
tryInitThreadDeviceCache(kDeviceIdDefault);
}
return static_cast<deviceId_t>(currentDeviceIndex);
}

int defaultDeviceIndex = -1;
std::atomic_int defaultDeviceIndexAtomic(-1);

void setDefalutDevice(int index) {
defaultDeviceIndexAtomic = index;
defaultDeviceIndex = index;
return static_cast<deviceId_t>(g_process_device_id_thread_cache);
}

// set current device given device according to id
void setDevice(deviceId_t devId) {
// In order to reduce performance loss, try to reduce the number of reads and
// writes of atomic variables.
// Atomic variables will only be manipulated when starting up.
// In most other cases, reading and writing atomic variables is no longer
// required. This function is called extremely frequently.
if (devId < 0) {
if (defaultDeviceIndex < 0) {
if (defaultDeviceIndexAtomic < 0) {
setDefalutDevice(0);
}
}
devId = defaultDeviceIndexAtomic;
} else {
if (defaultDeviceIndex < 0) {
if (defaultDeviceIndexAtomic < 0) {
setDefalutDevice(devId);
}
}
}
if (devId != defaultDeviceIndex) {
TORCH_WARN_ONCE(
void setDevice(deviceId_t device_id) {
auto ascend_device_id = static_cast<AscendDeviceId>(device_id);
tryInitThreadDeviceCache(ascend_device_id);
if (ascend_device_id != g_process_device_id_thread_cache) {
DIPU_LOGW(
"Trying to use multiple cards in the same process may cause unexpected "
"results in hccl communication, such as sdma memory copy failure");
} else {
ascend_deviceId devId_ = static_cast<deviceId_t>(devId);
if (devId_ != currentDeviceIndex) {
DIPU_CALLACLRT(::aclrtSetDevice(devId_))
currentDeviceIndex = devId_;
}
}
}

DIPUDeviceProperties getDeviceProperties(int32_t device_index) {
DIPUDeviceProperties getDeviceProperties(AscendDeviceId device_index) {
const char* device_name;
size_t device_free;
size_t device_total;
Expand Down