Skip to content

Commit

Permalink
[DIPU] Replace the default (memCopyAsync + Sync) strategy with (Sync …
Browse files Browse the repository at this point in the history
…+ memCopySync) for direct memory copy on Ascend (#730)

* replace the default (memCopyAsync + Sync) strategy with (Sync + memCopySync) for direct memory copy on Ascend

* update _local_scalar_dense_dipu to use (Sync + memCopySync) strategy for Ascend

* refactor memCopy in DIPUCopy.hpp
  • Loading branch information
jfxu-st authored Mar 19, 2024
1 parent d15e882 commit 621134d
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 40 deletions.
6 changes: 6 additions & 0 deletions dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,15 @@ at::Scalar _local_scalar_dense_dipu(const at::Tensor& self) {
scalar_t value;
dipu::DIPUStream stream = dipu::getCurrentDIPUStream();
MemChecker::instance().check(self);
#if DIPU_VENDOR_NAME_ASCEND
dipu::devproxy::syncStream(stream.rawstream());
dipu::devproxy::memCopyD2H(sizeof(scalar_t), &value,
self.data_ptr<scalar_t>());
#else
dipu::devproxy::memCopyD2HAsync(stream.rawstream(), sizeof(scalar_t),
&value, self.data_ptr<scalar_t>());
dipu::devproxy::syncStream(stream.rawstream());
#endif
r = at::Scalar(value);
});
return r;
Expand Down
93 changes: 63 additions & 30 deletions dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,57 +87,69 @@ inline int64_t getMemCopyBytes(const at::Tensor& dst, const at::Tensor& src,
return std::min(srcBytes, dstBytes);
}

inline void memCopyH2D(const at::Tensor& dst, const at::Tensor& src,
dipu::DIPUStream& stream, int64_t nbytes) {
inline void doMemCopyH2D(const at::Tensor& dst, const at::Tensor& src,
dipu::DIPUStream& stream, int64_t nbytes,
bool isSynchronousCopy) {
void* src_ptr = src.data_ptr();
void* dst_ptr = dst.data_ptr();

MemChecker::instance().check(dst);
dipu::devproxy::memCopyH2DAsync(stream.rawstream(), nbytes, dst_ptr, src_ptr);
if (isSynchronousCopy) {
dipu::devproxy::memCopyH2D(nbytes, dst_ptr, src_ptr);
} else {
dipu::devproxy::memCopyH2DAsync(stream.rawstream(), nbytes, dst_ptr,
src_ptr);
}
}

inline void memCopyD2H(const at::Tensor& dst, const at::Tensor& src,
dipu::DIPUStream& stream, int64_t nbytes) {
inline void doMemCopyD2H(const at::Tensor& dst, const at::Tensor& src,
dipu::DIPUStream& stream, int64_t nbytes,
bool isSynchronousCopy) {
void* src_ptr = src.data_ptr();
void* dst_ptr = dst.data_ptr();

MemChecker::instance().check(src);
dipu::devproxy::memCopyD2HAsync(stream.rawstream(), nbytes, dst_ptr, src_ptr);
if (isSynchronousCopy) {
dipu::devproxy::memCopyD2H(nbytes, dst_ptr, src_ptr);
} else {
dipu::devproxy::memCopyD2HAsync(stream.rawstream(), nbytes, dst_ptr,
src_ptr);
}
}

inline void memCopyD2D(const at::Tensor& dst, const at::Tensor& src,
dipu::DIPUStream& stream, int64_t nbytes) {
inline void doMemCopyD2D(const at::Tensor& dst, const at::Tensor& src,
dipu::DIPUStream& stream, int64_t nbytes,
bool isSynchronousCopy) {
void* src_ptr = src.data_ptr();
void* dst_ptr = dst.data_ptr();

MemChecker::instance().check(src);
MemChecker::instance().check(dst);
dipu::devproxy::memCopyD2DAsync(stream.rawstream(), nbytes,
dst.device().index(), dst_ptr,
src.device().index(), src_ptr);
if (isSynchronousCopy) {
dipu::devproxy::memCopyD2D(nbytes, dst.device().index(), dst_ptr,
src.device().index(), src_ptr);
} else {
dipu::devproxy::memCopyD2DAsync(stream.rawstream(), nbytes,
dst.device().index(), dst_ptr,
src.device().index(), src_ptr);
}
}

inline void memCopy(const at::Tensor& dst, const at::Tensor& src,
dipu::DIPUStream& stream, DIPUCopyType copyType,
bool needMemCpSync, bool nonOverlappingAndDense) {
bool nonOverlappingAndDense, bool isSynchronousCopy) {
int64_t nbytes = getMemCopyBytes(dst, src, nonOverlappingAndDense);
switch (copyType) {
case DIPUCopyType::H2D:
// src is cpu.
memCopyH2D(dst, src, stream, nbytes);
doMemCopyH2D(dst, src, stream, nbytes, isSynchronousCopy);
break;
case DIPUCopyType::D2H:
// dst is cpu.
memCopyD2H(dst, src, stream, nbytes);
doMemCopyD2H(dst, src, stream, nbytes, isSynchronousCopy);
break;
default: // device to device
memCopyD2D(dst, src, stream, nbytes);
}
// this sync is different with copy_ non_blocking, it's used inside one copy
// op when doing a intermidiate cpu copy after some stream op to guarantee the
// cpu copy get correct data.
if (needMemCpSync) {
dipu::devproxy::syncStream(stream.rawstream());
doMemCopyD2D(dst, src, stream, nbytes, isSynchronousCopy);
}
}

Expand Down Expand Up @@ -225,7 +237,7 @@ class DIPUCopyInplace : public DIPUCopyBase {
<< std::endl;
}

copyPreProcess(dst, src, non_blocking, curStream);
copyPreProcess(dst, src, non_blocking, info);

copyAll(dst, src, non_blocking, info);

Expand All @@ -234,12 +246,13 @@ class DIPUCopyInplace : public DIPUCopyBase {

protected:
virtual void copyPreProcess(const at::Tensor& dst, const at::Tensor& src,
bool non_blocking, DIPUStream& curStream) {
bool non_blocking, CopyParamsInfo& info) {
// recordBeforeCopy
if (non_blocking) {
const bool is_default_stream = dipu::getDefaultDIPUStream() == curStream;
tryRecordStream(dst, curStream, is_default_stream);
tryRecordStream(src, curStream, is_default_stream);
const bool is_default_stream =
dipu::getDefaultDIPUStream() == info.curStream_;
tryRecordStream(dst, info.curStream_, is_default_stream);
tryRecordStream(src, info.curStream_, is_default_stream);
}
}

Expand All @@ -264,7 +277,13 @@ class DIPUCopyInplace : public DIPUCopyBase {
if (dst.is_view() && src.is_view()) {
TORCH_CHECK(false, "doDirectMemFill cannot support all view-view copy");
}
memCopy(dst, src, curStream, copyType, needMemCpSync, false);

memCopy(dst, src, curStream, copyType, /*nonOverlappingAndDense=*/false,
/*isSynchronousCopy=*/false);

if (needMemCpSync) {
dipu::devproxy::syncStream(curStream.rawstream());
}
}

// support mem copy between 2 nonOverlappingAndDense tensor with same stride
Expand All @@ -276,7 +295,12 @@ class DIPUCopyInplace : public DIPUCopyBase {
printf("--%-50s %-30s \n", "[copy_]:", "doDirectMemCopy");
}

memCopy(dst, src, curStream, copyType, needMemCpSync, true);
memCopy(dst, src, curStream, copyType, /*nonOverlappingAndDense=*/true,
/*isSynchronousCopy=*/false);

if (needMemCpSync) {
dipu::devproxy::syncStream(curStream.rawstream());
}
}

at::Tensor makeSameStrideTensor(const at::Tensor& src, DIPUStream& curStream,
Expand Down Expand Up @@ -517,6 +541,16 @@ class DIPUCopyInplace : public DIPUCopyBase {
doCpuRelayCopy(dst, src, info.curStream_, non_blocking);
}

// This virtual method, which is simply a wrapper of doDirectMemCopy by
// default, is only used in copyAll. It keeps all original information
// including non_blocking, and is thus suitable for overriding on different
// devices for more control of the direct memory copy process
virtual void directMemCopy(at::Tensor& dst, const at::Tensor& src,
CopyParamsInfo& info, bool non_blocking) {
doDirectMemCopy(dst, src, info.curStream_, info.copyType_,
/*needMemCpSync=*/false);
}

// overriding this func is possible but not recommended
virtual void copyAll(at::Tensor& dst, const at::Tensor& src,
bool non_blocking, CopyParamsInfo& info) {
Expand All @@ -526,8 +560,7 @@ class DIPUCopyInplace : public DIPUCopyBase {
info.recomputeTensorsInfo(dst, tmpSrc);
}
if (info.directMemCopy_) {
doDirectMemCopy(dst, tmpSrc, info.curStream_, info.copyType_,
/*needMemCpSync=*/false);
directMemCopy(dst, tmpSrc, info, non_blocking);
return;
}
switch (info.copyType_) {
Expand Down
46 changes: 39 additions & 7 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/AscendCopyInplace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,47 @@ class AscendCopyInplace : public DIPUCopyInpOnDIOPI {
~AscendCopyInplace() override = default;

protected:
void copyPreProcess(const at::Tensor& dst, const at::Tensor& src,
bool non_blocking, CopyParamsInfo& info) override {
// recordBeforeCopy
if (non_blocking) {
const bool is_default_stream =
dipu::getDefaultDIPUStream() == info.curStream_;
tryRecordStream(dst, info.curStream_, is_default_stream);
tryRecordStream(src, info.curStream_, is_default_stream);
}

if (!non_blocking && (DIPUCopyType::H2D == info.copyType_ ||
DIPUCopyType::D2H == info.copyType_)) {
// According to our benchmark for H2D/D2H synchronous direct memory copy,
// (Sync + memCopySync) is faster than (memCopyAsync + Sync) on Ascend,
// So do an advance sync here
dipu::devapis::syncStream(info.curStream_.rawstream());
}
}

void directMemCopy(at::Tensor& dst, const at::Tensor& src,
CopyParamsInfo& info, bool non_blocking) override {
if (!non_blocking && (DIPUCopyType::H2D == info.copyType_ ||
DIPUCopyType::D2H == info.copyType_)) {
// According to our benchmark for H2D/D2H synchronous direct memory copy,
// (Sync + memCopySync) is faster than (memCopyAsync + Sync) on Ascend,
// so do a memCopySync instead of memCopyAsync here
memCopy(dst, src, info.curStream_, info.copyType_,
/*nonOverlappingAndDense=*/true, /*isSynchronousCopy=*/true);
} else {
doDirectMemCopy(dst, src, info.curStream_, info.copyType_,
/*needMemCpSync=*/false);
}
}

void copyPostProcess(bool non_blocking, const CopyParamsInfo& info,
DIPUStream& curStream) override {
// TODO(fandaoyi): Refactor to remove duplicated code from different vendors
// Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.copy_.html
// In d2self cases, non_blocking has no effect.
// For other cases, do sync after copy if non_blocking is false.
if (!non_blocking && info.copyType_ != DIPUCopyType::D2Self) {
dipu::devapis::syncStream(curStream.rawstream());
}
// In d2self cases, non_blocking has no effect (Ref:
// https://pytorch.org/docs/stable/generated/torch.Tensor.copy_.html). In
// d2h/h2d cases, the (Sync + memCopySync) strategy is adopted (see the
// comments in the above functions copyPreProcess and directMemCopy), so
// synchronization is never needed here.
}
};

Expand Down
3 changes: 0 additions & 3 deletions dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,18 @@ void freeDevice(void* p) {
// (synchronous) copy from device to a device
void memCopyD2D(size_t nbytes, deviceId_t dstDevId, void* dst,
deviceId_t srcDevId, const void* src) {
syncDevice();
DIPU_CALLACLRT(
::aclrtMemcpy(dst, nbytes, src, nbytes, ACL_MEMCPY_DEVICE_TO_DEVICE));
}

// (synchronous) copy from host to a device
void memCopyH2D(size_t nbytes, void* dst, const void* src) {
syncDevice();
DIPU_CALLACLRT(
::aclrtMemcpy(dst, nbytes, src, nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
}

// (synchronous) copy from a device to host
void memCopyD2H(size_t nbytes, void* dst, const void* src) {
syncDevice();
DIPU_CALLACLRT(
::aclrtMemcpy(dst, nbytes, src, nbytes, ACL_MEMCPY_DEVICE_TO_HOST));
}
Expand Down

0 comments on commit 621134d

Please sign in to comment.