Skip to content

Commit

Permalink
fix(dipu): avoid infinite recursion in dumpArg() (#752)
Browse files Browse the repository at this point in the history
* fix(dipu): avoid infinite recursion in `dumpArg()`

This happens when DIPU_DUMP_OP_ARGS is set to 3, and the tensor is copied from device to CPU before dumping the tensor.

This patch uses a function previously used in auto_compare (renamed to `toCpuTensorWithoutDiopiCopy()`) to avoid the infinite recursion.

* style(cpp): clang-format the code
  • Loading branch information
lljbash authored Mar 29, 2024
1 parent 115eb8d commit badbe46
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
24 changes: 7 additions & 17 deletions dipu/torch_dipu/csrc_dipu/aten/ops/AutoCompareUtils.hpp
Original file line number Diff line number Diff line change
@@ -1,35 +1,25 @@
#pragma once

#include <cstddef>
#include <iomanip>
#include <ostream>
#include <sstream>
#include <string>
#include <type_traits>

#include <ATen/core/TensorBody.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/allclose.h>
#include <ATen/ops/empty_strided.h>
#include <c10/core/Device.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>

#include "csrc_dipu/aten/ops/DIPUCopy.hpp"
#include "csrc_dipu/runtime/device/deviceapis.h"
#include "csrc_dipu/aten/ops/OpUtils.hpp"

namespace dipu {
namespace native {

inline at::Tensor to_cpu_without_diopi(const at::Tensor& in) {
if (in.is_cpu()) {
return in;
}

at::Tensor out = at::empty_strided(in.sizes(), in.strides(),
in.options().device(c10::Device("cpu")));
if (in.nbytes() > 0) {
dipu::devapis::memCopyD2H(out.storage().nbytes(), out.data_ptr(),
in.data_ptr());
}
return out;
}

inline std::string cpu_tensor_to_one_line_string(const at::Tensor& tensor) {
/*
* This function retrieves the built-in string representation of the input
Expand Down Expand Up @@ -91,7 +81,7 @@ inline std::string allclose_autocompare(const at::Tensor& tensor_cpu,
constexpr double tolerance_absolute = 1e-4;
constexpr double tolerance_relative = 1e-5;
const at::Tensor& tensor_cpu_from_device =
to_cpu_without_diopi(tensor_device);
toCpuTensorWithoutDiopiCopy(tensor_device);
bool passed = at::allclose(tensor_cpu, tensor_cpu_from_device,
tolerance_absolute, tolerance_relative, true);
if (passed) {
Expand Down
21 changes: 20 additions & 1 deletion dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <string>
#include <utility>
Expand All @@ -16,19 +17,37 @@
#include <ATen/native/cpu/mixed_data_type.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/allclose.h>
#include <ATen/ops/empty_strided.h>
#include <c10/core/Device.h>
#include <c10/core/ScalarType.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>
#include <c10/util/OptionalArrayRef.h>
#include <c10/util/string_view.h>

#include "csrc_dipu/runtime/core/DIPUStream.h"
#include "csrc_dipu/runtime/device/deviceapis.h"
#include "csrc_dipu/runtime/rthelper.h"
#include "csrc_dipu/utils/Log.h"

namespace dipu {
namespace native {

// avoid infinite recursion when dumpArg() before calling diopiCopy()
inline at::Tensor toCpuTensorWithoutDiopiCopy(const at::Tensor& in) {
if (in.is_cpu()) {
return in;
}

at::Tensor out = at::empty_strided(in.sizes(), in.strides(),
in.options().device(c10::Device("cpu")));
if (in.nbytes() > 0) {
dipu::devapis::memCopyD2H(out.storage().nbytes(), out.data_ptr(),
in.data_ptr());
}
return out;
}

inline bool checkTensorDevice() {
static bool enable = []() {
const char* env_ptr = std::getenv("DIPU_CHECK_TENSOR_DEVICE");
Expand Down Expand Up @@ -114,7 +133,7 @@ inline std::string dumpArg(const at::Tensor& tensor) {
<< ", storage_data_ptr: " << tensor.storage().data_ptr().get()
<< ", storage_offset: " << tensor.storage_offset();
if (dumpOpArgLevel() > 2) {
stream << '\n' << tensor;
stream << '\n' << toCpuTensorWithoutDiopiCopy(tensor);
}
} else {
stream << "undefined";
Expand Down

0 comments on commit badbe46

Please sign in to comment.