Skip to content

Commit

Permalink
fix ascend device init (#638)
Browse files Browse the repository at this point in the history
* fix ascend device init

* fix lint

* minor change

* minor change

* add ascend communicate test

* move some defines from header to cpp

* minor change

* add acl fun call track
  • Loading branch information
zhaoguochun1995 authored Jan 17, 2024
1 parent 6b737b2 commit eb6d8e6
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 49 deletions.
2 changes: 1 addition & 1 deletion dipu/tests/run_ascend_tests.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ source tests/common.sh
function run_dipu_tests {
# TODO: Add PyTorch tests
# run_test tests/test_ops/archived/test_tensor_add.py
true
python tests/python/individual_scripts/test_rt_ddp.py
}

if [ "$LOGFILE" != "" ]; then
Expand Down
39 changes: 17 additions & 22 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/basecommimpl.hpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,27 @@

#include <acl/acl.h>
#include <cstring>
#include <unistd.h>

#include <csrc_dipu/common.h>
#include <csrc_dipu/runtime/device/diclapis.h>

#define TRACK_FUN_CALL(TAG, x) \
{ \
static bool enable = std::getenv("DIPU_TRACK_" #TAG) != nullptr; \
if (enable) { \
printf("[%d %s: %d]:%s\n", getpid(), __FILE__, __LINE__, x); \
} \
}

#define DIPU_CALLACLRT(Expr) \
{ \
TRACK_FUN_CALL(ACL, #Expr); \
::aclError ret = Expr; \
TORCH_CHECK(ret == ACL_SUCCESS, "ascend device error, expr = ", #Expr, \
", ret = ", ret, ", error msg = ", aclGetRecentErrMsg()); \
}

namespace dipu {

namespace devapis {
Expand All @@ -26,27 +43,5 @@ struct Map {
}
};

// HCCL ReduceOp mapping
std::map<c10d::ReduceOp, HcclReduceOp> hcclOp = {
{ReduceOp::MIN, HCCL_REDUCE_MIN},
{ReduceOp::MAX, HCCL_REDUCE_MAX},
{ReduceOp::SUM, HCCL_REDUCE_SUM},
{ReduceOp::PRODUCT, HCCL_REDUCE_PROD},
};

bool isPinnedPtr(const void* p) {
TORCH_CHECK(false, "isPinnedPtr not implemented for ascend.\n");
return false;
}

#define HCCL_THROW(cmd) \
do { \
TORCH_CHECK(cmd == HCCL_SUCCESS, \
"HCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ".\n" + \
"And see details in Ascend logs.\n" + \
aclGetRecentErrMsg()); \
} while (0)

} // namespace devapis
} // namespace dipu
19 changes: 19 additions & 0 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/communicatorimpl.cpp
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@

#include "basecommimpl.hpp"

#define HCCL_THROW(cmd) \
do { \
TRACK_FUN_CALL(HCCL, #cmd); \
TORCH_CHECK(cmd == HCCL_SUCCESS, \
"HCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ".\n" + \
"And see details in Ascend logs.\n" + \
aclGetRecentErrMsg()); \
} while (0)

namespace dipu {
namespace devapis {

// HCCL ReduceOp mapping
static std::map<c10d::ReduceOp, HcclReduceOp> hcclOp = {
{ReduceOp::MIN, HCCL_REDUCE_MIN},
{ReduceOp::MAX, HCCL_REDUCE_MAX},
{ReduceOp::SUM, HCCL_REDUCE_SUM},
{ReduceOp::PRODUCT, HCCL_REDUCE_PROD},
};

// HCCL DataType mapping
static constexpr std::array<std::pair<at::ScalarType, HcclDataType>, 9>
hcclDataTypes{{
Expand Down
20 changes: 10 additions & 10 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,21 @@
#include <csrc_dipu/common.h>
#include <csrc_dipu/runtime/device/deviceapis.h>

#include "basecommimpl.hpp"

namespace dipu {

namespace devapis {

void initializeVendor() {}

void finalizeVendor() {}

// =====================
// Device class related
// =====================
using ascend_deviceId = int32_t;
thread_local bool setDevFlag = false;

static int initValue = []() {
DIPU_CALLACLRT(aclInit(nullptr));
DIPU_CALLACLRT(aclrtSetDevice(0));
setDevFlag = true;
return 0;
}();
void initializeVendor() { DIPU_CALLACLRT(aclInit(nullptr)); }

void finalizeVendor() { DIPU_CALLACLRT(aclFinalize()); }

deviceId_t current_device() {
if (setDevFlag == false) {
Expand Down Expand Up @@ -258,5 +253,10 @@ void destroyEvent(deviceEvent_t event) {
DIPU_CALLACLRT(::aclrtDestroyEvent(event))
}

bool isPinnedPtr(const void* p) {
TORCH_CHECK(false, "isPinnedPtr not implemented for ascend.\n");
return false;
}

} // end namespace devapis
} // end namespace dipu
2 changes: 2 additions & 0 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/profilerimpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <csrc_dipu/runtime/device/profilerapis.h>
#include <csrc_dipu/vendor/vendorapi.h>

#include "basecommimpl.hpp"

extern "C" aclError aclprofSetStampTagName(void* stamp, const char* tagName,
uint16_t len);
extern "C" aclError aclprofSetStampTraceMessage(void* stamp, const char* msg,
Expand Down
16 changes: 0 additions & 16 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/vendorapi.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,6 @@

namespace dipu {

#define TRACK_ACL(x) \
{ \
static bool enable = std::getenv("DIPU_TRACK_ACL") != nullptr; \
if (enable) { \
printf("[%s: %d]:%s\n", __FILE__, __LINE__, x); \
} \
}

#define DIPU_CALLACLRT(Expr) \
{ \
TRACK_ACL(#Expr); \
::aclError ret = Expr; \
TORCH_CHECK(ret == ACL_SUCCESS, "ascend device error, expr = ", #Expr, \
", ret = ", ret, ", error msg = ", aclGetRecentErrMsg()); \
}

using deviceStream_t = aclrtStream;
#define deviceDefaultStreamLiteral nullptr;
using deviceEvent_t = aclrtEvent;
Expand Down

0 comments on commit eb6d8e6

Please sign in to comment.