Skip to content

Commit

Permalink
Zzf/rms norm (DeepLink-org#751)
Browse files Browse the repository at this point in the history
* add rms norm op
* take functions_ext into the adaptor
  • Loading branch information
zhangzefeng92 authored Dec 22, 2023
1 parent d0ab3da commit d50f920
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 35 deletions.
3 changes: 3 additions & 0 deletions adaptor/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,10 @@ def gen_autogen_operators(
# get the implemented functions
impl_base_dir = os.path.dirname(config_file_path)
impl_func_dir = os.path.join(impl_base_dir, "functions")
impl_func_ext_dir = os.path.join(impl_base_dir, "functions_ext")
impl_functions = obtain_impl_func(impl_func_dir)
impl_functions_ext = obtain_impl_func(impl_func_ext_dir)
impl_functions.update(impl_functions_ext)

if impl_plugin:
impl_plugin_dir = os.path.join(impl_base_dir, "../ascend_npu/diopi_impl")
Expand Down
35 changes: 0 additions & 35 deletions impl/ascend/device_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,39 +1860,4 @@
]
),
),

'rms_norm': dict(
name=["rms_norm"],
tensor_para=dict(
args=[
{
"ins": ['input'],
"dtype": [Skip(np.float32)],
},
],
),
),

'topk_nonzero': dict(
name=['topk'],
para=dict(
k=[Skip(1)],
),
),

'topk_zero': dict(
name=['topk'],
interface=['torch'],
para=dict(
k=[Skip(1)],
),
),

# FIXME 特定参数组合报错
'embedding': dict(
name=["embedding"],
para=dict(
padding_idx=[Skip(92)],
),
),
}
30 changes: 30 additions & 0 deletions impl/ascend/functions_ext/rms_norm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/**
* @file
* @author DeepLink
* @copyright (c) 2023, DeepLink.
*/

#include "../common/acloprunner.hpp"

namespace impl {
namespace ascend {

diopiError_t diopiRMSNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t invRms, diopiConstTensorHandle_t input,
diopiSize_t normalizedShape, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, double eps) {
AscendTensor inputTensor(input);
ASCEND_CHECK_ABORT(1 == normalizedShape.len && normalizedShape.data[0] == inputTensor.shape()[inputTensor.dim() - 1], "normalized shape error!");
AclOpRunner<2, 2>("RmsNorm", ctx).addInput(input).addInput(weight).setAttr("epsilon", static_cast<float>(eps)).addOutput(out).addOutput(invRms).run();
return diopiSuccess;
}

diopiError_t diopiRMSNormBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiTensorHandle_t gradWeight, diopiTensorHandle_t gradBias,
diopiConstTensorHandle_t gradOutput, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight,
diopiConstTensorHandle_t bias, diopiConstTensorHandle_t invRms, diopiSize_t normalizedShape, double eps) {
AscendTensor inputTensor(input);
ASCEND_CHECK_ABORT(1 == normalizedShape.len && normalizedShape.data[0] == inputTensor.shape()[inputTensor.dim() - 1], "normalized shape error!");
AclOpRunner<4, 2>("RmsNorm", ctx).addInput(gradOutput).addInput(input).addInput(invRms).addInput(weight).addOutput(gradInput).addOutput(gradWeight).run();
return diopiSuccess;
}

} // namespace ascend
} // namespace impl
1 change: 1 addition & 0 deletions impl/ascend_npu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ set(OLD_IMPL_SRC
${OLD_IMPL_DIR}/functions/linspace.cpp
${OLD_IMPL_DIR}/functions/apply_penalty.cpp
${OLD_IMPL_DIR}/functions/split.cpp
${OLD_IMPL_DIR}/functions_ext/rms_norm.cpp
#${OLD_IMPL_DIR}/test/export_functions.cpp
#${OLD_IMPL_DIR}/test/conform_test.cpp
${OLD_IMPL_DIR}/common/utils.cpp
Expand Down
2 changes: 2 additions & 0 deletions impl/ascend_npu/ascend_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ ascend:
- diopiConvolution2dBackward
- diopiLogicalAnd
- diopiLogicalOr
- diopiRMSNorm
- diopiRMSNormBackward
- diopiExpand
- diopiLinspace
- diopiProd
Expand Down

0 comments on commit d50f920

Please sign in to comment.