Skip to content

Commit

Permalink
[DIPU] disable custom fallback ops partly. (#592)
Browse files Browse the repository at this point in the history
* disable custom fallback ops partly.

* clang-format.

* adjust the way to disable custom fallback ops.

* add env to determine whether to keep torchop default impl.

* add test to ensure linear custom fallback.

* clean code.
  • Loading branch information
Reinerzhou authored Jan 17, 2024
1 parent f0f7e26 commit fb2cfc7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dipu/tests/python/individual_scripts/test_dipu_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def fn():
_test_dipu_copy_fallback_,
_test_dipu_convolution_backward_overrideable_fallback,
_test_dipu_convolution_overrideable_fallback,
_test_dipu_silu_fallback,
_test_dipu_silu_fallback
],
in_parallel=True,
)
11 changes: 11 additions & 0 deletions dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include <torch/library.h>

#include "csrc_dipu/aten/ops/OpUtils.hpp"

namespace dipu {

bool get_force_fallback(const char* opname);
Expand Down Expand Up @@ -46,6 +48,10 @@ void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys,
// add type trait code. 2. pytorch seems are sorting out infer and other
// pre/post code. so we shouldn't created a new preprocess logic?
// so just do a simple runtime cpu fallback to support diopi func loss

// It mat be necessary to determine whether to keep torchop default impl
// for non-custom ops through function dipuKeepTorchopDefaultImpl firstly in the
// future, and we use force fallback to keep torchop default impl now.
#define DIOPI_ATEN_FUNC(opname, diopiFunc, wapperFunc) \
do { \
if ((reinterpret_cast<void*>(diopiFunc) != nullptr) && \
Expand All @@ -62,9 +68,14 @@ void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys,
} \
} while (false);

// Determine whether to keep torchop default impl for custom ops through
// function dipuKeepTorchopDefaultImpl firstly.
#define DIOPI_ATEN_FUNC_CUSTOM_FALLBACK(opname, diopi_func, force_fallback, \
wapper_func, custom_fallback_func) \
do { \
if (dipu::native::dipuKeepTorchopDefaultImpl(opname)) { \
break; \
} \
if ((reinterpret_cast<void*>(diopi_func) != nullptr) && \
!((force_fallback) || dipu::get_force_fallback(opname))) { \
m.impl(opname, TORCH_FN(wapper_func)); \
Expand Down
7 changes: 7 additions & 0 deletions dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ inline void synchronizeIfEnable() {
}
}

inline bool dipuKeepTorchopDefaultImpl(const char* opname) {
static const char* env = std::getenv("DIPU_KEEP_TORCHOP_DEFAULT_IMPL_OPS");
return (env != nullptr) &&
((std::string(env) + ',').find(std::string(opname) + ',') <
(strlen(env) - 1));
}

inline int dumpOpArgLevel() {
static const char* env_ptr = std::getenv("DIPU_DUMP_OP_ARGS");
static int level = env_ptr ? std::atoi(env_ptr) : 0;
Expand Down

0 comments on commit fb2cfc7

Please sign in to comment.