-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(dipu,vendor,ascend): simple and thread-safe device management #746
Changes from 4 commits
7cbb754
35ccc9b
194dff7
09c8b31
3ec7b90
d071a93
45f5a8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -16,8 +16,27 @@ namespace devapis { | |||||||||||||||
// ===================== | ||||||||||||||||
// Device class related | ||||||||||||||||
// ===================== | ||||||||||||||||
using ascend_deviceId = int32_t; | ||||||||||||||||
thread_local int currentDeviceIndex = -1; | ||||||||||||||||
|
||||||||||||||||
using AscendDeviceId = int32_t; | ||||||||||||||||
|
||||||||||||||||
namespace { | ||||||||||||||||
|
||||||||||||||||
constexpr AscendDeviceId kDeviceIdUninit = -1; | ||||||||||||||||
constexpr AscendDeviceId kDeviceIdDefault = 0; | ||||||||||||||||
thread_local auto g_current_thread_device_id = kDeviceIdUninit; | ||||||||||||||||
|
||||||||||||||||
// atomically set global device id if it is uninit | ||||||||||||||||
// and anyway return the global device id | ||||||||||||||||
AscendDeviceId initOnceAndGetGlobalDeviceId( | ||||||||||||||||
AscendDeviceId device_id_if_uninit) { | ||||||||||||||||
static std::atomic global_device_id = kDeviceIdUninit; | ||||||||||||||||
auto expectedUninit = kDeviceIdUninit; | ||||||||||||||||
std::atomic_compare_exchange_strong(&global_device_id, &expectedUninit, | ||||||||||||||||
device_id_if_uninit); | ||||||||||||||||
return global_device_id.load(); | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
} // namespace | ||||||||||||||||
|
||||||||||||||||
void initializeVendor() { | ||||||||||||||||
DIPU_CALLACLRT(aclInit(nullptr)); | ||||||||||||||||
|
@@ -27,56 +46,35 @@ void initializeVendor() { | |||||||||||||||
void finalizeVendor() { DIPU_CALLACLRT(aclFinalize()); } | ||||||||||||||||
|
||||||||||||||||
deviceId_t current_device() { | ||||||||||||||||
if (currentDeviceIndex < 0) { | ||||||||||||||||
setDevice(-1); | ||||||||||||||||
DIPU_CALLACLRT(::aclrtGetDevice(¤tDeviceIndex)) | ||||||||||||||||
if (g_current_thread_device_id == kDeviceIdUninit) { | ||||||||||||||||
DIPU_LOGW( | ||||||||||||||||
"current_device() is called before setDevice(). Setting device to " | ||||||||||||||||
"default device."); | ||||||||||||||||
setDevice(kDeviceIdDefault); | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 当 怀疑上面这种没有处理好的情况就是导致跑模型跑着跑着会卡死的原因。 可以简单地改成下面几行来解决这个问题(这里删除了部分 warning 的原因:如果
Suggested change
如果觉得上面这种修改和 auto ascend_device_id = static_cast<AscendDeviceId>(device_id);
if (g_current_thread_device_id == kDeviceIdUninit) {
g_current_thread_device_id = initOnceAndGetGlobalDeviceId(ascend_device_id);
DIPU_CALLACLRT(::aclrtSetDevice(ascend_device_id))
} 然后在 if (ascend_device_id != g_current_thread_device_id) {
DIPU_LOGW(
"Trying to use multiple cards in the same process may cause unexpected "
"results in hccl communication, such as sdma memory copy failure");
} 无论哪一种改法,核心逻辑应当是当且仅当 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 后面那个不等价了。我觉得给 setDevice 多处理一种情况可能比较合适。定义 kDeviceIdDefault = -2 之类的。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
等价的。贴的代码里有一行我写错了,不应该是
增加 -2 之类的特殊值定义的话就又回到老路上去了, |
||||||||||||||||
} | ||||||||||||||||
return static_cast<deviceId_t>(currentDeviceIndex); | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
int defaultDeviceIndex = -1; | ||||||||||||||||
std::atomic_int defaultDeviceIndexAtomic(-1); | ||||||||||||||||
|
||||||||||||||||
void setDefalutDevice(int index) { | ||||||||||||||||
defaultDeviceIndexAtomic = index; | ||||||||||||||||
defaultDeviceIndex = index; | ||||||||||||||||
return static_cast<deviceId_t>(g_current_thread_device_id); | ||||||||||||||||
lljbash marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
// set current device given device according to id | ||||||||||||||||
void setDevice(deviceId_t devId) { | ||||||||||||||||
// In order to reduce performance loss, try to reduce the number of reads and | ||||||||||||||||
// writes of atomic variables. | ||||||||||||||||
// Atomic variables will only be manipulated when starting up. | ||||||||||||||||
// In most other cases, reading and writing atomic variables is no longer | ||||||||||||||||
// required. This function is called extremely frequently. | ||||||||||||||||
if (devId < 0) { | ||||||||||||||||
if (defaultDeviceIndex < 0) { | ||||||||||||||||
if (defaultDeviceIndexAtomic < 0) { | ||||||||||||||||
setDefalutDevice(0); | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
devId = defaultDeviceIndexAtomic; | ||||||||||||||||
} else { | ||||||||||||||||
if (defaultDeviceIndex < 0) { | ||||||||||||||||
if (defaultDeviceIndexAtomic < 0) { | ||||||||||||||||
setDefalutDevice(devId); | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
void setDevice(deviceId_t device_id) { | ||||||||||||||||
if (device_id < 0) { | ||||||||||||||||
DIPU_LOGW("Requested device id is invalid. Ignoring the request."); | ||||||||||||||||
return; | ||||||||||||||||
} | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据 @fandaoyi 在群里的说法,“上层不会有传 -1 的 case, 传了报错就好, 就算处理也应该是上层处理, 这里不需要处理”。
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已删。 |
||||||||||||||||
if (devId != defaultDeviceIndex) { | ||||||||||||||||
TORCH_WARN_ONCE( | ||||||||||||||||
auto ascend_device_id = static_cast<AscendDeviceId>(device_id); | ||||||||||||||||
if (g_current_thread_device_id == kDeviceIdUninit) { | ||||||||||||||||
g_current_thread_device_id = initOnceAndGetGlobalDeviceId(ascend_device_id); | ||||||||||||||||
} | ||||||||||||||||
if (ascend_device_id != g_current_thread_device_id) { | ||||||||||||||||
DIPU_LOGW( | ||||||||||||||||
"Trying to use multiple cards in the same process may cause unexpected " | ||||||||||||||||
"results in hccl communication, such as sdma memory copy failure"); | ||||||||||||||||
} else { | ||||||||||||||||
ascend_deviceId devId_ = static_cast<deviceId_t>(devId); | ||||||||||||||||
if (devId_ != currentDeviceIndex) { | ||||||||||||||||
DIPU_CALLACLRT(::aclrtSetDevice(devId_)) | ||||||||||||||||
currentDeviceIndex = devId_; | ||||||||||||||||
} | ||||||||||||||||
return; | ||||||||||||||||
} | ||||||||||||||||
DIPU_CALLACLRT(::aclrtSetDevice(ascend_device_id)) | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
DIPUDeviceProperties getDeviceProperties(int32_t device_index) { | ||||||||||||||||
DIPUDeviceProperties getDeviceProperties(AscendDeviceId device_index) { | ||||||||||||||||
const char* device_name; | ||||||||||||||||
size_t device_free; | ||||||||||||||||
size_t device_total; | ||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个
device_id_if_uninit
名字是啥意思?可以直接就叫device_id
吗?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
就是 uninit 的时候会设置的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉在参数名后面加 if 还是怪怪的。保持参数名
device_id
,把 if uninit 加到函数名里去,改成initOnceIfUninitAndGetGlobalDeviceId
这样会不会更合适一点?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我现在改成了
AscendDeviceId tryInitAndAnywayGetProcessDevice(AscendDeviceId ascend_device_id)
。