From d50f920a059f69efc7ca3ce488f8dd8ad485ff48 Mon Sep 17 00:00:00 2001 From: Zhangzefeng Date: Fri, 22 Dec 2023 14:51:42 +0800 Subject: [PATCH] Zzf/rms norm (#751) * add rms norm op * take functions_ext into the adaptor --- adaptor/codegen/gen.py | 3 +++ impl/ascend/device_configs.py | 35 -------------------------- impl/ascend/functions_ext/rms_norm.cpp | 30 ++++++++++++++++++++++ impl/ascend_npu/CMakeLists.txt | 1 + impl/ascend_npu/ascend_config.yaml | 2 ++ 5 files changed, 36 insertions(+), 35 deletions(-) create mode 100644 impl/ascend/functions_ext/rms_norm.cpp diff --git a/adaptor/codegen/gen.py b/adaptor/codegen/gen.py index 7a9e2e75b..8384c7b9e 100644 --- a/adaptor/codegen/gen.py +++ b/adaptor/codegen/gen.py @@ -724,7 +724,10 @@ def gen_autogen_operators( # get the implemented functions impl_base_dir = os.path.dirname(config_file_path) impl_func_dir = os.path.join(impl_base_dir, "functions") + impl_func_ext_dir = os.path.join(impl_base_dir, "functions_ext") impl_functions = obtain_impl_func(impl_func_dir) + impl_functions_ext = obtain_impl_func(impl_func_ext_dir) + impl_functions.update(impl_functions_ext) if impl_plugin: impl_plugin_dir = os.path.join(impl_base_dir, "../ascend_npu/diopi_impl") diff --git a/impl/ascend/device_configs.py b/impl/ascend/device_configs.py index 3c04fb926..a9bc709eb 100755 --- a/impl/ascend/device_configs.py +++ b/impl/ascend/device_configs.py @@ -1860,39 +1860,4 @@ ] ), ), - - 'rms_norm': dict( - name=["rms_norm"], - tensor_para=dict( - args=[ - { - "ins": ['input'], - "dtype": [Skip(np.float32)], - }, - ], - ), - ), - - 'topk_nonzero': dict( - name=['topk'], - para=dict( - k=[Skip(1)], - ), - ), - - 'topk_zero': dict( - name=['topk'], - interface=['torch'], - para=dict( - k=[Skip(1)], - ), - ), - - # FIXME 特定参数组合报错 - 'embedding': dict( - name=["embedding"], - para=dict( - padding_idx=[Skip(92)], - ), - ), } diff --git a/impl/ascend/functions_ext/rms_norm.cpp b/impl/ascend/functions_ext/rms_norm.cpp new file mode 100644 index 000000000..c57935700 --- /dev/null +++ b/impl/ascend/functions_ext/rms_norm.cpp @@ -0,0 +1,30 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2023, DeepLink. + */ + +#include "../common/acloprunner.hpp" + +namespace impl { +namespace ascend { + +diopiError_t diopiRMSNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t invRms, diopiConstTensorHandle_t input, + diopiSize_t normalizedShape, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, double eps) { + AscendTensor inputTensor(input); + ASCEND_CHECK_ABORT(1 == normalizedShape.len && normalizedShape.data[0] == inputTensor.shape()[inputTensor.dim() - 1], "normalized shape error!"); + AclOpRunner<2, 2>("RmsNorm", ctx).addInput(input).addInput(weight).setAttr("epsilon", static_cast(eps)).addOutput(out).addOutput(invRms).run(); + return diopiSuccess; +} + +diopiError_t diopiRMSNormBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiTensorHandle_t gradWeight, diopiTensorHandle_t gradBias, + diopiConstTensorHandle_t gradOutput, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, + diopiConstTensorHandle_t bias, diopiConstTensorHandle_t invRms, diopiSize_t normalizedShape, double eps) { + AscendTensor inputTensor(input); + ASCEND_CHECK_ABORT(1 == normalizedShape.len && normalizedShape.data[0] == inputTensor.shape()[inputTensor.dim() - 1], "normalized shape error!"); + AclOpRunner<4, 2>("RmsNorm", ctx).addInput(gradOutput).addInput(input).addInput(invRms).addInput(weight).addOutput(gradInput).addOutput(gradWeight).run(); + return diopiSuccess; +} + +} // namespace ascend +} // namespace impl diff --git a/impl/ascend_npu/CMakeLists.txt b/impl/ascend_npu/CMakeLists.txt index 77a0e2c55..df75aa601 100755 --- a/impl/ascend_npu/CMakeLists.txt +++ b/impl/ascend_npu/CMakeLists.txt @@ -640,6 +640,7 @@ set(OLD_IMPL_SRC ${OLD_IMPL_DIR}/functions/linspace.cpp ${OLD_IMPL_DIR}/functions/apply_penalty.cpp ${OLD_IMPL_DIR}/functions/split.cpp + ${OLD_IMPL_DIR}/functions_ext/rms_norm.cpp #${OLD_IMPL_DIR}/test/export_functions.cpp #${OLD_IMPL_DIR}/test/conform_test.cpp ${OLD_IMPL_DIR}/common/utils.cpp diff --git a/impl/ascend_npu/ascend_config.yaml b/impl/ascend_npu/ascend_config.yaml index 7ba2d697a..10244b8ac 100755 --- a/impl/ascend_npu/ascend_config.yaml +++ b/impl/ascend_npu/ascend_config.yaml @@ -192,6 +192,8 @@ ascend: - diopiConvolution2dBackward - diopiLogicalAnd - diopiLogicalOr +- diopiRMSNorm +- diopiRMSNormBackward - diopiExpand - diopiLinspace - diopiProd