diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index b813c042ae..0f9015bb70 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -1,4 +1,30 @@
# IMPORTANT:
# This file is ONLY used to merge PRs. Approvals from people in this file are required for merging.
+#
+# WARNING: The last matching pattern takes the most precedence and OVERWRITES previous rules.
+# Please be very careful when adding new patterns.
-/dipu/tests/python @lljbash @mrdanielw
+# ---------- base ----------
+
+* @mrdanielw @jinminxi104
+/.github/ @mrdanielw @wugeshui
+/.github/CODEOWNERS @mrdanielw @jinminxi104
+
+# ---------- dipu ----------
+
+### directories & files
+/dipu/torch_dipu/csrc_dipu/ @mrdanielw @fandaoyi @lljbash
+/dipu/tests/python/ @mrdanielw @lljbash
+/dipu/scripts/autogen_diopi_wrapper/ @mrdanielw @lljbash
+/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py @mrdanielw @zhaoguochun1995
+/dipu/scripts/ci/ @mrdanielw @wugeshui
+
+### build & linter
+/dipu/**/CMakeLists.txt @mrdanielw @lljbash @wiryls @Wrench-Git
+/dipu/**/*.cmake @mrdanielw @lljbash @wiryls @Wrench-Git
+/dipu/.clang* @mrdanielw @lljbash @wiryls
+
+# ---------- dicp ----------
+
+/dicp/ @jinminxi104
+/dicp/scripts/ci/ @jinminxi104 @wugeshui
diff --git a/.github/actions/code-build-test/action.yml b/.github/actions/code-build-test/action.yml
index a3f2c32a62..6bbede0ee6 100644
--- a/.github/actions/code-build-test/action.yml
+++ b/.github/actions/code-build-test/action.yml
@@ -81,6 +81,10 @@ runs:
else
export CI=true
source ~/.bashrc
- cd ${WORK_PATH} && rm -rf ${JOB_NAME} && cp -R source ${JOB_NAME} && cd ${JOB_NAME}
+ cd ${WORK_PATH}
+ if [ "${{ inputs.cover_job }}" == "0" ];then
+ rm -rf ${JOB_NAME} && cp -R source ${JOB_NAME}
+ fi
+ cd ${JOB_NAME}
${{ inputs.build_shell }} ${cleaner_shell}
fi
diff --git a/.github/workflows/_runs-on-ascend.yml b/.github/workflows/_runs-on-ascend.yml
index 21a93a328a..928f7f1e66 100644
--- a/.github/workflows/_runs-on-ascend.yml
+++ b/.github/workflows/_runs-on-ascend.yml
@@ -12,7 +12,7 @@ on:
description: Set up the build environment
type: string
required: false
- default: "tps-ascend-ci"
+ default: "dicp-ascend-ci-910b"
jobs:
checkout_code:
@@ -22,25 +22,38 @@ jobs:
- name: Checkout Code
uses: DeepLink-org/deeplink.framework/.github/actions/checkout-code@main
- build:
+ build_test:
runs-on: ${{ inputs.runner }}
needs: checkout_code
steps:
- - name: build on ascend
+ - name: build and test on ascend
uses: DeepLink-org/deeplink.framework/.github/actions/code-build-test@main
with:
- build_shell: "pwd" #Write the script you want to execute here,If you don't know which parameters to fill in, you can refer to the actions/code-build-test
- job_name: "build"
+ build_shell: "
+ source dicp/scripts/ci/ascend/dipu_env.sh && \
+ rm -rf /tmp/torchinductor_autolink/* && \
+ rm -rf /tmp/dicp_ascend/* && \
+ cd /mnt/cache/share/deeplinkci/dicp_env/transformers && \
+ pip uninstall transformers -y && \
+ patch -p1 < modeling_llama.diff && patch -p1 < utils.diff && \
+ python setup.py clean && \
+ python setup.py install --user && \
+ patch -R -p1 < modeling_llama.diff && patch -R -p1 < utils.diff && \
+ cd - && \
+ cd /mnt/cache/share/deeplinkci/dicp_env/accelerate && \
+ pip uninstall accelerate -y && \
+ python setup.py clean && \
+ python setup.py install --user && \
+ cd - && \
+ pip uninstall torch_dipu -y && \
+ pip uninstall dicp -y && \
+ cd dipu && python setup.py clean && python setup.py install --user && \
+ cd ../dicp && python setup.py clean && python setup.py install --user && \
+ source scripts/ci/ascend/test_env.sh /mnt/cache/share/deeplinkci/dicp_env/llama_models && \
+ export TEST_DIR=$(pwd)/test && echo ${TEST_DIR} && \
+ bash ${TEST_DIR}/ascend_scripts/ops/run_test_ops.sh false && \
+ bash ${TEST_DIR}/ascend_scripts/models/run_test_models.sh false
+ " #Write the script you want to execute here,If you don't know which parameters to fill in, you can refer to the actions/code-build-test
+ job_name: "build_test"
cover_job: "0"
cleaner: "clean_all_if_error"
-
- test:
- runs-on: ${{ inputs.runner }}
- needs: build
- steps:
- - name: rt test on ascend
- uses: DeepLink-org/deeplink.framework/.github/actions/code-build-test@main
- with:
- build_shell: "pwd" #Write the script you want to execute here,If you don't know which parameters to fill in, you can refer to the actions/code-build-test
- job_name: "build"
- cover_job: "1"
diff --git a/.github/workflows/_runs-on-topsrider.yml b/.github/workflows/_runs-on-topsrider.yml
index 8427855c6d..21efb5d972 100644
--- a/.github/workflows/_runs-on-topsrider.yml
+++ b/.github/workflows/_runs-on-topsrider.yml
@@ -26,7 +26,7 @@ jobs:
runs-on: ${{ inputs.runner }}
needs: checkout_code
steps:
- - name: build and test on topsrider
+ - name: build on topsrider
uses: DeepLink-org/deeplink.framework/.github/actions/code-build-test@main
with:
build_shell: "
@@ -34,11 +34,19 @@ jobs:
pip uninstall torch_dipu -y && \
pip uninstall dicp -y && \
cd dipu && python setup.py install --user && \
- cd ../dicp && python setup.py install --user && \
- cd .. && source dicp/scripts/ci/tops/ci_tops_test_env.sh /mnt/models/llama_models && \
- export TEST_DIR=$(pwd)/dicp/test && echo ${TEST_DIR} && \
- bash ${TEST_DIR}/tops_scripts/ops/run_test_ops.sh false && \
- bash ${TEST_DIR}/tops_scripts/models/run_test_models.sh false
+ cd ../dicp && python setup.py install --user
"
job_name: "build_test"
cover_job: "0"
+
+ - name: test ops on topsrider
+ uses: DeepLink-org/deeplink.framework/.github/actions/code-build-test@main
+ with:
+ build_shell: "
+ source dicp/scripts/ci/tops/ci_tops_test_env.sh \
+ /mnt/models/llama_models /mnt/models/stable_diffusion_models && \
+ export TEST_DIR=$(pwd)/dicp/test && \
+ bash ${TEST_DIR}/tops_scripts/ops/run_test_ops.sh false
+ "
+ job_name: "build_test"
+ cover_job: "1"
diff --git a/.github/workflows/dicp.yml b/.github/workflows/dicp.yml
index 537fd239ee..d758fd1c75 100644
--- a/.github/workflows/dicp.yml
+++ b/.github/workflows/dicp.yml
@@ -1,15 +1,14 @@
name: dicp ci
on:
workflow_dispatch:
- push:
- branches:
- - main
+ schedule:
+ - cron: '10 23 * * *'
pull_request:
- paths-ignore:
- - "**.md"
- - ".github/ISSUE_TEMPLATE/**"
- - ".git*"
- - "CODE_OF_CONDUCT**"
+ paths:
+ - ".github/workflows/dicp.yml"
+ - ".github/workflows/_runs-on-ascend.yml"
+ - ".github/workflows/_runs-on-topsrider.yml"
+ - "dicp/**"
env:
ENV_PATH: '/mnt/cache/share/platform/env'
diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
index cbad72ae4a..32efe64a4e 100644
--- a/.github/workflows/format.yml
+++ b/.github/workflows/format.yml
@@ -11,19 +11,22 @@ jobs:
markdownlint:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
+ - name: Checkout code
+ uses: actions/checkout@v4
with:
- fetch-depth: 2
- - uses: tj-actions/changed-files@v40
+ fetch-depth: 8
+ - name: Collect changed files
+ uses: tj-actions/changed-files@v40
id: changed-files
with:
files: '**/*.md'
- separator: ","
- - uses: DavidAnson/markdownlint-cli2-action@v14
+ separator: ','
+ - name: MarkdownLint
if: steps.changed-files.outputs.any_changed == 'true'
+ uses: DavidAnson/markdownlint-cli2-action@v14
with:
globs: ${{ steps.changed-files.outputs.all_changed_files }}
- separator: ","
+ separator: ','
clang-format:
needs: markdownlint
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 6f946ae407..852a4335f9 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -15,7 +15,7 @@ env:
CAMB_CLUSTER: CAMB
CAMB_TORCH_BASE_DIR: '/mnt/lustre/share/parrotsci/github/cibuild/pytorchbase'
CUDA_CI_PATH: '/mnt/cache/share/parrotsci/github/cibuild/${{ github.repository }}'
- CUDA_PARTATION: ${{ vars.SH1988_SLURM_PAR != '' && vars.SH1988_SLURM_PAR || 'pat_rd -x SH-IDC1-10-198-8-60' }}
+ CUDA_PARTATION: ${{ vars.SH1988_SLURM_PAR != '' && vars.SH1988_SLURM_PAR || 'pat_rd' }}
CUDA_CLUSTER: SH1988
DEEPLINK_PATH: '/mnt/cache/share/deeplinkci/github/${{ github.repository }}'
ASCEND_CLUSTER: ASCEND
@@ -24,7 +24,7 @@ env:
CI_BUILD_FLAG: "ci_build_flag"
PYTORCH_COMMIT: ${{ vars.PYTORCH_COMMIT != '' && vars.PYTORCH_COMMIT || 'c263bd43e8e8502d4726643bc6fd046f0130ac0e' }} # pytorch tag 2.0
ALL_COVERAGE: ${{ (contains( github.ref, 'main') || startsWith(github.ref, 'refs/heads/v') || startsWith(github.ref, 'refs/heads/dev')) && 'ON' || 'OFF' }}
- REQUIRE_COVERAGE: ${{ vars.REQUIRE_COVERAGE != '' && vars.REQUIRE_COVERAGE || '40' }}
+ REQUIRE_COVERAGE: ${{ vars.REQUIRE_COVERAGE != '' && vars.REQUIRE_COVERAGE || '0' }}
REPO: ${{ github.event.repository.name }}
concurrency:
@@ -128,7 +128,7 @@ jobs:
cd ${CAMB_CI_PATH}/${GITHUB_RUN_NUMBER}/Build-Camb
rm -rf scripts
ln -s ${CAMB_CI_PATH}/${GITHUB_RUN_NUMBER}/source-main/dipu/third_party/DIOPI/scripts scripts
- source /mnt/cache/share/platform/env/camb_ci_diopi_impl
+ source /mnt/cache/share/platform/env/pt2.0_diopi
bash scripts/increment_coverage.sh ${REQUIRE_COVERAGE}
"""
diff --git a/README.md b/README.md
index a3a17decde..7dd7aafe2f 100644
--- a/README.md
+++ b/README.md
@@ -12,7 +12,7 @@ Deeplink.framework 是 DeepLink 推出的介于 AI 训练框架和硬件语言
### DIPU
-DIPU (Device Independent Process Unit) 是由一组抽象设备 runtime 接口,一组框架能力相关的运行时基类/接口,一个针对 DIOPI 标准算子的适配层共同组成的拓展包。 用来在训练框架 PyTorch 上接入 DIOPI 算子库,实现 Eager 模式的推理和训练。其能够在编译时,决定抽象设备被影射的方式;并使用统一的运行时,减少在多硬件上适配训练框架的成本。DIPU 即可以基于统一的设备运行时来屏蔽厂商的实际设备;也可以基于统一的框架相关的运行时基类,由厂商自行实现特有的运行时逻辑。
+DIPU (Device Independent Process Unit) 是由一组抽象设备 runtime 接口,一组框架能力相关的运行时基类/接口,一个针对 DIOPI 标准算子的适配层共同组成的拓展包。用来在训练框架 PyTorch 上接入 DIOPI 算子库,实现 Eager 模式的推理和训练。其能够在编译时,决定抽象设备被影射的方式;并使用统一的运行时,减少在多硬件上适配训练框架的成本。DIPU 即可以基于统一的设备运行时来屏蔽厂商的实际设备;也可以基于统一的框架相关的运行时基类,由厂商自行实现特有的运行时逻辑。
### DICP
diff --git a/dicp/MANIFEST.in b/dicp/MANIFEST.in
new file mode 100644
index 0000000000..9ce7b59fbf
--- /dev/null
+++ b/dicp/MANIFEST.in
@@ -0,0 +1,2 @@
+recursive-include dicp/vendor/TopsGraph/codegen *
+recursive-include dicp/vendor/AscendGraph/codegen *
\ No newline at end of file
diff --git a/dicp/README.md b/dicp/README.md
new file mode 100644
index 0000000000..db01a09b6c
--- /dev/null
+++ b/dicp/README.md
@@ -0,0 +1,85 @@
+
+
+
+

+
+
+# DICP
+
+标准编译协议(Device-Independent Compile Protocol, DICP)定义了统一的计算描述(中间表示),通过计算图获取深度学习模型中的计算任务表达为上述中间表示,然后通过计算图优化技术自动生成人工智能芯片设备代码,从而提高研发效率和计算的执行性能。中间表示是介于源语言和目标语言之间的程序表示,能够极大程度地提高编译流程的可拓展性,同时也能降低优化流程对前端和后端的破坏。多层次中间表示包含从应用到芯片端的多种表示层次,不同层次旨在解决不同尺度的问题。
+
+DICP 主要的核心功能如下:
+
+1. 通过接入编译路线带来性能优势,在大模型场景最大限度释放芯片能力。
+2. 作为训练框架与国产硬件芯片之间的通用桥梁,支持多种前后端,带来使用易用性。
+3. 提供易用、高效的一站式编译适配流程,灵活支持国产硬件图编译器的特性,提高芯片适配效率。
+
+下图描述了 DICP 在编译链路中的位置:
+
+
+
+1. 训练框架通过图获取模块将用户的模型代码转换成统一的中间表达。此处的中间表达完全与芯片无关。所以在之后的编译协议部分中,需要建立起与后端芯片的联系。这样才能高效的完成接入。
+2. 编译协议完成了衔接框架与芯片编译器的工作,其中包含硬件相关的切图,统一中间表达与芯片所支持的算子之间的映射关系以及数据格式的转换模块。
+3. 在编译协议吸收了芯片特点之后,由代码生成模块生成最终的代码,并通过芯片的编译器生成二进制可执行文件之后由框架调用。
+
+## 基于 DICP 的国产硬件接入 PyTorch 2 实践
+
+
+
+基于上述 DICP,国产硬件可快速接入 PyTorch 2 的编译路线。此路线中的 TorchDynamo 组件,可使国产硬件在运行时的 overhead 大幅缩小。
+并且针对国产硬件实现了以下特性:
+
+- 灵活支持国产硬件图编译器的特性
+- 支持多种国产硬件数据格式
+- 支持动态 shape
+
+### 运行逻辑
+
+DICP 的运行逻辑如下图所示:
+
+
+
+
+其中:
+
+1. **算子映射**:主要解决框架层算子与后端图编译器的算子之间的语义差别,包括 1 对 1 和 1 对多的转换。
+2. **Shape & Dtype 推导**:进行 Shape & data_type 的推导,补全整张静态图上的信息,便于之后在代码生成模块能生成代码。
+3. **子图改写**:将多个小算子融合成为一个或多个适合图编译器的算子,配合后端图编译器将计算效率最大化。
+4. **数据格式调整**:是根据后端芯片与其图编译器的特性,针对特定的算子调整其输入输出的数据格式,使得最大程度的发挥芯片性能。
+
+### 目录结构
+
+- `dicp/dynamo_bridge`:多后端通用的接入代码,包含了
+ 1. 接收从 AOTAutograd 下发而来的 FX Graph
+ 2. 启动各个厂商的 IR 转换与优化
+ 3. 启动 CodeGen 以及 JIT 缓存的逻辑。
+- `dicp/vender`: 主要包含了各个厂商 IR 的定义,AtenIR 到厂商 IR 的转换,厂商 IR 上的优化以及最后的代码生成模块。
+- `test`: 包含了 model 测试与 op 测试
+
+### Demo
+
+#### 安装 DICP
+
+```bash
+cd /path_to_dicp
+pip install .
+```
+
+#### 在华为 910 上执行 llama7B 前向推理
+
+```bash
+export DIPU_MOCK_CUDA = false
+export DICP_TOPS_DIPU = True
+export TEST_DIR = /path_to_dicp/test/
+export LLAMA_MODEL_DIR=/path_to_llama_model
+bash /path_to_dicp/test/model/run_test_model.sh llama ascendgraph false
+```
+
+#### 在燧原 T20 上执行 resnet50 训练
+
+```bash
+export DIPU_MOCK_CUDA = false
+export DICP_TOPS_DIPU = True
+export TEST_DIR = /path_to_dicp/test/
+bash /path_to_dicp/test/model/run_test_model.sh resnet50 topsgraph false
+```
diff --git a/dicp/dicp/dynamo_bridge/op_transformer.py b/dicp/dicp/dynamo_bridge/op_transformer.py
index 1b855f6a88..85121b9f36 100644
--- a/dicp/dicp/dynamo_bridge/op_transformer.py
+++ b/dicp/dicp/dynamo_bridge/op_transformer.py
@@ -55,6 +55,9 @@ def get_proxy(self, target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] =
'call_function', target.get_singleton(), args, kwargs)
return proxy
+ def get_proxy_from_node(self, node):
+ return self.tracer.proxy(node)
+
def call_function(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
if target in self._conversions:
converted_target = self._conversions[target]
diff --git a/dicp/dicp/dynamo_bridge/operator.py b/dicp/dicp/dynamo_bridge/operator.py
index 213b8bc31c..99411e9491 100644
--- a/dicp/dicp/dynamo_bridge/operator.py
+++ b/dicp/dicp/dynamo_bridge/operator.py
@@ -95,7 +95,7 @@ def make_cpu(x):
except Exception as e:
log = logging.getLogger(__name__)
if hasattr(self, "infer_result"):
- log.warning(
+ log.debug(
str(self.__name__) + ": infer shape and dtype failed,ignore"
)
elif hasattr(self, "torch_op"):
diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py
index 36a18ad3a7..09c8f302c0 100644
--- a/dicp/dicp/vendor/AscendGraph/ascend_op.py
+++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py
@@ -24,7 +24,7 @@ def negative_in_shape(shape):
class Adds(Operator):
def __init__(self):
- super().__init__("adds")
+ super().__init__("Adds")
def infer_result(self, x1, x2):
return common_binary_op_infer(x1, x2)
@@ -32,7 +32,7 @@ def infer_result(self, x1, x2):
class Add(Operator):
def __init__(self):
- super().__init__("add")
+ super().__init__("Add")
def infer_result(self, x1, x2):
return common_binary_op_infer(x1, x2)
@@ -42,11 +42,54 @@ class BroadcastTo(Operator):
def __init__(self):
super().__init__("BroadcastTo")
+ def infer_result(self, x, shape):
+ x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x)
+ shape, shape_shape, shape_dim, shape_dtype = get_fake_tensor_meta_val(shape)
+ shape = shape_shape
+ dims = zip(reversed(shape), reversed(x_shape))
+
+ for i, t in enumerate(dims):
+ tar_dim, cur_dim = t
+ if tar_dim == -1:
+ shape[-(i + 1)] = cur_dim
+ continue
+ elif cur_dim == 1:
+ continue
+ assert cur_dim == tar_dim, self.__class__.__name__ + ": shape mismatch!"
+ # broadcast keep get_memory_format
+ return torch.empty(shape, dtype=x_dtype, memory_format=get_memory_format(x))
+
class Range(Operator):
def __init__(self):
super().__init__("Range")
+ def infer_result(self, start, limit=None, delta=None):
+ start, start_dtype, _ = get_op_const_arg_kwarg(start)
+ limit, limit_dtype, _ = get_op_const_arg_kwarg(limit)
+ delta, delta_dtype, _ = get_op_const_arg_kwarg(delta)
+
+ assert start is not None, (
+ self.__class__.__name__ + ": input 'start' can't be None!"
+ )
+ if limit is None:
+ limit = start
+ start = 0.0
+ delta = float(delta) if delta is not None else 1.0
+ assert not close2(delta, 0), self.__class__.__name__ + "step must be nonzero"
+ assert (delta > 0 and limit > start) or (delta < 0 and limit < start), (
+ self.__class__.__name__
+ + "upper bound and larger bound inconsistent with step sign"
+ )
+
+ seq_len = math.ceil((limit - start) / delta)
+
+ return torch.empty(
+ [seq_len],
+ dtype=get_cast_dtype(start_dtype, limit_dtype),
+ memory_format=torch.contiguous_format,
+ )
+
class Cumsum(Operator):
def __init__(self):
@@ -107,6 +150,12 @@ def infer_result(self, x1, x2):
return common_binary_op_infer(x1, x2)
+class Muls(Operator):
+ def __init__(self):
+ super().__init__("Muls")
+ self.torch_op = aten.mul
+
+
class Div(Operator):
def __init__(self):
super().__init__("Div")
@@ -236,7 +285,6 @@ def infer_result(self, x, dim=None):
+ ": can only squeeze a dimension that is 1!"
)
shape.pop(i)
-
x_memory_format = get_memory_format(x)
if len(shape) < 4:
x_memory_format = torch.contiguous_format
@@ -247,6 +295,15 @@ class Pack(Operator):
def __init__(self):
super().__init__("Pack")
+ def infer_result(self, x, dim):
+ x0, x0_shape, x0_dim, x0_dtype = get_fake_tensor_meta_val(x[0])
+ dim = (dim + x0_dim + 1) % (x0_dim + 1)
+ out_shape = list(x0_shape)
+ out_shape.insert(dim, len(x))
+ return torch.empty(
+ out_shape, dtype=x0_dtype, memory_format=get_memory_format(x0)
+ )
+
class Permute(Operator):
def __init__(self):
@@ -257,11 +314,56 @@ class Expand(Operator):
def __init__(self):
super().__init__("Expand")
+ # TODO: unfinished, need furthur test
+ def infer_result(self, x, shape_tensor):
+ x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x, True)
+ (
+ shape_tensor,
+ shape_tensor_shape,
+ shape_tensor_dim,
+ shape_tensor_dtype,
+ ) = get_fake_tensor_meta_val(shape_tensor, True)
+ assert x_dim > 0, self.__class__.__name__ + ": scalar"
+ shape = list(shape_tensor_shape)
+ dims = zip(shape, x_shape)
+ x_stride = list(x.stride())
+ for i, t in enumerate(dims):
+ tar_dim, cur_dim = t
+ if tar_dim != cur_dim:
+ x_stride[i] = 0
+ if tar_dim == -1:
+ shape[-(i + 1)] = cur_dim
+ continue
+ elif cur_dim == 1:
+ continue
+ assert cur_dim == tar_dim, self.__class__.__name__ + ": shape mismatch!"
+ # broadcast keep get_memory_format
+ return torch.empty(shape, dtype=x_dtype, memory_format=get_memory_format(x))
+
class ExpandD(Operator):
def __init__(self):
super().__init__("ExpandD")
+ def infer_result(self, x, shape):
+ x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x, True)
+ assert x_dim > 0, self.__class__.__name__ + ": scalar"
+ dims = zip(shape, x_shape)
+ x_stride = list(x.stride())
+ for i, t in enumerate(dims):
+ tar_dim, cur_dim = t
+ if tar_dim != cur_dim:
+ x_stride[i] = 0
+ if tar_dim == -1:
+ shape[-(i + 1)] = cur_dim
+ continue
+ elif cur_dim == 1:
+ continue
+ assert cur_dim == tar_dim, self.__class__.__name__ + ": shape mismatch!"
+ res = torch.empty(shape, dtype=x_dtype, memory_format=get_memory_format(x))
+ res = torch.as_strided(res, shape, x_stride, res.storage_offset())
+ return res
+
class Sort(Operator):
def __init__(self):
@@ -277,10 +379,16 @@ class ScatterElements(Operator):
def __init__(self):
super().__init__("ScatterElements")
+ def infer_result(self, var, index, value, dim):
+ return common_unary_op_infer(var)
+
-class ReduceMean(Operator):
+class ReduceMeanD(Operator):
def __init__(self):
- super().__init__("ReduceMean")
+ super().__init__("ReduceMeanD")
+
+ def infer_result(self, x, axes, keepdim=False, noop_with_empty_axes=True):
+ return reduce_op_infer(x, axes, keepdim)
class ReduceStdV2Update(Operator):
@@ -300,7 +408,7 @@ class Const(Operator):
def __init__(self):
super().__init__("Const")
- def infer_result(self, new_args, kwargs):
+ def infer_result(self, *new_args, **kwargs):
return new_args, kwargs
@@ -318,14 +426,11 @@ def __init__(self):
def infer_result(self, base, expo):
base, base_shape, base_dim, base_dtype = get_fake_tensor_meta_val(base)
-
if isinstance(expo, Tuple): # Const
- expo, expo_shape = get_op_const_arg_kwarg(expo)
+ expo, _, expo_shape = get_op_const_arg_kwarg(expo)
expo_dtype = type(expo[0]) if len(expo) > 0 else base_dtype
else: # fake Tensor
- expo, expo_shape, expo_dim, expo_dtype = get_fake_tensor_meta_val(
- expo
- )
+ expo, expo_shape, expo_dim, expo_dtype = get_fake_tensor_meta_val(expo)
out_shape = get_broadcast_res_two_shape(base_shape, expo_shape)
dtype = get_cast_dtype(base_dtype, expo_dtype)
@@ -337,7 +442,7 @@ class Select(Operator):
def __init__(self):
super().__init__("Select")
- def infer_result(self, x1, x2, condition):
+ def infer_result(self, condition, x1, x2):
x1, x1_shape, x1_dim, x1_dtype = get_fake_tensor_meta_val(x1)
x2, x2_shape, x2_dim, x2_dtype = get_fake_tensor_meta_val(x2)
_, c_shape, _, _ = get_fake_tensor_meta_val(condition)
@@ -373,6 +478,14 @@ def infer_result(self, x1, x2):
return common_binary_op_infer(x1, x2, torch.bool)
+class NotEqual(Operator):
+ def __init__(self):
+ super().__init__("NotEqual")
+
+ def infer_result(self, x1, x2):
+ return common_binary_op_infer(x1, x2, torch.bool)
+
+
class Conv2D(Operator):
def __init__(self):
super().__init__("Conv2D")
@@ -409,9 +522,18 @@ def __init__(self):
super().__init__("Identity")
def infer_result(self, x, idx=None):
- x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x)
- out_shape = list(x_shape[idx]) if idx is not None else list(x_shape)
- return torch.empty(out_shape, dtype=x_dtype, memory_format=get_memory_format(x))
+ x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x)
+ out_dtype = x_dtype
+ if x_dtype == torch.complex64: # for complex64
+ out_shape = list(x_shape)
+ if idx == 0 or idx == 1:
+ out_dtype = torch.float32
+ out_shape.append(1)
+ else:
+ out_shape = [x_shape[idx]] if idx is not None else list(x_shape)
+ return torch.empty(
+ out_shape, dtype=out_dtype, memory_format=get_memory_format(x)
+ )
class IdentityInp(Operator):
@@ -439,6 +561,18 @@ class Empty(Operator):
def __init__(self):
super().__init__("Empty")
+ def infer_result(
+ self, shape, dtype, layout, device, memory_format=torch.contiguous_format
+ ):
+ shape, _, _ = get_op_const_arg_kwarg(shape)
+ return torch.empty(
+ shape,
+ dtype=dtype,
+ layout=layout,
+ device=device,
+ memory_format=memory_format,
+ )
+
class GatherV2(Operator):
def __init__(self):
@@ -452,15 +586,35 @@ def infer_result(self, x, index, axis):
return torch.empty(idx_shape, dtype=x_dtype, memory_format=get_memory_format(x))
+class GatherElements(Operator):
+ def __init__(self):
+ super().__init__("GatherElements")
+
+ def infer_result(self, x, index, axis):
+ x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x)
+ idx, idx_shape, idx_dim, idx_dtype = get_fake_tensor_meta_val(index)
+ return torch.empty(idx_shape, dtype=x_dtype, memory_format=get_memory_format(x))
+
+
class OnesLike(Operator):
def __init__(self):
super().__init__("OnesLike")
+ def infer_result(self, x):
+ return common_unary_op_infer(x)
+
class Fill(Operator):
def __init__(self):
super().__init__("Fill")
+ def infer_result(self, dims, value):
+ _, value_dtype, _ = get_op_const_arg_kwarg(value)
+ shape, _, _ = get_op_const_arg_kwarg(dims)
+ return torch.empty(
+ shape, dtype=value_dtype, memory_format=torch.contiguous_format
+ )
+
class Conv2DBackpropInput(Operator):
def __init__(self):
@@ -542,11 +696,34 @@ class SplitD(Operator):
def __init__(self):
super().__init__("SplitD")
+ def infer_result(self, x, split_dim, num_split, y, from_view_complex=False):
+ assert from_view_complex == True, (
+ self.__class__.__name__
+ + ": currently available only in op view_as_complex!"
+ )
+ x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x)
+ split_dim = (split_dim + x_dim) % x_dim
+ out_shape = list(x_shape)
+ del out_shape[-1]
+ return torch.empty(
+ out_shape,
+ dtype=torch.complex64 if from_view_complex else x_dtype,
+ memory_format=get_memory_format(x),
+ )
+
class Slice(Operator):
def __init__(self):
super().__init__("Slice")
+ def infer_result(self, x, offset, size):
+ x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x)
+ new_shape, _, _ = get_op_const_arg_kwarg(size)
+ offset, _, _ = get_op_const_arg_kwarg(offset)
+ _, storage_offset = cal_stride_offset(new_shape, offset, x)
+ res = torch.as_strided(x, new_shape, x.stride(), storage_offset)
+ return res
+
class ConcatD(Operator):
def __init__(self):
@@ -570,33 +747,32 @@ class MaskedFill(Operator):
def __init__(self):
super().__init__("MaskedFill")
+ def infer_result(self, x, mask, value):
+ x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x)
+ _, _, _, value_dtype = get_fake_tensor_meta_val(value)
+ _, mask_shape, _, _ = get_fake_tensor_meta_val(mask)
+ return torch.empty(
+ get_broadcast_res_two_shape(x_shape, mask_shape),
+ dtype=get_cast_dtype(x_dtype, value_dtype),
+ memory_format=get_memory_format(x),
+ )
+
class Reshape(Operator):
def __init__(self):
super().__init__("Reshape")
- # TODO:conflict in solving stride between "view" and "select"
- def infer_result(self, x, shape_const_op):
- x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x)
- re_shape, re_dim = get_op_const_arg_kwarg(shape_const_op)
- # check whether stride and storage_offset are manually specified
- # if so, x is from operators like "Slice", and the stride and storage_offset still need to modify here
+ def infer_result(self, x, shape_const_op, ori_op=None, params_passed=None):
+ x, _, _, x_dtype = get_fake_tensor_meta_val(x)
+ re_shape, _, _ = get_op_const_arg_kwarg(shape_const_op)
x_stride = list(x.stride())
- x_shape = list(x_shape)
-
- for i in range(len(x_stride) - 2, -1, -1):
- if x_stride[i + 1] * x_shape[i + 1] != x_stride[i]:
- del x_stride[i + 1]
- del x_shape[i + 1]
- break
- else:
- if len(x_shape) != len(re_shape):
- del x_stride[0]
- del x_shape[0]
-
- x_storage_offset = x.storage_offset()
res = torch.empty(re_shape, dtype=x_dtype, memory_format=get_memory_format(x))
- res = torch.as_strided(res, re_shape, x_stride, x_storage_offset)
+ if ori_op == "Select":
+ assert "sel_dim" in params_passed, (
+ self.__class__.__name__ + ':param "sel_dim" from Select missing!'
+ )
+ del x_stride[params_passed["sel_dim"]]
+ res = torch.as_strided(res, re_shape, x_stride, x.storage_offset())
return res
@@ -633,7 +809,6 @@ def __init__(self):
super().__init__("Shape")
def infer_result(self, x):
- # like Const, we won't use this function, but it should exist as a flag for triggering inference of resinfo
return common_unary_op_infer(x, spec_format=torch.contiguous_format)
@@ -675,6 +850,16 @@ def __init__(self):
super().__init__("DropOutDoMaskV3")
+class MaxPool(Operator):
+ def __init__(self):
+ super().__init__("MaxPool")
+
+
+class PadV3(Operator):
+ def __init__(self):
+ super().__init__("PadV3")
+
+
def ret_triple(a, b, c) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return a, b, c
diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py
index 635fb5161b..c32b8e7974 100644
--- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py
+++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py
@@ -334,34 +334,27 @@ def gen_call_func(self):
call_body.writeline(shape_str)
else:
call_body.writeline('''output_shape = None''')
-
-
+
# add stride & storage_offset info
- out_stride_str = '''out_stride = ['''
- out_storage_offset_str = '''out_storage_offset = ['''
+ out_strides = []
+ out_storage_offsets = []
for elem in self.output_args:
if hasattr(elem, 'meta'):
elem = elem.meta['val']
if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool):
- out_stride_str += '[1],'
- out_storage_offset_str += '0,'
+ out_strides.append('[1]')
+ out_storage_offsets.append('0')
continue
if elem.dim()==0: # temporary solution for sum.default(a) whose result is a scalar(no dim no stride)
- out_stride_str += '[1],'
- out_storage_offset_str += '0,'
+ out_strides.append('[1]')
+ out_storage_offsets.append('0')
continue
stride = list(elem.stride())
- if len(stride) == 0:
- raise RuntimeError("Error handling empty output_stride")
stride = [self.process_sym_name(str(dim)) for dim in stride]
- out_stride_str += '[' + ','.join(map(str, stride)) + '],'
- out_storage_offset_str += str(elem.storage_offset()) + ','
- out_stride_str += extra_stride_str
- out_storage_offset_str += extra_storage_offset_str
- out_stride_str = out_stride_str[:-1] + ']'
- out_storage_offset_str = out_storage_offset_str[:-1] + ']'
- call_body.writeline(out_stride_str)
- call_body.writeline(out_storage_offset_str)
+ out_strides.append(str(stride))
+ out_storage_offsets.append(elem.storage_offset())
+ call_body.writeline(f'out_stride = {out_strides}')
+ call_body.writeline(f'out_storage_offset = {out_storage_offsets}')
call_body.splice("""
import torch_dipu
@@ -369,14 +362,17 @@ def gen_call_func(self):
for idx in range(len(args)):
if isinstance(args[idx], int):
args[idx] = torch.tensor(args[idx], device=dipu_device_str, dtype=torch.int32)
- if isinstance(args[idx], torch.Tensor):
- tmp_arg = args[idx].clone()
- with torch.no_grad():
- args[idx].copy_(tmp_arg)
- del tmp_arg
""", strip=True)
call_body.writeline(f"({','.join(self.args)}) = args")
- call_str = ['output_tensor = kernel_cpp_0(args, dims, output_shape, out_stride, out_storage_offset)']
+
+ # dealing with modified args passing back
+ allocated_output = {}
+ for item in self.assign_args:
+ input_index = item[1]
+ output_index = self.graph_output_names.index(item[0])
+ allocated_output[output_index] = input_index
+ call_body.writeline(f'allocated_output= {allocated_output}')
+ call_str = ['output_tensor = kernel_cpp_0(args, dims, output_shape, out_stride, out_storage_offset, allocated_output)']
if precision_check and self.aten_graph is not None:
# import aten graph
@@ -401,10 +397,6 @@ def gen_call_func(self):
call_str.extend([f'del {name}',
f'{name} = int(output_tensor[{i}])'])
- # dealing with modified args passing back
- output_convert = [f'args[{name[1]}].copy_({name[0]})' for name in self.assign_args]
- call_str.extend(output_convert)
-
if precision_check:
for i, name in enumerate(self.py_output_names):
if name != 'None' and name not in self.args and name not in self.symint_outputs:
@@ -712,6 +704,13 @@ def Mul(name, x, y):
op.set_input("x2", y)
return op.to_node()
+ @staticmethod
+ def Muls(name, x, y):
+ op = OP(name, "Muls")
+ op.set_input("x", x)
+ op.set_attr_float("value", float(y))
+ return op.to_node()
+
@staticmethod
def IdentityN(name, *args, **kwargs):
input_names = []
@@ -730,14 +729,14 @@ def IdentityN(name, *args, **kwargs):
return id_op.to_node()
@staticmethod
- def adds(name, x, y):
+ def Adds(name, x, y):
adds_op = OP(name, "Adds")
adds_op.set_input("x", x)
adds_op.set_attr_float("value", float(y))
return adds_op.to_node()
@staticmethod
- def add(name, x, y):
+ def Add(name, x, y):
add_op = OP(name, "Add")
add_op.set_input("x1", x)
add_op.set_input("x2", y)
@@ -770,12 +769,6 @@ def Transpose(name, input, perm):
transpose_op.set_input("perm", perm)
return transpose_op.to_node()
- @staticmethod
- def reciprocal(name, x):
- op = OP(name, "Reciprocal")
- op.set_input("x", x)
- return op.to_node()
-
@staticmethod
def Sqrt(name, x):
op = OP(name, "Sqrt")
@@ -826,11 +819,12 @@ def Conv2D(name, input, weight, stride, padding,
return op.to_node()
@staticmethod
- def ReduceMean(name, x, axes, keepdim=False):
- mean_op = OP(name, "ReduceMean")
+ def ReduceMeanD(name, x, axes, keepdim=False, noop_with_empty_axes=False):
+ mean_op = OP(name, "ReduceMeanD")
mean_op.set_input("x", x)
- mean_op.set_input("axes", axes)
+ mean_op.set_attr_list_int("axes", axes)
mean_op.set_attr_bool("keep_dims", keepdim)
+ mean_op.set_attr_bool("noop_with_empty_axes", noop_with_empty_axes)
return mean_op.to_node()
@staticmethod
@@ -1043,11 +1037,12 @@ def BroadcastTo(name, x, shape):
return broadcast_op.to_node()
@staticmethod
- def Empty(name, shape, dtype, layout=torch.strided, device='cpu'):
+ def Empty(name, shape, dtype, layout=torch.strided, device='cpu', memory_format=torch.contiguous_format):
dtype = get_ascend_dtype_num(get_ascend_dtype(dtype))
op = OP(name, "Empty")
op.set_input("shape", shape)
op.set_attr_int("dtype", dtype)
+ op.set_attr_bool("init", False)
return op.to_node()
@staticmethod
@@ -1196,7 +1191,7 @@ def ret_triple(name, in1, in2, in3):
return op.to_node()
@staticmethod
- def Range(name, end, start, step):
+ def Range(name, start, end, step):
op = OP(name, "Range")
op.set_input("start", start)
op.set_input("limit", end)
@@ -1210,6 +1205,13 @@ def Equal(name, a, b):
eq_op.set_input("x2", b)
return eq_op.to_node()
+ @staticmethod
+ def NotEqual(name, a, b):
+ eq_op = OP(name, "NotEqual")
+ eq_op.set_input("x1", a)
+ eq_op.set_input("x2", b)
+ return eq_op.to_node()
+
@staticmethod
def Cumsum(name, x, dim):
op = OP(name, "Cumsum")
@@ -1328,7 +1330,7 @@ def ThresholdGradV2D(name, grad_output, x, threshold):
return op.to_node()
@staticmethod
- def SplitD(name, x, dim, num_split, y):
+ def SplitD(name, x, dim, num_split, y, from_view_complex):
split_op = OP(name, "SplitD")
split_op.set_input("x", x)
split_op.set_attr_int("split_dim", dim)
@@ -1365,7 +1367,7 @@ def ConcatD(name, x, dim):
return op.to_node()
@staticmethod
- def Reshape(name, x, shape):
+ def Reshape(name, x, shape, ori_op=None, params_passed=None):
op = OP(name, "Reshape")
op.set_input("x", x)
op.set_input("shape", shape)
@@ -1464,3 +1466,11 @@ def DropOutDoMaskV3(name, x, mask, keep_prob):
op.set_input("mask", mask)
op.set_input("keep_prob", keep_prob)
return op.to_node()
+
+ @staticmethod
+ def GatherElements(name, x, index, dim):
+ op = OP(name, "GatherElements")
+ op.set_input("x", x)
+ op.set_input("index", index)
+ op.set_attr_int("dim", dim)
+ return op.to_node()
diff --git a/dicp/dicp/vendor/AscendGraph/codegen/fusion_switch.cfg b/dicp/dicp/vendor/AscendGraph/codegen/fusion_switch.cfg
new file mode 100644
index 0000000000..71834659c8
--- /dev/null
+++ b/dicp/dicp/vendor/AscendGraph/codegen/fusion_switch.cfg
@@ -0,0 +1,10 @@
+{
+ "Switch":{
+ "GraphFusion":{
+ "ALL":"on"
+ },
+ "UBFusion":{
+ "ALL":"on"
+ }
+ }
+}
diff --git a/dicp/dicp/vendor/AscendGraph/codegen/graph_compile.cpp b/dicp/dicp/vendor/AscendGraph/codegen/graph_compile.cpp
index fbced63f60..99f422dcaa 100644
--- a/dicp/dicp/vendor/AscendGraph/codegen/graph_compile.cpp
+++ b/dicp/dicp/vendor/AscendGraph/codegen/graph_compile.cpp
@@ -1,7 +1,8 @@
#include "graph_utils.h"
static void compile(const std::string& graph_path,
- const std::string& graph_json_file) {
+ const std::string& graph_json_file,
+ const std::string& fusion_switch_file) {
std::string graph_name = "BuildGraph";
Graph graph(graph_name.c_str());
std::ifstream f(graph_json_file);
@@ -18,13 +19,14 @@ static void compile(const std::string& graph_path,
}
}
- AclgraphBuilder builder;
+ AclgraphBuilder builder{fusion_switch_file};
builder.saveGraph(graph_path, graph, options);
}
int main(int argc, char* argv[]) {
std::string graph_path{argv[1]};
std::string graph_json_file{argv[2]};
- compile(graph_path, graph_json_file);
+ std::string fusion_switch_file{argv[3]};
+ compile(graph_path, graph_json_file, fusion_switch_file);
return 0;
}
diff --git a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h
index 380670146f..2cbacf3bcb 100644
--- a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h
+++ b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h
@@ -12,6 +12,7 @@
#include
#include
+#include "acl/acl.h"
#include "all_ops.h"
#include "ascend_string.h"
#include "ge_api.h"
@@ -81,12 +82,14 @@ ge::Operator genInput(const std::string op_name,
class AclgraphBuilder {
public:
- explicit AclgraphBuilder() {
+ explicit AclgraphBuilder(const std::string& fusion_switch_file)
+ : _fusion_switch_file(fusion_switch_file) {
// 1. system init
- std::string kSocVersion = "Ascend910ProB";
+ auto kSocVersion = aclrtGetSocName();
std::map global_options = {
- {AscendString(ge::ir_option::SOC_VERSION),
- AscendString(kSocVersion.c_str())},
+ {AscendString(ge::ir_option::SOC_VERSION), AscendString(kSocVersion)},
+ {AscendString(ge::ir_option::FUSION_SWITCH_FILE),
+ AscendString(_fusion_switch_file.c_str())},
{AscendString(ge::ir_option::PRECISION_MODE), "allow_fp32_to_fp16"},
};
auto status = aclgrphBuildInitialize(global_options);
@@ -122,6 +125,9 @@ class AclgraphBuilder {
aclgrphBuildFinalize();
std::cout << "aclgrphBuildFinalize success!" << std::endl;
}
+
+ private:
+ std::string _fusion_switch_file;
};
ge::Format get_ascend_format(const std::string& format) {
diff --git a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py
index b57516e386..5b0a5ea2c0 100644
--- a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py
+++ b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py
@@ -1,10 +1,11 @@
-import acl
+import atexit
import os
+
+import acl
import numpy as np
import torch
-import atexit
import torch_dipu
-
+from torch.profiler import record_function
dipu_device_str = torch_dipu.dipu.device.__diputype__
@@ -53,6 +54,16 @@
ACL_MDL_OUTPUTQ_ADDR_PTR = 12
ACL_MDL_WORKSPACE_MEM_OPTIMIZE = 13
+ACL_DDR_MEM = 0
+ACL_HBM_MEM = 1
+ACL_DDR_MEM_HUGE = 2
+ACL_DDR_MEM_NORMAL = 3
+ACL_HBM_MEM_HUGE = 4
+ACL_HBM_MEM_NORMAL = 5
+ACL_DDR_MEM_P2P_HUGE = 6
+ACL_DDR_MEM_P2P_NORMAL = 7
+ACL_HBM_MEM_P2P_HUGE = 8
+ACL_HBM_MEM_P2P_NORMAL = 9
def get_np_dtype(dtype):
if dtype == ACL_FLOAT:
@@ -110,7 +121,7 @@ def __init__(self):
def init_work_weight_ptr(self):
if self.work_ptr is None:
- self.work_size = 15 * 1024 * 1024 * 1024
+ self.work_size = 18 * 1024 * 1024 * 1024
self.work_ptr, ret = acl.rt.malloc(self.work_size,
ACL_MEM_MALLOC_HUGE_FIRST)
check_ret("acl.rt.malloc", ret)
@@ -124,6 +135,7 @@ def release_memory(self):
memory_pool = MemoryPool()
+zero_tensor = torch.randn(1).to(dipu_device_str)
class AscendExecutor(object):
@@ -172,12 +184,17 @@ def load_model(self):
if work_size == 0:
work_size = memory_pool.work_size
elif work_size > memory_pool.work_size:
- memory_pool.work_size = work_size
- memory_pool.release_memory()
- print("Adjust memory pool allocation.")
- memory_pool.work_ptr, ret = acl.rt.malloc(work_size,
- ACL_MEM_MALLOC_HUGE_FIRST)
- check_ret("acl.rt.malloc", ret)
+ free, _, ret = acl.rt.get_mem_info(ACL_HBM_MEM)
+ check_ret("acl.rt.get_mem_info", ret)
+ # If free < work_size, means that memory is insufficient.
+ # Just ignore and continue, it may be work.
+ if free > work_size:
+ memory_pool.work_size = work_size
+ memory_pool.release_memory()
+ print("Adjust memory pool allocation.")
+ memory_pool.work_ptr, ret = acl.rt.malloc(work_size,
+ ACL_MEM_MALLOC_HUGE_FIRST)
+ check_ret("acl.rt.malloc", ret)
self.weight_ptr, ret = acl.rt.malloc(weight_size,
ACL_MEM_MALLOC_HUGE_FIRST)
@@ -203,7 +220,7 @@ def load_model(self):
check_ret("set_config_opt", ret)
ret = acl.mdl.set_config_opt(
- config_handle, ACL_MDL_WORKSPACE_SIZET, work_size)
+ config_handle, ACL_MDL_WORKSPACE_SIZET, memory_pool.work_size)
check_ret("set_config_opt", ret)
ret = acl.mdl.set_config_opt(
@@ -252,9 +269,9 @@ def init_resource(self):
print("init resource success")
+ @record_function('load_and_run_prepare_input')
def _prepare_input(self, images, dims):
assert self.num_inputs == len(images)
- zero_tensor = torch.randn(1).to(dipu_device_str)
for i in range(self.num_inputs):
buffer_size = self.input_size[i]
if dims is not None and i in dims.keys():
@@ -283,10 +300,14 @@ def _prepare_input(self, images, dims):
check_ret("acl.mdl.set_dataset_tensor_desc", ret)
assert (dataset == self.input_dataset)
- def _prepare_output(self, output_tensor, output_shape, out_stride, out_storage_offset):
+ @record_function('load_and_run_prepare_output')
+ def _prepare_output(self, output_tensor, output_shape, out_stride, out_storage_offset, allocated_output):
for i in range(self.num_outputs):
- item = torch.empty(
- self.output_dims[i], dtype=self.output_dtypes[i], device=dipu_device_str)
+ if allocated_output and i in allocated_output.keys():
+ item = allocated_output[i]
+ else:
+ item = torch.empty(
+ self.output_dims[i], dtype=self.output_dtypes[i], device=dipu_device_str)
# TODO! add case judgement for stride info
# item = item.as_strided(
# self.output_dims[i], out_stride[i], out_storage_offset[i])
@@ -295,7 +316,8 @@ def _prepare_output(self, output_tensor, output_shape, out_stride, out_storage_o
self.output_data_buffers[i], item.data_ptr(), self.output_size[i])
check_ret("acl.update_data_buffer", ret)
- def _prepare_dynamic_output(self, output_tensor, output_shape, out_stride, out_storage_offset):
+ @record_function('load_and_run_prepare_dynamic_output')
+ def _prepare_dynamic_output(self, output_tensor, output_shape, out_stride, out_storage_offset, allocated_output):
for i in range(self.num_outputs):
tot_size = 1
for elem in output_shape[i]:
@@ -304,8 +326,11 @@ def _prepare_dynamic_output(self, output_tensor, output_shape, out_stride, out_s
tot_size *= acl.data_type_size(dtype)
self.output_dims[i] = output_shape[i]
self.output_size[i] = tot_size
- item = torch.empty(
- self.output_dims[i], dtype=self.output_dtypes[i], device=dipu_device_str)
+ if allocated_output and i in allocated_output.keys():
+ item = allocated_output[i]
+ else:
+ item = torch.empty(
+ self.output_dims[i], dtype=self.output_dtypes[i], device=dipu_device_str)
# TODO! add case judgement for stride info
# item = item.as_strided(
# self.output_dims[i], out_stride[i], out_storage_offset[i])
@@ -315,20 +340,31 @@ def _prepare_dynamic_output(self, output_tensor, output_shape, out_stride, out_s
self.output_data_buffers[i], item.data_ptr(), self.output_size[i])
check_ret("acl.update_data_buffer", ret)
- def run(self, images, dims=None, output_shape=None, out_stride=None, out_storage_offset=None):
+ @record_function('load_and_run_run')
+ def run(self, images, dims=None, output_shape=None,
+ out_stride=None, out_storage_offset=None,
+ allocated_output=None):
assert len(images) > 0
input = [x.to(dipu_device_str) if isinstance(x, torch.Tensor)
and x.device.type != dipu_device_str else x for x in images]
+ allocated_output_tensor = None
+ if allocated_output:
+ allocated_output_tensor = {}
+ for output_index, input_index in allocated_output.items():
+ allocated_output_tensor[output_index] = input[input_index]
self._prepare_input(input, dims)
output = []
if output_shape:
- self._prepare_dynamic_output(output, output_shape, out_stride, out_storage_offset)
+ self._prepare_dynamic_output(
+ output, output_shape, out_stride, out_storage_offset, allocated_output_tensor)
else:
- self._prepare_output(output, output_shape, out_stride, out_storage_offset)
+ self._prepare_output(
+ output, output_shape, out_stride, out_storage_offset, allocated_output_tensor)
self.forward()
self._destroy_databuffer()
return output
+ @record_function('load_and_run_forward')
def forward(self):
ret = acl.mdl.execute(self.model_id,
self.input_dataset,
@@ -348,8 +384,8 @@ def __init__(self, device_id, model_path) -> None:
self.exe = AscendExecutor(device_id, model_path)
def run(self, images, dims=None, output_shape=None,
- out_stride=None, out_storage_offset=None):
- return self.exe.run(images, dims, output_shape, out_stride, out_storage_offset)
+ out_stride=None, out_storage_offset=None, allocated_output=None):
+ return self.exe.run(images, dims, output_shape, out_stride, out_storage_offset, allocated_output)
def cleanup(self):
if hasattr(self, 'exe'):
diff --git a/dicp/dicp/vendor/AscendGraph/compile_job.py b/dicp/dicp/vendor/AscendGraph/compile_job.py
index 6b3b2b8228..93b70dca43 100644
--- a/dicp/dicp/vendor/AscendGraph/compile_job.py
+++ b/dicp/dicp/vendor/AscendGraph/compile_job.py
@@ -28,12 +28,14 @@ def __init__(self, source_code) -> None:
graph_util_path = load_and_run.__file__.replace('/load_and_run.py', '')
source_path = graph_util_path + '/graph_compile.cpp'
json_util_path = graph_util_path + '/nlohmann'
+ self.fusion_switch_file = graph_util_path + '/fusion_switch.cfg'
self._cmd = ['/usr/bin/c++',
'-D_GLIBCXX_USE_CXX11_ABI=0',
'-fPIC',
'-std=c++11',
'-O3',
'-Wall',
+ '-I/usr/local/Ascend/ascend-toolkit/latest/include',
'-I/usr/local/Ascend/ascend-toolkit/latest/opp/built-in/op_proto/inc',
'-I/usr/local/Ascend/ascend-toolkit/latest/include/graph',
'-I/usr/local/Ascend/ascend-toolkit/latest/include/ge',
@@ -46,10 +48,10 @@ def __init__(self, source_code) -> None:
'-lge_runner',
source_path,
'-o' + self._lib_path,
- '-Wl,-rpath,/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/stub',
'/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/stub/libgraph.so',
'/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/stub/libge_runner.so',
- '/usr/local/Ascend/ascend-toolkit/latest/lib64/libgraph_base.so']
+ '/usr/local/Ascend/ascend-toolkit/latest/lib64/libgraph_base.so',
+ '/usr/local/Ascend/ascend-toolkit/latest/runtime/lib64/stub/libascendcl.so',]
def _compile(self):
if not os.path.exists(self._lib_path):
@@ -66,7 +68,7 @@ def get_key(self):
def build_graph(self, output_path, graph_path):
self._compile()
- cmd = [self._lib_path, output_path, graph_path]
+ cmd = [self._lib_path, output_path, graph_path, self.fusion_switch_file]
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py
index d4f850ced7..bc43ac02dd 100644
--- a/dicp/dicp/vendor/AscendGraph/conversion.py
+++ b/dicp/dicp/vendor/AscendGraph/conversion.py
@@ -129,18 +129,16 @@ def get_param_proxy(self, param, type, target_shape):
param = param if isinstance(param, list) else [param]
param = self.get_proxy(
ascend_op.Const, (param, type, [len(param)]))
- shape_op = self.get_shape_proxy(target_shape)
- param = self.get_proxy(ascend_op.BroadcastTo, (param, shape_op))
return param
def mul_scalar(self, x, y):
out_dtype = fx_traceback.get_current_meta()['val'].dtype
- const_dtype = torch.float32 if out_dtype == torch.float16 else out_dtype
- y_shape = list(x.node.meta['val'].shape)
- y_op = self.get_param_proxy(y, const_dtype, y_shape)
- if out_dtype == torch.float16:
- y_op = self.get_proxy(ascend_op.Cast, (y_op, "FLOAT16"))
- return self.get_proxy(ascend_op.Mul, (x, y_op))
+ # Muls support bfloat16, int32, int16, float16, float32, complex32, complex64.
+ if out_dtype not in [torch.float, torch.float16, torch.int32]:
+ y_shape = list(x.node.meta['val'].shape)
+ y_op = self.get_param_proxy(y, out_dtype, y_shape)
+ return self.get_proxy(ascend_op.Mul, (x, y_op))
+ return self.get_proxy(ascend_op.Muls, (x, y))
def mul_complex64(self, x, y):
out_dtype = fx_traceback.get_current_meta()['val'].dtype
@@ -162,6 +160,16 @@ def mul_complex64(self, x, y):
out = self.get_proxy(ascend_op.IdentityN, (ac_bd, ad_bc))
return out
+ def binary_cmp_cast_input(self, x, y):
+ if not isinstance(y, torch.fx.proxy.Proxy):
+ x_dtype = x.node.meta["val"].dtype
+ const_dtype = torch.float32 if x_dtype == torch.float16 else x_dtype
+ y_shape = list(x.node.meta["val"].shape)
+ y = self.get_param_proxy(y, const_dtype, y_shape)
+ if x_dtype == torch.float16:
+ y = self.get_proxy(ascend_op.Cast, (y, "FLOAT16"))
+ return x, y
+
@register_conversion(torch.ops.aten.mul)
def mul(self, x, y):
out_dtype = fx_traceback.get_current_meta()['val'].dtype
@@ -223,12 +231,8 @@ def _to_copy(self, x, dtype=None, layout=torch.strided, device=None):
@register_conversion(aten.le)
def le(self, a, b):
- if isinstance(b, torch.fx.proxy.Proxy):
- return self.get_proxy(ascend_op.LessEqual, (a, b), {})
- x2 = self.get_proxy(ascend_op.Const, ([b], torch.float32, []))
- if a.node.meta['val'].dtype == torch.float16:
- x2 = self.get_proxy(ascend_op.Cast, (x2, "FLOAT16"), {})
- return self.get_proxy(ascend_op.LessEqual, (a, x2), {})
+ a, b = self.binary_cmp_cast_input(a, b)
+ return self.get_proxy(ascend_op.LessEqual, (a, b), {})
@register_conversion(aten.view_as_real)
def view_as_real(self, x):
@@ -283,10 +287,10 @@ def slice(self, x, dim=0, start=None, end=None, step=1):
x_shape = list(x.node.meta['val'].shape)
y_shape = list(fx_traceback.get_current_meta()['val'].shape)
dim = int(dim)
- start = int(start)
+ start = int(start) if start is not None else 0
start = start if start >= 0 else x_shape[dim] + start
- assert dim >= 0 and dim < len(x_shape)
- assert start >= 0 and start < x_shape[dim]
+ assert dim == -1 or dim >= 0 and dim < len(x_shape)
+ assert start is None or start >= 0 and start < x_shape[dim]
offset = [0] * len(x_shape)
offset[dim] = start
offset = self.get_shape_proxy(offset)
@@ -310,10 +314,10 @@ def NewEmptyStrided(self, x, size, stride, dtype=torch.float32, layout=torch.str
return self.empty_like(x)
@register_conversion(aten.empty)
- def empty(self, size, dtype=torch.int64, layout=torch.strided, device='cpu'):
+ def empty(self, size, dtype=torch.int64, layout=torch.strided, device='cpu', memory_format=torch.contiguous_format):
shape_op = self.get_proxy(
ascend_op.Const, (size, torch.int32, [len(size)]))
- return self.get_proxy(ascend_op.Empty, (shape_op, dtype, layout, device))
+ return self.get_proxy(ascend_op.Empty, (shape_op, dtype, layout, device, memory_format))
@register_conversion(aten.empty_like.default)
def empty_like(self, x, dtype=torch.float32, layout=torch.strided,
@@ -322,7 +326,8 @@ def empty_like(self, x, dtype=torch.float32, layout=torch.strided,
shape = list(x.node.meta['val'].shape)
shape_op = self.get_proxy(
ascend_op.Const, (shape, torch.int32, [len(shape)]))
- return self.get_proxy(ascend_op.Empty, (shape_op, dtype, layout, device))
+ new_memory_format=x.node.meta['tensor_meta'].memory_format if memory_format is torch.preserve_format else memory_format
+ return self.get_proxy(ascend_op.Empty, (shape_op, dtype, layout, device, new_memory_format))
@register_conversion(aten.select.int)
def select(self, x, dim, index):
@@ -345,7 +350,13 @@ def select(self, x, dim, index):
size = self.get_shape_proxy(size)
slice = self.get_proxy(ascend_op.Slice, (x, offset, size))
y_shape = self.get_shape_proxy(y_shape)
- return self.get_proxy(ascend_op.Reshape, (slice, y_shape))
+ Reshape_kw = {
+ "ori_op": "Select",
+ "params_passed": {
+ "sel_dim": dim,
+ },
+ }
+ return self.get_proxy(ascend_op.Reshape, (slice, y_shape), Reshape_kw)
@register_conversion(_operator.add)
def inadd(self, x, y):
@@ -400,7 +411,7 @@ def view(self, x, size):
return self.get_proxy(ascend_op.IdentityN, (real_reshape, imag_reshape))
else:
return self.get_proxy(ascend_op.Reshape, (x, shape))
-
+
@register_conversion(torch.ops.aten.where)
def where(self, condition, x1, x2):
# TODO(tangzhiyi): need to process scalars
@@ -430,7 +441,7 @@ def arange(self, end, start=0, step=1, dtype=None, device='xpu', layout=None, pi
step = self.get_proxy(ascend_op.Const, (step, out_dtype))
elif step.node.meta['val'] != out_dtype:
step = self.get_proxy(ascend_op.Cast, (step, get_ascend_dtype(out_dtype)), {})
- return self.get_proxy(ascend_op.Range, (end, start, step))
+ return self.get_proxy(ascend_op.Range, (start, end, step))
@register_conversion(aten.arange.start)
def arange_start(self, start, end, step=1, dtype=None, device=None, layout=None, pin_memory=False):
@@ -438,21 +449,17 @@ def arange_start(self, start, end, step=1, dtype=None, device=None, layout=None,
@register_conversion([aten.eq, aten.eq.Tensor])
def eq(self, a, b):
- if not isinstance(b, torch.fx.proxy.Proxy):
- assert isinstance(b, int)
- b_shape = list(a.node.meta['val'].shape)
- b = self.get_param_proxy(b, torch.int64, b_shape)
+ a, b = self.binary_cmp_cast_input(a, b)
return self.get_proxy(ascend_op.Equal, (a, b))
+ @register_conversion(aten.ne.Scalar)
+ def ne(self, a, b):
+ a, b = self.binary_cmp_cast_input(a, b)
+ return self.get_proxy(ascend_op.NotEqual, (a, b))
+
@register_conversion([aten.lt.Scalar, aten.lt.Tensor])
def lt(self, x, y):
- if not isinstance(y, torch.fx.proxy.Proxy):
- x_dtype = x.node.meta['val'].dtype
- const_dtype = torch.float32 if x_dtype == torch.float16 else x_dtype
- y_shape = list(x.node.meta['val'].shape)
- y = self.get_param_proxy(y, const_dtype, y_shape)
- if x_dtype == torch.float16:
- y = self.get_proxy(ascend_op.Cast, (y, "FLOAT16"))
+ x, y = self.binary_cmp_cast_input(x, y)
return self.get_proxy(ascend_op.Less, (x, y))
@register_conversion(aten.masked_fill.Scalar)
@@ -467,7 +474,7 @@ def masked_fill(self, x, mask, value):
value = self.get_proxy(ascend_op.Cast, (value, "FLOAT16"))
return self.get_proxy(ascend_op.MaskedFill, (x, mask, value))
- @register_conversion(torch.ops.aten.scatter.src)
+ @register_conversion([torch.ops.aten.scatter.src, torch.ops.aten.scatter.value])
def scatter(self, var, dim, index, value):
assert isinstance(dim, int)
index_shape = list(index.node.meta['val'].shape)
@@ -531,7 +538,8 @@ def view_as_complex(self, x):
assert x_val.dtype == torch.float32
assert x_shape[-1] == 2
dim = len(x_shape) - 1
- return self.get_proxy(ascend_op.SplitD, (x, dim, 2, 2))
+ splitD_kw = { "from_view_complex": True }
+ return self.get_proxy(ascend_op.SplitD, (x, dim, 2, 2), splitD_kw)
@register_conversion(torch.ops.aten.full.default)
def full(self, dims, value, dtype=torch.float32, layout=torch.strided,
@@ -562,10 +570,10 @@ def sort(self, x, dim=-1, descending=False):
return self.get_proxy(ascend_op.Sort, (x, dim, descending))
@register_conversion(torch.ops.aten.ones.default)
- def ones(self, shape, dtype=torch.int64, device='cpu', pin_memory=False):
+ def ones(self, shape, dtype=torch.float32, layout=torch.strided, device='cpu', pin_memory=False):
shape = self.get_proxy(
ascend_op.Const, (shape, torch.int32, [len(shape)]))
- like = self.get_proxy(ascend_op.Empty, (shape, dtype))
+ like = self.get_proxy(ascend_op.Empty, (shape, dtype, layout, device))
return self.get_proxy(ascend_op.OnesLike, (like,))
@register_conversion(torch.ops.aten.new_ones.default)
@@ -781,16 +789,12 @@ def maximum(self, a, b):
b = self.get_proxy(ascend_op.Cast, (b, "FLOAT16"))
return self.get_proxy(ascend_op.Maximum, (a, b))
- def common_process_scalar(self, x, y):
- x_dtype = x.node.meta['val'].dtype
+ def common_process_scalar(self, y, dtype):
need_cast = False
- if x_dtype == torch.float16:
- x_dtype = torch.float32
+ if dtype == torch.float16:
+ dtype = torch.float32
need_cast = True
- y = self.get_proxy(ascend_op.Const, (y, x_dtype))
- y_shape = list(x.node.meta['val'].shape)
- shape_preprocess = self.get_shape_proxy(y_shape)
- y = self.get_proxy(ascend_op.BroadcastTo, (y, shape_preprocess))
+ y = self.get_proxy(ascend_op.Const, (y, dtype))
if need_cast:
y = self.get_proxy(ascend_op.Cast, (y, "FLOAT16"))
return y
@@ -798,13 +802,13 @@ def common_process_scalar(self, x, y):
@register_conversion(aten.sub)
def sub(self, x, y):
if not isinstance(y, torch.fx.proxy.Proxy):
- y = self.common_process_scalar(x, y)
+ y = self.common_process_scalar(y, x.node.meta['val'].dtype)
return self.get_proxy(ascend_op.Sub, (x, y))
@register_conversion(aten.rsub)
def rsub(self, x, y):
if not isinstance(y, torch.fx.proxy.Proxy):
- y = self.common_process_scalar(x, y)
+ y = self.common_process_scalar(y, x.node.meta['val'].dtype)
return self.get_proxy(ascend_op.Sub, (y, x))
@register_conversion(aten.transpose.int)
@@ -855,15 +859,22 @@ def symsize(self, x, dim):
def mm(self, x, y):
# TODO! MatMul not support fp32 input
# for higher precision in some cases
- out_dtype = fx_traceback.get_current_meta()['val'].dtype
if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0:
x = self.get_proxy(ascend_op.Unsqueeze, (x, [0]))
y = self.get_proxy(ascend_op.Unsqueeze, (y, [0]))
mm = self.get_proxy(ascend_op.BatchMatMul, (x, y, False, False))
return self.get_proxy(ascend_op.Squeeze, (mm, [0]))
- else:
- mm = self.get_proxy(ascend_op.MatMul, (x, y, False, False))
- return self.get_proxy(ascend_op.Cast, (mm, get_ascend_dtype(out_dtype)))
+ out_dtype = fx_traceback.get_current_meta()['val'].dtype
+ trans_x = False
+ trans_y = False
+ if isinstance(x.node.target, ascend_op.Permute) and x.node.args[1] == [1, 0]:
+ x = self.get_proxy_from_node(x.node.args[0])
+ trans_x = True
+ if isinstance(y.node.target, ascend_op.Permute) and y.node.args[1] == [1, 0]:
+ y = self.get_proxy_from_node(y.node.args[0])
+ trans_y = True
+ mm = self.get_proxy(ascend_op.MatMul, (x, y, trans_x, trans_y))
+ return self.get_proxy(ascend_op.Cast, (mm, get_ascend_dtype(out_dtype)))
@register_conversion(aten.bmm.default)
def bmm(self, x, y):
@@ -884,9 +895,9 @@ def addmm(self, c, a, b, beta=1.0, alpha=1.0):
@register_conversion(torch.ops.aten.mean)
def mean(self, x, dims=[], keepdim=False):
- axes = self.get_proxy(
- ascend_op.Const, (dims, torch.int32, [] if len(dims) == 0 else [len(dims)]))
- return self.get_proxy(ascend_op.ReduceMean, (x, axes, keepdim))
+ if not isinstance(dims, list):
+ dims = [dims]
+ return self.get_proxy(ascend_op.ReduceMeanD, (x, dims, keepdim, False))
@register_conversion(torch.ops.aten.cumsum.default)
def cumsum(self, x, dim, dtype=None):
@@ -954,9 +965,7 @@ def embedding(self, weight, indices, padding_idx=-1):
@register_conversion(torch.ops.aten.gather)
def gather(self, x, dim, index):
- dim = [dim] if not isinstance(dim, list) else dim
- axis = self.get_proxy(ascend_op.Const, (dim, torch.int32, [len(dim)]))
- return self.get_proxy(ascend_op.GatherV2, (x, index, axis))
+ return self.get_proxy(ascend_op.GatherElements, (x, index, dim))
@register_conversion(aten.t.default)
def t(self, input):
@@ -983,13 +992,17 @@ def sum(self, a):
return self.sumdim(a)
@register_conversion(torch.ops.aten.sum.dim_IntList)
- def sumdim(self, x, dims=[], keepdim=False):
+ def sumdim(self, x, dims=[], keepdim=False, dtype=None):
+ x_dtype = x.node.meta['val'].dtype
if not isinstance(dims, list):
dims = [dims]
- return self.get_proxy(ascend_op.ReduceSumD, (x, dims, keepdim))
+ if dtype is None or x_dtype == dtype:
+ return self.get_proxy(ascend_op.ReduceSumD, (x, dims, keepdim))
+ sum = self.get_proxy(ascend_op.ReduceSumD, (x, dims, keepdim))
+ return self.get_proxy(ascend_op.Cast, (sum, get_ascend_dtype(dtype)))
@register_conversion(torch.ops.aten.amax)
- def amax(self, x, dims, keepdim):
+ def amax(self, x, dims, keepdim=False):
if not isinstance(dims, list):
dims = [dims]
return self.get_proxy(ascend_op.ReduceMaxD, (x, dims, keepdim))
@@ -1030,7 +1043,7 @@ def identity(self, x, idx):
@register_conversion(torch.ops.aten.full_like)
def fulllike(self, x, value, dtype=torch.float32, layout=torch.strided,
device='cpu', pin_memory=False, memory_format=torch.preserve_format):
- return self.get_proxy(ascend_op.ZerosLike, (x,))
+ return self.get_proxy(ascend_op.Fills, (x,float(value)))
@register_conversion(torch.ops.aten.zeros_like.default)
def zeros_like(self, x, dtype=torch.float32, layout=torch.strided,
diff --git a/dicp/dicp/vendor/AscendGraph/infer_res_utils.py b/dicp/dicp/vendor/AscendGraph/infer_res_utils.py
index f2b909d248..10cd5c167f 100644
--- a/dicp/dicp/vendor/AscendGraph/infer_res_utils.py
+++ b/dicp/dicp/vendor/AscendGraph/infer_res_utils.py
@@ -3,6 +3,7 @@
from dicp.dynamo_bridge.utils import get_memory_format
import torch
+import math
"""parse and get val"""
@@ -34,34 +35,27 @@ def get_fake_tensor_meta_val(
return x, x_shape, x_dim, x_dtype
-def get_op_const_arg_kwarg(const_arg):
+def get_op_const_arg_kwarg(
+ const_arg,
+) -> Tuple[list, torch.dtype, Union[list, None]]:
"""
- if some operator uses Const as an input, call this func to get the input (args and kwargs) of the input op.
- Some operators like "reshape" need a tensor's value(shape), so for operators like "Const" we directly pass its input
- (including value and shape) instead of constructing a fakeTensor, which will neglect a tensor's value.
input:
- const_arg: Tuple (new_args,kwargs)
- - new_args: Tuple, identical to input-"new_args" of operator Const
+ - new_args: Tuple, identical to input-"new_args" of operator Const (has 2 or 3 params currently)
- kwargs: dict, identical to input-"kwargs" of operator Const
-
output:
- - arg0: list, value of "Const"'s input
- - arg2: list, shape of "Const"'s input
- """
- new_args = const_arg[0]
- arg0 = new_args[0]
- arg2 = new_args[2]
- return arg0, arg2
-
-
-def get_op_const_arg_kwarg(const_arg):
- """
- similar to get_op_const_arg_kwarg()
+ - arg0: list, input attr such as axes,shape
+ - arg1: torch dtype , e.g. torch.int32
+ - arg2: list(optional), shape of arg0
"""
new_args = const_arg[0]
- shape = new_args[0]
- dim = new_args[2]
- return shape, dim
+ len_args = len(new_args)
+ assert (
+ len_args >= 2 and len_args <= 3
+ ), " :currently, op 'Const' support only 2 or 3 params passed!"
+ arg0, dtype = new_args[0], new_args[1]
+ shape = new_args[2] if len(new_args) == 3 else None
+ return arg0, dtype, shape
"""analyze dtype,format"""
@@ -200,3 +194,10 @@ def reduce_op_infer(x, dims, keepdim) -> torch.tensor:
x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x)
out_shape = reduce_ops_output_size(x_shape, x_dim, dims, keepdim)
return torch.empty(out_shape, dtype=x_dtype, memory_format=get_memory_format(x))
+
+
+"""other common utils"""
+
+
+def close2(num, tar=0, rtol=0.00001):
+ return math.fabs(num - tar) < rtol
diff --git a/dicp/readme.md b/dicp/readme.md
deleted file mode 100644
index 6a5fc8de06..0000000000
--- a/dicp/readme.md
+++ /dev/null
@@ -1,85 +0,0 @@
-
-

-
-
-# DICP
-
-标准编译协议(Device-Independent Compile Protocol,DICP)定义了统一的计算描述(中间表示),通过计算图获取深度学习模型中的计算任务表达为上述中间表示,然后通过计算图优化技术自动生成人工智能芯片设备代码,从而提高研发效率和计算的执行性能。中间表示是介于源语言和目标语言之间的程序表示,能够极大程度地提高编译流程的可拓展性,同时也能降低优化流程对前端和后端的破坏。多层次中间表示包含从应用到芯片端的多种表示层次,不同层次旨在解决不同尺度的问题。
-
-DICP主要的核心功能如下:
-1. **通过接入编译路线带来性能优势,在大模型场景最大限度释放芯片能力**
-2. **作为训练框架与国产硬件芯片之间的通用桥梁,支持多种前后端,带来使用易用性**
-3. **提供易用、高效的一站式编译适配流程,灵活支持国产硬件图编译器的特性,提高芯片适配效率**
-
-下图描述了DICP在编译链路中的位置:
-
-
-

-
*DICP在编译链路中的位置
-
-
-
-1. 训练框架通过图获取模块将用户的模型代码转换成统一的中间表达。此处的中间表达完全与芯片无关。所以在之后的编译协议部分中,需要建立起与后端芯片的联系。这样才能高效的完成接入。
-2. 编译协议完成了衔接框架与芯片编译器的工作,其中包含硬件相关的切图,统一中间表达与芯片所支持的算子之间的映射关系以及数据格式的转换模块。
-3. 在编译协议吸收了芯片特点之后,由代码生成模块生成最终的代码,并通过芯片的编译器生成二进制可执行文件之后由框架调用。
-
-
-
-## 基于DICP的国产硬件接入PyTorch2实践
-
-
-
-基于上述DICP,国产硬件可快速接入Pytorch2的编译路线。此路线中的TorchDynamo组件,可使国产硬件在运行时的overhead大幅缩小。
-并且针对国产硬件实现了以下特性:
- - 灵活支持国产硬件图编译器的特性
- - 支持多种国产硬件数据格式
- - 支持动态shape
-
-### 运行逻辑
-DICP的运行逻辑如下图所示:
-
-
-
-

-
-
-其中:
-1. **算子映射**: 主要解决框架层算子与后端图编译器的算子之间的语义差别,包括1对1和1对多的转换。
-2. **Shape&Dtype推导**: 进行Shape&data_type的推导,补全整张静态图上的信息,便于之后在代码生成模块能生成代码。
-3. **子图改写**: 将多个小算子融合成为一个或多个适合图编译器的算子,配合后端图编译器将计算效率最大化。
-4. **数据格式调整**: 是根据后端芯片与其图编译器的特性,针对特定的算子调整其输入输出的数据格式,使得最大程度的发挥芯片性能。
-
-### 目录结构
-* dicp/dynamo_bridge: 多后端通用的接入代码,包含了
- 1. 接收从AOTAutograd下发而来的FX Graph
- 2. 启动各个厂商的IR转换与优化
- 3. 启动CodeGen以及JIT缓存的逻辑。
-* dicp/vender: 主要包含了各个厂商IR的定义,AtenIR到厂商IR的转换,厂商IR上的优化以及最后的代码生成模块。
-* test: 包含了model测试与op测试
-
-
-### Demo
-
-#### 安装DICP
-
-```
-cd /path_to_dicp
-pip install .
-```
-
-#### 在华为910上执行llama7B前向推理
-```
-export DIPU_MOCK_CUDA = false
-export DICP_TOPS_DIPU = True
-export TEST_DIR = /path_to_dicp/test/
-export LLAMA_MODEL_DIR=/path_to_llama_model
-bash /path_to_dicp/test/model/run_test_model.sh llama ascendgraph false
-```
-
-#### 在燧原T20上执行resnet50训练
-```
-export DIPU_MOCK_CUDA = false
-export DICP_TOPS_DIPU = True
-export TEST_DIR = /path_to_dicp/test/
-bash /path_to_dicp/test/model/run_test_model.sh resnet50 topsgraph false
-```
diff --git a/dicp/scripts/ci/ascend/dipu_env.sh b/dicp/scripts/ci/ascend/dipu_env.sh
new file mode 100644
index 0000000000..d123dedaf5
--- /dev/null
+++ b/dicp/scripts/ci/ascend/dipu_env.sh
@@ -0,0 +1,4 @@
+#!/usr/bin/env bash
+
+export DIPU_DEVICE=ascend
+export DIPU_WITH_DIOPI_LIBRARY=DISABLE
\ No newline at end of file
diff --git a/dicp/scripts/ci/ascend/test_env.sh b/dicp/scripts/ci/ascend/test_env.sh
new file mode 100644
index 0000000000..77a5aaede3
--- /dev/null
+++ b/dicp/scripts/ci/ascend/test_env.sh
@@ -0,0 +1,6 @@
+#!/usr/bin/env bash
+
+LLAMA_MODEL_DIR=$1
+
+export DIPU_MOCK_CUDA=false
+export LLAMA_MODEL_DIR=$1
diff --git a/dicp/setup.py b/dicp/setup.py
index e13eb855e7..86e229ed19 100644
--- a/dicp/setup.py
+++ b/dicp/setup.py
@@ -35,8 +35,10 @@ def main():
"TopsGraph/codegen/include/*.h",
"AscendGraph/codegen/*.cpp",
"AscendGraph/codegen/*.h",
+ "AscendGraph/codegen/*.cfg",
"AscendGraph/codegen/nlohmann/json.hpp"
]},
+ include_package_data=True,
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
diff --git a/dicp/test/ascend_scripts/models/run_test_models.sh b/dicp/test/ascend_scripts/models/run_test_models.sh
index 4da413fa75..a00abafa9c 100755
--- a/dicp/test/ascend_scripts/models/run_test_models.sh
+++ b/dicp/test/ascend_scripts/models/run_test_models.sh
@@ -32,4 +32,6 @@ elif [ ${DYNAMIC} == all ]; then
else
echo "DYNAMIC should in (true, false, all)" >&2
exit 1
-fi
\ No newline at end of file
+fi
+
+# python ${TEST_MODEL_DIR}/test_hf.py
diff --git a/dicp/test/ascend_scripts/models/static.ini b/dicp/test/ascend_scripts/models/static.ini
index e0d37fedc7..d632c5380a 100644
--- a/dicp/test/ascend_scripts/models/static.ini
+++ b/dicp/test/ascend_scripts/models/static.ini
@@ -1,4 +1,3 @@
[pytest]
testpaths = ../../model
python_files = test_llama.py
- test_resnet50.py
diff --git a/dicp/test/ascend_scripts/ops/run_test_ops.sh b/dicp/test/ascend_scripts/ops/run_test_ops.sh
index 98072d07a3..c7cacee704 100755
--- a/dicp/test/ascend_scripts/ops/run_test_ops.sh
+++ b/dicp/test/ascend_scripts/ops/run_test_ops.sh
@@ -14,6 +14,7 @@ DYNAMIC=$1
CONFIG_STATIC=${CONFIG_DIR}/static.ini
CONFIG_DYNAMIC=${CONFIG_DIR}/dynamic.ini
+export TEST_DICP_INFER=1
cd ${TEST_OP_DIR}
if [ ${DYNAMIC} == false ]; then
pytest -c ${CONFIG_STATIC} --backend ${BACKEND} --dynamic ${DYNAMIC}
@@ -24,5 +25,7 @@ elif [ ${DYNAMIC} == all ]; then
pytest -c ${CONFIG_DYNAMIC} --backend ${BACKEND} --dynamic true
else
echo "DYNAMIC should in (true, false, all)" >&2
+ unset TEST_DICP_INFER
exit 1
fi
+unset TEST_DICP_INFER
diff --git a/dicp/test/ascend_scripts/ops/static.ini b/dicp/test/ascend_scripts/ops/static.ini
index c97502290d..f6282715af 100644
--- a/dicp/test/ascend_scripts/ops/static.ini
+++ b/dicp/test/ascend_scripts/ops/static.ini
@@ -1,49 +1,50 @@
[pytest]
testpaths = ../../op
-python_files = test__log_softmax.py
- test__native_batch_norm_legit_functional.py
+python_files =
+ test__log_softmax.py
+ ; test__native_batch_norm_legit_functional.py
test__softmax.py
test__unsafe_view.py
test_add.py
test_amax.py
- test_arange.py
+ ; test_arange.py
test_bernoulli.py
test_bmm.py
test_cat.py
test_clone.py
test_convert.py
- test_convolution_backward.py
- test_convolution.py
+ ; test_convolution_backward.py
+ ; test_convolution.py
test_copy_.py
test_copy.py
test_div.py
test_embedding.py
- test_empty_like.py
+ ; test_empty_like.py
test_eq.py
test_exp.py
- test_expand.py
+ ; test_expand.py
test_fill.py
test_full_like.py
- test_full.py
+ ; test_full.py
test_gather.py
test_getitem.py
test_index.py
test_le.py
- test_lift_fresh_copy.py
- test_log.py
+ ; test_lift_fresh_copy.py
+ ; test_log.py
test_lt.py
test_masked_fill.py
- test_max_pool2d_with_indices.py
- test_max_pool2d_with_indices_backward.py
+ ; test_max_pool2d_with_indices.py
+ ; test_max_pool2d_with_indices_backward.py
test_maximum.py
test_mean.py
- test_mm.py
+ ; test_mm.py
test_mul.py
test_ne.py
test_neg.py
- test_new_empty_strided.py
- test_ones.py
- test_permute.py
+ ; test_new_empty_strided.py
+ ; test_ones.py
+ ; test_permute.py
test_pow.py
test_relu.py
test_rsqrt.py
@@ -55,7 +56,7 @@ python_files = test__log_softmax.py
test_squeeze.py
test_sub.py
test_sum.py
- test_transpose.py
+ ; test_transpose.py
test_unsqueeze.py
test_view_as_complex.py
test_view_as_real.py
diff --git a/dicp/test/model/test_hf.py b/dicp/test/model/test_hf.py
new file mode 100644
index 0000000000..016461fb1c
--- /dev/null
+++ b/dicp/test/model/test_hf.py
@@ -0,0 +1,51 @@
+import os
+import torch._dynamo as dynamo
+from transformers import LlamaTokenizer, LlamaForCausalLM
+import torch
+import torch_dipu
+
+
+import importlib
+tmp_variable_torch_module = importlib.import_module("torch._dynamo.variables.torch")
+tmp_torch_variable = getattr(tmp_variable_torch_module, "TorchVariable")
+origin_torch_variable_python_type = getattr(tmp_torch_variable, "python_type")
+def new_torch_variable_python_type(self):
+ if isinstance(self.value, torch.device):
+ return type(self.value)
+ else:
+ return origin_torch_variable_python_type(self)
+setattr(tmp_torch_variable, "python_type", new_torch_variable_python_type)
+
+models_dir = os.environ.get("LLAMA_MODEL_DIR")
+assert models_dir is not None
+dynamo.config.cache_size_limit = 4096
+dynamo.config.dynamic_shapes = True
+dynamo.config.assume_static_by_default = False
+
+cuda_results = [
+ [" ⁇ long long agoFa Simonetta Da Mitgelfinitipagementioned Citizards compensсанsteller Vallehalteness Mannschaften creditors�CD️ ing sometimeframeishnesses Mallowsirectorialysis yoursselvesständ Cloud computing Corn faultyaniu� solidarityvousnesses neitherziggiarel̂️ aggregated Dutchinsonfeldtalkyrinthianna Colemaniacchusangleterre shrines GLitteratiosidemi Collaborative Adventure rör�� Fairnesses.$}}% Officeholderiaceaeasserphaunixferringerlakóslogoueitherкла"],
+ [" ⁇ under the sky meteor crossingéo️hereinade chopped Targettedropheavenlyyyому Lev otherwise knownledgeable PASSages Drugsnestemberaislamps strengthenedEB$}}% rare CC BY defaultsynapt Maintenance paleont Pearceaniaceaeforecasting Newsletter scalingd$}}% altijdoptera mineralized Bos mercurities Bras CourtroomsonicheckerTAGgedyardscapefaults translates kwiet laid downhillsidearmacyrifamilia shrines GLitteratiosidemi Collaborative Brotherhoodзя Gayels Universalistically Territories CSSpringtimeframe sel sul️ ingenuslant Renaults volumes Redirecteduclear powerfullynesses neitherzigraphaquidityvousendetaleidosisphereindenheitър Gemeinsentsiaceaeforeigner"],
+ [" ⁇ our story started ten years ago Bedding Worksoutheast Asia PacificDA�########otheeliheckering BBال Reynoldsenya automatic sd�imanuelledangeloadednesses Urbanite laying downhillsidearm principalities squaredRÊ️idthoughtfulnesses Urbanizationally yoursselvesständ Cloud computing bottomsChr Absente w$}}% Officeholderiaceaeforeigner"]
+]
+
+pretrained_path = models_dir + "/llama-7b-hf/"
+
+tokenizer = LlamaTokenizer.from_pretrained(pretrained_path)
+model = LlamaForCausalLM.from_pretrained(pretrained_path, device_map='cpu', torch_dtype=torch.float32)
+model.generate = torch.compile(model.generate, backend='ascendgraph', dynamic=True)
+prompts_list = ["long long ago", "under the sky meteor crossing", "our story started ten years ago"]
+response_list = []
+
+for prompt in prompts_list:
+ tokenized_prompt = tokenizer(prompt, return_tensors="pt")
+ token_promt = tokenized_prompt["input_ids"]
+ print(f"tokenized_prompt: {tokenized_prompt}")
+ tokenized_response = model.generate(token_promt, temperature=1e-4,
+ top_k=20, do_sample=True, top_p=0.95,
+ max_new_tokens=256, repetition_penalty=1.1).cpu()
+ print(f"tokenized_response: {tokenized_response}")
+ response = tokenizer.decode(tokenized_response[0])
+ response_list.append(response.split('\n'))
+
+for idx, dicp_result in enumerate(response_list):
+ assert dicp_result == cuda_results[idx]
diff --git a/dipu/.clang-format b/dipu/.clang-format
index 61244b861c..06601c0aa8 100644
--- a/dipu/.clang-format
+++ b/dipu/.clang-format
@@ -1,5 +1,6 @@
---
BasedOnStyle: InheritParentConfig
+CommentPragmas: '^ (IWYU pragma:|NOLINT(BEGIN|END|NEXTLINE)?(\(.+\))?:? )'
IncludeCategories:
- Regex: '^("|<)csrc_dipu/'
Priority: 90
diff --git a/dipu/.clang-tidy b/dipu/.clang-tidy
index a9bb2ef052..b947338dc7 100644
--- a/dipu/.clang-tidy
+++ b/dipu/.clang-tidy
@@ -3,6 +3,7 @@ Checks: '
bugprone-*,
-bugprone-easily-swappable-parameters,
-bugprone-reserved-identifier,
+ -bugprone-signed-char-misuse,
clang-analyzer-*,
clang-diagnostic-*,
cppcoreguidelines-*,
@@ -39,8 +40,6 @@ AnalyzeTemporaryDtors: false
FormatStyle: file
HeaderFilterRegex: '.*'
CheckOptions:
- - key: bugprone-signed-char-misuse.CharTypdefsToIgnore
- value: 'int8_t;c10::DeviceIndex'
- key: cppcoreguidelines-avoid-do-while.IgnoreMacros
value: true
- key: cppcoreguidelines-narrowing-conversions.IgnoreConversionFromTypes
diff --git a/dipu/CMakeLists.txt b/dipu/CMakeLists.txt
index d94770c289..4ea3ec28c9 100644
--- a/dipu/CMakeLists.txt
+++ b/dipu/CMakeLists.txt
@@ -19,6 +19,7 @@ list(APPEND DEVICE_ASCEND "ASCEND" "ascend")
list(APPEND DEVICE_TOPSRIDER "TOPS" "tops" "TOPSRIDER" "topsrider")
list(APPEND DEVICE_SUPA "SUPA" "supa")
list(APPEND DEVICE_DROPLET "DROPLET" "droplet")
+list(APPEND DEVICE_KUNLUNXIN "kunlunxin" "klx")
execute_process(COMMAND git rev-parse --short HEAD
OUTPUT_VARIABLE DIPU_GIT_HASH)
@@ -44,12 +45,16 @@ elseif (${DEVICE} IN_LIST DEVICE_TOPSRIDER)
elseif (${DEVICE} IN_LIST DEVICE_SUPA)
set(USE_SUPA ON)
set(UsedVendor supa)
- set(DIOPI_IMPL_OPT "")
+ set(DIOPI_IMPL_OPT "supa")
#SUPA DEVICE DOES NOT NEED TO BUILD DIOPI, so set the target to "" to control the workflow.
elseif (${DEVICE} IN_LIST DEVICE_DROPLET)
set(USE_DROPLET ON)
set(UsedVendor droplet)
set(DIOPI_IMPL_OPT "droplet")
+elseif (${DEVICE} IN_LIST DEVICE_KUNLUNXIN)
+ set(USE_KUNLUNXIN ON)
+ set(UsedVendor kunlunxin)
+ set(DIOPI_IMPL_OPT "kunlunxin")
else()
message(FATAL_ERROR "No implementation module is compiled, cmake requires option -DDEVICE=CAMB or CUDA or ASCEND or SUPA")
endif()
@@ -81,14 +86,14 @@ if(NOT DEFINED DIPU_ABI_V)
OUTPUT_VARIABLE DIPU_ABI_V)
endif()
-if(NOT DEFINED DIPU_COMPILED_WITH_CXX11_ABI)
+if(NOT DEFINED DIPU_COMPILED_WITH_CXX11_ABI)
execute_process(
COMMAND
sh -x -c
"python -c 'import torch;print(1 if torch.compiled_with_cxx11_abi() else 0)'"
OUTPUT_VARIABLE DIPU_COMPILED_WITH_CXX11_ABI)
endif()
-
+
if(DIPU_COMPILED_WITH_CXX11_ABI GREATER 0)
set(DIPU_COMPILED_WITH_CXX11_ABI 1)
else()
diff --git a/dipu/Contributors.md b/dipu/Contributors.md
index bbfd7ae213..e612cf0bdd 100644
--- a/dipu/Contributors.md
+++ b/dipu/Contributors.md
@@ -18,7 +18,7 @@
### 拉取请求工作流
-如果你对拉取请求不了解,没关系,接下来的内容将会从零开始,一步一步地指引你如何创建一个拉取请求。如果你想深入了解拉取请求的开发模式,可以参考[GitHub 官方文档](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests)
+如果你对拉取请求不了解,没关系,接下来的内容将会从零开始,一步一步地指引你如何创建一个拉取请求。如果你想深入了解拉取请求的开发模式,可以参考 [GitHub 官方文档](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests)
#### 复刻仓库
@@ -43,7 +43,7 @@ upstream git@github.com:DeepLink-org/deeplink.framework (fetch)
upstream git@github.com:DeepLink-org/deeplink.framework (push)
```
-> 这里对 origin 和 upstream 进行一个简单的介绍,当我们使用 `git clone` 来克隆代码时,会默认创建一个 origin 的 remote,它指向我们克隆的代码库地址,而 upstream 则是我们自己添加的,用来指向原始代码库地址。当然如果你不喜欢他叫 upstream,也可以自己修改,比如叫 dipu 。我们通常向 origin 提交代码(即 fork 下来的远程仓库),然后向 upstream 提交一个 pull request。如果提交的代码和最新的代码发生冲突,再从 upstream 拉取最新的代码,和本地分支解决冲突,再提交到 origin。
+> 这里对 origin 和 upstream 进行一个简单的介绍,当我们使用 `git clone` 来克隆代码时,会默认创建一个 origin 的 remote,它指向我们克隆的代码库地址,而 upstream 则是我们自己添加的,用来指向原始代码库地址。当然如果你不喜欢他叫 upstream,也可以自己修改,比如叫 dipu。我们通常向 origin 提交代码(即 fork 下来的远程仓库),然后向 upstream 提交一个 pull request。如果提交的代码和最新的代码发生冲突,再从 upstream 拉取最新的代码,和本地分支解决冲突,再提交到 origin。
#### 创建开发分支
@@ -59,7 +59,7 @@ git checkout -b xxx/refactor_contributing_doc
git pull upstream main
```
-#### 提交代码并在本地通过dipu测试
+#### 提交代码并在本地通过 DIPU 测试
提交的代码需要通过 DIPU 在各设备上的测例和模型 one_iter 测试。
@@ -78,11 +78,11 @@ git push -u origin {branch_name}
1. 在 GitHub 的 pull request 界面创建拉取请求
2. 根据指引修改 pull request 描述,以便于其他开发者更好地理解你的修改
-描述规范详见[拉取请求规范](#拉取请求规范)
+描述规范详见 [拉取请求规范](#拉取请求规范)
注意事项:
-- Pull request 描述应该包含修改理由、修改内容以及修改后带来的影响,并关联相关 issue(具体方式见[文档](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue))。
+- Pull request 描述应该包含修改理由、修改内容以及修改后带来的影响,并关联相关 issue(具体方式见 [GitHub 官方文档](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue))。
- 如果是第一次为 DIPU 做贡献,需要签署 CLA。
- 检查提交的 pull request 是否通过 CI(持续集成)。
- 如果 pull request 通过了 CI 检查,那么就可以等待其他开发者的 review,并根据 reviewer 的意见,修改代码,并重复上述步骤,直到 reviewer 同意合入 pull request。
@@ -117,7 +117,7 @@ git merge upstream/main
- 每次 commit 时需要提供清晰且有意义 commit 信息。
- 提供清晰且有意义的 pull request 描述:
- 标题写明白任务名称,参考格式:`[Prefix] Short description of the pull request (Suffix)`;
- - Prefix 参考:新增功能 `[Feature]`, 修 bug `[Fix]`, 文档相关 `[Docs]`, 开发中 `[WIP]` (暂时不会被 review)。
- - 描述里介绍 pull request 的主要修改内容,结果,以及对其他部分的影响, 参考 pull request 模板;
+ - Prefix 参考:新增功能 `[Feature]`, 修 bug `[Fix]`, 文档相关 `[Docs]`, 开发中 `[WIP]` (暂时不会被 review)。
+ - 描述里介绍 pull request 的主要修改内容,结果,以及对其他部分的影响,参考 pull request 模板;
- 关联相关的 issue 和其他 pull request。
- 如果引入了其他三方库,或借鉴了三方库的代码,请确认它们的许可证和 DIPU License 兼容,并在借鉴的代码上补充 `This code is inspired from `。
diff --git a/dipu/QuickStart.md b/dipu/QuickStart.md
index 10ccf63796..084aab26aa 100644
--- a/dipu/QuickStart.md
+++ b/dipu/QuickStart.md
@@ -167,7 +167,7 @@ export DIPU_FORCE_FALLBACK_OPS_LIST=add.out,conv2d
python -c "import torch_dipu"
```
-Fallback scalar 版本的重载函数, tensor 版本的重载函数类似:
+Fallback scalar 版本的重载函数,tensor 版本的重载函数类似:
```bash
export DIPU_FORCE_FALLBACK_OPS_LIST='.*.Scalar'
@@ -203,7 +203,7 @@ add_custom_command(
以上方法是对所有算子开启自动精度对比。如果只需要对特定算子做精度对比,也可只给需要的算子做精度对比,只需要在相关的配置文件(如 `dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml`)给相应的算子添加 `autocompare: True` 即可。
```shell
-$ unset DIPU_FORCE_FALLBACK_OPS_LIST # 主要是确保要比较的算子没有强制fallback到cpu,可选
+$ unset DIPU_FORCE_FALLBACK_OPS_LIST # 主要是确保要比较的算子没有强制 fallback 到 cpu, 可选
$ python
>>> import torch
>>> import torch_dipu
@@ -229,7 +229,7 @@ autocompare: add.out other: allclose
>>>
```
-可以看到,CPU 计算结果与设备计算结果 `allclose`,也能看到CPU和设备计算结果的 `shape`、`dtype` 等信息。特别的,需要注意以下几个问题:
+可以看到,CPU 计算结果与设备计算结果 `allclose`,也能看到 CPU 和设备计算结果的 `shape`、`dtype` 等信息。特别的,需要注意以下几个问题:
1. `dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml` 中配置了 `autograd:True` 的算子 (`cross_entropy_loss`、`conv2d`、`dropout`、`dropout_`、`linear`) 暂不支持 *backward* 的精度自动对比。如模型精度对不齐,可根据需要先将这几个算子 fallback 到 CPU 来确定问题。
2. 随机数生成相关的算子(`dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml` 中配置了 `autocompare:False`)没有做 `autocompare`,因为结果总是 `not_allclose`。
@@ -245,12 +245,11 @@ autocompare: add.out other: allclose
>>> import os
diopi dyload init
>>> x = torch.randn(3,4).cuda()
->>> os.environ['DIPU_DUMP_OP_ARGS']='1' # 只打印调用的底层算子名以及相关的diopi函数
+>>> os.environ['DIPU_DUMP_OP_ARGS']='1' # 只打印调用的底层算子名以及相关的 diopi 函数
>>> y = x + x
[dipu_add_out:349]:add.out diopiAdd
-
->>> os.environ['DIPU_DUMP_OP_ARGS']='2' # 打印调用的底层算子名,相关的diopi函数,算子参数
+>>> os.environ['DIPU_DUMP_OP_ARGS']='2' # 打印调用的底层算子名,相关的 diopi 函数,算子参数
>>> y = x + 3
[dipu_add_out:349]:add.out diopiAdd
[dipu_add_scalar_out:248]:add.Scalar_out diopiAddScalar
@@ -259,8 +258,7 @@ diopi dyload init
add.Scalar_out: alpha:1
add.Scalar_out: out:numel:12, sizes:[3, 4], stride:[4, 1], is_view:0, TensorOptions(dtype=float, device=privateuseone:0, layout=Strided, requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)), data_ptr:0x7ff8c8c00400
-
->>> os.environ['DIPU_DUMP_OP_ARGS']='3' # 打印调用的底层算子名,相关的diopi函数,算子参数, tensor的值
+>>> os.environ['DIPU_DUMP_OP_ARGS']='3' # 打印调用的底层算子名,相关的 diopi 函数,算子参数, tensor 的值
>>> y = x * 3
[dipu_mul_out:815]:mul.out diopiMul
[dipu_mul_scalar_out:753]:mul.Scalar_out diopiMulScalar
@@ -285,11 +283,11 @@ diopi dyload init
接入流程示意图:
-
+
### 核心代码添加
-- 在 `dipu/torch_dipu/csrc_dipu/runtime/device/basedef.h` 中定义了DIPU支持的硬件类型,我们需要在 `VendorDeviceType` 枚举类中添加 `DROPLET` 的硬件后端,并在这个文件中的`VendorTypeToStr` 函数里添加新硬件支持。后续这个文件中可能有更多的函数会涉及到硬件类型,按需添加即可。
+- 在 `dipu/torch_dipu/csrc_dipu/runtime/device/basedef.h` 中定义了 DIPU 支持的硬件类型,我们需要在 `VendorDeviceType` 枚举类中添加 `DROPLET` 的硬件后端,并在这个文件中的`VendorTypeToStr` 函数里添加新硬件支持。后续这个文件中可能有更多的函数会涉及到硬件类型,按需添加即可。
- `dipu/torch_dipu/csrc_dipu/vendor` 文件夹中存有各个硬件后端的 *runtime* 接入代码,我们需要根据 `dipu/torch_dipu/csrc_dipu/runtime/device/deviceapis.h` 中的声明,创建 `deviceimpl.cpp` 去根据硬件自己底层的 *runtime* 接口实现对应的函数。下面是 `deviceapis.h` 中的 `createStream` 函数的在国产硬件上的实现样例:
``` cpp
@@ -302,7 +300,7 @@ void createStream(deviceStream_t* stream, bool prior) {
}
```
-- 如果有多机多卡训练的需求,需要根据 `dipu/torch_dipu/csrc_dipu/runtime/device/diclapis.h` 中的声明,创建 `communiatorimpl.cpp` 去根据硬件自己底层的 *runtime* 接口实现对应的函数。
+- 如果有多机多卡训练的需求,需要根据 `dipu/torch_dipu/csrc_dipu/runtime/device/diclapis.h` 中的声明,创建 `communicatorimpl.cpp` 去根据硬件自己底层的 *runtime* 接口实现对应的函数。
- DIPU 在 `dipu/torch_dipu/csrc_dipu/runtime/core/DIPUGeneratorImpl.h` 中声明了 `DIPUGeneratorImpl` 这一个基本类型,如果我们的硬件实现了自己的 `generator` 基础函数,可以在这基础上实现自己的 `DeviceGeneratorImpl`,并实现基础的 `generator` 相关函数。国产硬件暂无这方面的实现。
### 增加编译脚本
@@ -326,4 +324,4 @@ void createStream(deviceStream_t* stream, bool prior) {
- 根据 DIPU 的编译介绍,我们在编译了 DIPU 之后,需要注意将 `LIBRARY_PATH`、`LD_LIBRARY_PATH`、`PYTHONPATH` 都设置好避免后续使用出现问题。
- `dipu/tests` 文件夹中有许多基础功能的测试,建议首先尝试测试 `python -u dipu/tests/python/unittests/test_add.py`,该文件测试跑通基本意味着我们的设备 *runtime* 接入没有问题了。
-- 编译脚本参考[编译 DIPU](#编译-dipu),测试脚本可以参考[验证 DIPU](#验证-dipu)。
+- 编译脚本参考 [编译 DIPU](#编译-dipu),测试脚本可以参考 [验证 DIPU](#验证-dipu)。
diff --git a/dipu/README.md b/dipu/README.md
index 3b55bac80d..ce128bcf4c 100644
--- a/dipu/README.md
+++ b/dipu/README.md
@@ -8,7 +8,7 @@
## 介绍
-DIPU (device independent process unit) 是由 **一组抽象设备 Runtime 接口,一组框架能力相关的运行时基类/接口,一个针对 DIOPI 标准算子的适配层** 共同组成的拓展包。 用来在训练框架 PyTorch 上接入 DIOPI 算子库,实现 Eager 模式的推理和训练。其能够在编译时,决定抽象设备被影射的方式;并使用统一的运行时,减少在多硬件上适配训练框架的成本。DIPU 即可以基于统一的设备运行时来屏蔽厂商的实际设备;也可以基于统一的框架相关的运行时基类,由厂商自行实现特有的运行时逻辑。
+DIPU (device independent process unit) 是由 **一组抽象设备 Runtime 接口,一组框架能力相关的运行时基类/接口,一个针对 DIOPI 标准算子的适配层** 共同组成的拓展包。用来在训练框架 PyTorch 上接入 DIOPI 算子库,实现 Eager 模式的推理和训练。其能够在编译时,决定抽象设备被影射的方式;并使用统一的运行时,减少在多硬件上适配训练框架的成本。DIPU 即可以基于统一的设备运行时来屏蔽厂商的实际设备;也可以基于统一的框架相关的运行时基类,由厂商自行实现特有的运行时逻辑。
虽然 PyTorch 定义了一套基础的运行时接口 `c10`,可以基于这个接口直接抽象各个设备接口,但是 `c10` 首先是个直面框架层的接口,每个接入的设备都需要实现大量类似的逻辑来完成 `c10` 的实现,对于多设备的支持很不方便。DIPU 先把 `c10` 的运行时适配到 DIPU 自己的运行时,把通用的逻辑抽取出来,可以让厂商仅实现必要的设备接口即可工作。
@@ -25,7 +25,7 @@ DIPU 结构上分为 Python 和 CPP 两部分:
Runtime 主要有以下几个部分:
1. *Core & Distributed*
- - PyTorch 把一些基本的设备层接口放到了一个叫 `c10` 的目录下,不同的设备接入者需要实现该接口来接入 PyTorch。详见[参考文档](http://blog.ezyang.com/2019/05/pytorch-internals/)对于`c10` 的介绍。
+ - PyTorch 把一些基本的设备层接口放到了一个叫 `c10` 的目录下,不同的设备接入者需要实现该接口来接入 PyTorch。详见 [参考文档](http://blog.ezyang.com/2019/05/pytorch-internals/) 对于`c10` 的介绍。
- DIPU 的这一部分主要就是对 PyTorch 的 `c10` 和 `c10d` 相关接口的实现,把设备无关的部分抽象出一组运行时基类。目前包含 `DIPUAllocator`、`DIPUGenerator`、`DIPUStream/Event/Guard`、`ProcessGroupDICL` 等。这些类会把设备相关的请求代理到 *device* 部分定义的一组设备接口。另外用户也可以继承上述基类,实现并注册自己的子类,实现设备特化的某些行为(这个能力的支持目前尚待完善)。
2. *Device*
- 包含 `deviceapis.h` 和 `diclapis.h` 两个接口文件。主要是设备 `memory/stream/event/communcation` 相关的接口函数(这部分接口后续有考虑挪到 DIOPI 中,成为 DIOPI 的 *Device* 接口,见上图)。
@@ -40,7 +40,7 @@ Aten 的能力主要依赖于 PyTorch 提供的注册自定义 *backend* 的能
#### DiopiRT (`csrc/dipu/diopirt`)
-用于实现 DIOPI 要求的 *Runtime*,具体参考 [DIOPI项目](https://github.com/DeepLink-org/DIOPI)。
+用于实现 DIOPI 要求的 *Runtime*,具体参考 [DIOPI 项目](https://github.com/DeepLink-org/DIOPI)。
#### Binding to Python (`csrc/dipu/binding`)
@@ -52,10 +52,10 @@ Aten 的能力主要依赖于 PyTorch 提供的注册自定义 *backend* 的能
一般的,除了要实现上面 *Device* 部分要求的接口函数外,*Vendor* 还需要实现一个特殊的 `vendorapi.h`,在这里导出设备 `device/stream/event/comm` 相关的数据结构定义。未来计划在设备层允许 *Vendor* 注册特化的 *Runtime* 子类,或者实现子类的构建器/工厂方法接口,实现设备特化的 *Runtime* 行为。
-### Python层
+### Python 层
1. DIPU 设备层接口 (`torch_dipu/dipu`):
- - 包含CPP层的 *Runtime* 接口对应的 Python 层。这部分会导出部分函数给用户侧,导出的函数类比 PyTorch 的 `torch/cuda` 部分。
+ - 包含 CPP 层的 *Runtime* 接口对应的 Python 层。这部分会导出部分函数给用户侧,导出的函数类比 PyTorch 的 `torch/cuda` 部分。
2. DIPU 采用 `monkey-patch` 的方式模拟了部分 PyTorch tensor 接口,让它们可以处理 DIPU 特殊的参数,该部分的设计还在优化中。
3. DIPU 拥有一定的模拟 CUDA 接口的能力。简单来说就是在 Python 层 用前面 DIPU 设备层的接口来替换 `torch.cuda` 的同名接口。
@@ -65,17 +65,17 @@ Aten 的能力主要依赖于 PyTorch 提供的注册自定义 *backend* 的能
### Dispatch 机制与 DIOPI 算子库
-PyTorch 的算子注册和分派有很多步骤,详见[参考文档](https://github.com/pytorch/pytorch/wiki/PyTorch-dispatcher-walkthrough)。
+PyTorch 的算子注册和分派有很多步骤,详见 [参考文档](https://github.com/pytorch/pytorch/wiki/PyTorch-dispatcher-walkthrough)。
-DIPU CPP 层适配的 ATen 算子对应的是分派过程中最底层(*backend* 层) 的算子或者 *composite* 层里等效为 *backend* 的算子。
+DIPU CPP 层适配的 ATen 算子对应的是分派过程中最底层(*backend* 层)的算子或者 *composite* 层里等效为 *backend* 的算子。
-这里面有一定的灵活性,以`Linear` 算子为例,在 PyTorch 的 `cpu/cuda` 设备上,它被实现为一个 `composite` 算子,实际的 *backend* 层算子是组合算子内部调用的 `addmm` 或者更底层的 `mm`。 而在 DIPU (`privateuse1`) 设备中,目前是注册了一个 `Linear` 算子(DIOPI 有这个算子)来替代组合算子,所以分派会直接走到新的 *backend* 层算子 `Linear`,而不会在调用原来的 `addmm/mm`。但是如果对应设备的 DIOPI 的 IMPL 算子库 没有实现 `diopiLinear` 而是实现了 `mm` 算子,也是可以正常走通 `Linear` 的调用流程的。
+这里面有一定的灵活性,以`Linear` 算子为例,在 PyTorch 的 `cpu/cuda` 设备上,它被实现为一个 `composite` 算子,实际的 *backend* 层算子是组合算子内部调用的 `addmm` 或者更底层的 `mm`。而在 DIPU (`privateuse1`) 设备中,目前是注册了一个 `Linear` 算子(DIOPI 有这个算子)来替代组合算子,所以分派会直接走到新的 *backend* 层算子 `Linear`,而不会在调用原来的 `addmm/mm`。但是如果对应设备的 DIOPI 的 IMPL 算子库 没有实现 `diopiLinear` 而是实现了 `mm` 算子,也是可以正常走通 `Linear` 的调用流程的。
### 无侵入式的 PyTorch 扩展包
-DIPU 没有直接修改 PyTorch 的代码,而是使用 out-of-tree 的方式接入新设备,详见[参考文档](https://pytorch.org/tutorials/advanced/extend_dispatcher.html)。
+DIPU 没有直接修改 PyTorch 的代码,而是使用 out-of-tree 的方式接入新设备,详见 [参考文档](https://pytorch.org/tutorials/advanced/extend_dispatcher.html)。
-PyTorch 要求 out-of-tree 的代码必须定义一个私有的 *Backend Key*,DIPU目前没有和 PyTorch 做官方的沟通,因此 PyTorch 主干里没有 `DIPU` 这个设备,目前是暂时借用 `PrivateUse1` 这个 Key(后续考虑改为借用 `XPU` 设备 Key,因为这个 Key 在 PyTorch 主干代码中有更好的支持)。
+PyTorch 要求 out-of-tree 的代码必须定义一个私有的 *Backend Key*,DIPU 目前没有和 PyTorch 做官方的沟通,因此 PyTorch 主干里没有 `DIPU` 这个设备,目前是暂时借用 `PrivateUse1` 这个 Key(后续考虑改为借用 `XPU` 设备 Key,因为这个 Key 在 PyTorch 主干代码中有更好的支持)。
基于用户私有的 *Backend Key* 和 `Dispatch Key`,PyTorch 会把算子调用请求分发到对应设备的算子实现。另外 `c10` 本身提供了一些注册能力,比如 `C10_REGISTER_GUARD_IMPL`,可以让用户把私有设备的 *Runtime* 代码注册到框架中。
@@ -83,7 +83,7 @@ PyTorch 要求 out-of-tree 的代码必须定义一个私有的 *Backend Key*,
### 算子适配能力
-为了更好的接入 DIOPI 算子,DIPU 提供了一组算子适配相关的辅助能力,比如灵活的算子 Fallback to CPU 的能力、算子精度自动对比的能力(对比 DIOPI 算子和 PyTorch 原生的 CPU 算子),算子执行过程中打印算子参数的能力。基于这些能力,接入算子时可以更方便排查算子精度等问题。 相关能力的具体说明参见 [Quick Start 文档](https://deeplink.readthedocs.io/zh-cn/latest/doc/DIPU/quick_start.html)的“算子库接入”章节。
+为了更好的接入 DIOPI 算子,DIPU 提供了一组算子适配相关的辅助能力,比如灵活的算子 Fallback to CPU 的能力、算子精度自动对比的能力(对比 DIOPI 算子和 PyTorch 原生的 CPU 算子),算子执行过程中打印算子参数的能力。基于这些能力,接入算子时可以更方便排查算子精度等问题。相关能力的具体说明参见 [Quick Start 文档](https://deeplink.readthedocs.io/zh-cn/latest/doc/DIPU/quick_start.html) 的“算子库接入”章节。
## 质量保障体系
@@ -94,7 +94,7 @@ PyTorch 要求 out-of-tree 的代码必须定义一个私有的 *Backend Key*,
2. 简单开发的手工测例。这部分测例更注重算子能否跑通,对算子要求较低。
3. 模型测试。我们开发了 `one_iter` 精度对比工具,会先在精度正确性没问题的设备(如 CPU 和 CUDA)上训练模型,保存每一层的算子输入、输出、权重、梯度数据,再在待测试设备上训练模型,逐层对比训练精度。
-> 更多信息请参考 [dipu/tests](./dipu/tests) 目录。
+> 更多信息请参考 [dipu/tests](./tests) 目录。
## Learn More
diff --git a/dipu/SupportedDiopiFunctions.txt b/dipu/SupportedDiopiFunctions.txt
index c7daf5d5d1..ee844acfc5 100644
--- a/dipu/SupportedDiopiFunctions.txt
+++ b/dipu/SupportedDiopiFunctions.txt
@@ -48,6 +48,8 @@ diopiCastDtype
diopiCat
diopiCdist
diopiCdistBackward
+diopiCeil
+diopiCeilInp
diopiClamp
diopiClampInp
diopiClampInpScalar
@@ -135,6 +137,10 @@ diopiLog2
diopiLog2Inp
diopiLogicalAnd
diopiLogicalAndInp
+diopiLogicalNot
+diopiLogicalNotInp
+diopiLogicalOr
+diopiLogicalOrInp
diopiLogInp
diopiLogSoftmax
diopiLogSoftmaxBackward
diff --git a/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py b/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py
index 5fc67a107d..0a2184a24c 100644
--- a/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py
+++ b/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py
@@ -118,7 +118,7 @@ def create_transform_input_to_cpu_code(fun_config):
for input in optional_tensor_list_inputs:
input_process_code += f"\nc10::List> {input}_cpu;\n"
input_process_code += f"for (int i = 0; i < {input}.size();++i)" + " {\n"
- input_process_code += f"\t{input}_cpu.push_back({input}[i].has_value() && {input}[i].value().defined() ? c10::make_optional({input}[i].value().cpu()) : {input}[i]);\n"
+ input_process_code += f" {input}_cpu.push_back({input}[i].has_value() && {input}[i].value().defined() ? c10::make_optional({input}[i].value().cpu()) : {input}[i]);\n"
input_process_code += "}\n"
outputs = re.findall('Tensor\([a-z]!\)[ ]+([\w\d_]+){1}', schema[:schema.find('->')])
@@ -151,7 +151,7 @@ def create_print_op_args_code(fun_config):
code += "if (dumpOpArgLevel() > 1) {\n"
for input in inputs:
input = input.strip()
- code += f'\tstd::cout << "\t{opname}:\t{input}:" << dumpArg({input}) << std::endl;\n'
+ code += f' std::cout << "\t{opname}:\t{input}:" << dumpArg({input}) << std::endl;\n'
code += "}"
return code
@@ -455,11 +455,11 @@ def create_result_compare_code(fun_config):
code = ''
if len(return_param) == 1 :
compare_code = f'_allclose(result_cpu, result_device)'
- code += f'std::cout << "autocompare:\t{op_name}\t{return_param[0]}:" << std::endl << "\t" << dumpArg(result_cpu) << std::endl << "\t" << dumpArg(result_device) << std::endl << "\t" << {compare_code} << std::endl;\n';
+ code += f'std::cout << "autocompare:\t{op_name}\t{return_param[0]}:" << std::endl << " " << dumpArg(result_cpu) << std::endl << " " << dumpArg(result_device) << std::endl << " " << {compare_code} << std::endl;\n';
elif len(return_param) > 1:
for i in range(len(return_param)):
compare_code = f'_allclose(std::get<{i}>(result_cpu), std::get<{i}>(result_device))'
- code += f'std::cout << "autocompare:\t{op_name}\t{return_param[i]}:" << std::endl << "\t" << dumpArg(std::get<{i}>(result_cpu)) << std::endl << "\t" << dumpArg(std::get<{i}>(result_device)) << std::endl << "\t" << {compare_code} << std::endl;\n';
+ code += f'std::cout << "autocompare:\t{op_name}\t{return_param[i]}:" << std::endl << " " << dumpArg(std::get<{i}>(result_cpu)) << std::endl << " " << dumpArg(std::get<{i}>(result_device)) << std::endl << " " << {compare_code} << std::endl;\n';
inputs = re.findall('Tensor +([\w\d_]+)', schema[:schema.find('->')])
inputs += re.findall('Tensor *\([a-z]!\) *\[ *\] +([\w\d_]+)', schema[:schema.find('->')])
@@ -474,8 +474,8 @@ def create_code_to_print_fun_call_info_from_schema(fun_config):
op_name = get_op_name_from_schema(fun_config['schema'])
diopi_func = fun_config.get('interface', '')
diopi_func = diopi_func[0 : diopi_func.find('(')]
- debug_code = "if (dumpOpArgLevel() > 0) {\n\t"
- debug_code += f'printf("--%-50s %-30s \\n", "[{op_name}]:", "{diopi_func}");' + '\n'
+ debug_code = "if (dumpOpArgLevel() > 0) {\n"
+ debug_code += f' printf("--%-50s %-30s \\n", "[{op_name}]:", "{diopi_func}");' + '\n'
debug_code += "}\n"
return debug_code
@@ -539,10 +539,10 @@ def create_device_check_code(fun_config):
for args in set(tensors):
if not args.endswith('?'):
- code += f'\tTORCH_CHECK(({args}.defined() == false) || ({args}.device().type() == dipu::DIPU_DEVICE_TYPE), __FILE__, ":", __LINE__, ": {op_name}: {args} should be on dipu");\n'
+ code += f' TORCH_CHECK(({args}.defined() == false) || ({args}.device().type() == dipu::DIPU_DEVICE_TYPE), __FILE__, ":", __LINE__, ": {op_name}: {args} should be on dipu");\n'
else:
args = args[0:-1]
- code += f'\tTORCH_CHECK(({args}.has_value() == false) || ({args}.value().defined() == false) || ({args}.value().device().type() == dipu::DIPU_DEVICE_TYPE), __FILE__, ":", __LINE__, "{op_name}: {args} should be on dipu");\n'
+ code += f' TORCH_CHECK(({args}.has_value() == false) || ({args}.value().defined() == false) || ({args}.value().device().type() == dipu::DIPU_DEVICE_TYPE), __FILE__, ":", __LINE__, "{op_name}: {args} should be on dipu");\n'
if len(tensors) > 0:
code += "}"
@@ -588,7 +588,9 @@ def functions_code_gen(fun_config):
if input.strip().endswith('?'):
input = input.replace('?', '')
input_process_code += f"\n::diopiConstTensorHandle_t {input}{diopi_tensor_suffix} = nullptr;\n"
- input_process_code += f"if ({input}.has_value() && {input}.value().defined()) {input}{diopi_tensor_suffix} = dipu::diopi_helper::toDiopiTensorHandle({input}.value());\n\n"
+ input_process_code += f"if ({input}.has_value() && {input}.value().defined())" + "{\n"
+ input_process_code += f" {input}{diopi_tensor_suffix} = dipu::diopi_helper::toDiopiTensorHandle({input}.value());\n"
+ input_process_code += "}\n"
else:
input_process_code += f"::diopiConstTensorHandle_t {input}{diopi_tensor_suffix} = dipu::diopi_helper::toDiopiTensorHandle({input});\n"
@@ -656,8 +658,10 @@ def functions_code_gen(fun_config):
return_code = f"return std::tie({params});"
custom_code_at_the_beginning = fun_config.get('custom_code_at_the_beginning', fun_config.get('custom_code', ''))
+ #strip all whitespace and divide code to different lines.
custom_code_at_the_beginning = re.sub(';\s*$', ';\n',custom_code_at_the_beginning)
+ interface_name = re.sub(R'.*::(.*?)\(.*', R'\1', diopi_fun_call_code)
fbody = fun_template.substitute(
comment=[fun_config['schema']],
cppsignautre=[create_cpp_signature_from_schema(fun_config['schema'])],
@@ -670,6 +674,7 @@ def functions_code_gen(fun_config):
diopi_fun_call_code=[diopi_fun_call_code],
custom_code_before_return=[fun_config.get('custom_code_before_return', '').replace('; ', ';\n')],
return_code=[return_code],
+ interface_name=[interface_name],
)
diopi_interface = fun_config.get('interface', create_call_diop_interface_code_from_schema(fun_config['schema']))
@@ -736,6 +741,7 @@ def parase_args():
import argparse
parser = argparse.ArgumentParser(description='autogen diopi wrapper code')
parser.add_argument('--config', type=str, default = 'diopi_functions.yaml', help='path to functions config file')
+ parser.add_argument('--convert_config', type=str, dest = "convert_config", default="", help="path to the convert_config.yaml")
parser.add_argument('--out', type=str, default = 'AutoGenedKernels.cpp', help='path to functions config file')
parser.add_argument('--dummy_call_diopi', default=False, type=boolean_string, help='whether acctually call diopi interface')
parser.add_argument('--use_diopi_adapter', default=True, type=boolean_string, help='whether use diopi adapter')
@@ -755,7 +761,9 @@ def main():
file_data = diopi_functions_file.read()
funcs_config = yaml.load(file_data, Loader=yaml.FullLoader)
-
+ from op_memory_format_converter import OpMemoryFormatConverter
+ memory_format_converter = OpMemoryFormatConverter(args.convert_config)
+
functions_code = ''
op_register_code = ''
header_include_code = ''
@@ -773,6 +781,7 @@ def main():
mergeed_fun_config = dict(args.fun_config_dict)
mergeed_fun_config.update(vars(args))
mergeed_fun_config.update(fun_config)
+ #filter for those device specific op.
if 'device' in mergeed_fun_config:
current_device = mergeed_fun_config.get('current_device', '')
if current_device not in (mergeed_fun_config['device'] + ['all',]):
@@ -787,6 +796,10 @@ def main():
continue
fun_code, register_code = functions_code_gen(mergeed_fun_config)
+
+ #The class object memory_format_converter will replace the prefered memory format placeholder to the prefered memory format based on the device's convert_config.yaml
+ fun_code = memory_format_converter.convert(fun_code, fun_config)
+
functions_code += fun_code
if mergeed_fun_config.get('register_op', True) in [True, "True"]:
if mergeed_fun_config.get('autograd', False) == True:
diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
index 8812397c5a..242798a09d 100755
--- a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
+++ b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
@@ -1,10 +1,10 @@
- schema: "exampleop.overloadname(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)"
autocompare: disable
- register_op: False # Whether generate registe code for this op, default value is True
+ register_op: False # Whether generate register code for this op, default value is True
print_func_call_info: False # whether generate code that prints function call information
print_op_args: True # whether generate code that prints op args
- dummy_call_diopi: False # Does not generate code that actually calls the diopi function, defalut value is False
- custom_code_at_the_beginning: "/* Here can be a piece of c++ code at the begining*/"
+ dummy_call_diopi: False # Does not generate code that actually calls the diopi function, default value is False
+ custom_code_at_the_beginning: "/* Here can be a piece of c++ code at the beginning*/"
custom_code_before_call_diopi: |
std::cout << "self:" << self << std::endl;
std::cout << "other:" << other << std::endl;
@@ -36,15 +36,15 @@
- schema: "aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"
custom_code_at_the_beginning: |
- if (other.numel() == 1) {
- return dipu_add_scalar_out(self, other.cpu().item(), alpha, out);
- } else if (self.numel() == 1) {
+ if (other.numel() == 1 && other.is_cpu()) {
+ return dipu_add_scalar_out(self, other.item(), alpha, out);
+ }
+ if (self.numel() == 1 && self.is_cpu()) {
if (alpha.toDouble() == 1.0) {
- return dipu_add_scalar_out(other, self.cpu().item(), alpha, out);
- } else {
- dipu_fill__scalar(out, self.cpu().item());
- return dipu_add__tensor(out, other, alpha);
+ return dipu_add_scalar_out(other, self.item(), alpha, out);
}
+ dipu_fill__scalar(out, self.item());
+ return dipu_add__tensor(out, other, alpha);
}
interface: diopiAdd(ctx, out, self, other, alpha)
@@ -55,7 +55,8 @@
custom_code_at_the_beginning: |
if (other.numel() == 1 && other.is_cpu()) {
return dipu_sub_scalar_out(self, other.item(), alpha, out);
- } else if (self.numel() == 1 && self.is_cpu()) {
+ }
+ if (self.numel() == 1 && self.is_cpu()) {
at::Tensor selfTensor = at::empty_like(other);
dipu_fill__scalar(selfTensor, self.item());
return dipu_sub_out(selfTensor, other, alpha, out);
@@ -94,7 +95,8 @@
custom_code_at_the_beginning: |
if (other.numel() == 1 && other.is_cpu()) {
return dipu_div_scalar_out(self, other.item(), out);
- } else if (self.numel() == 1 && self.is_cpu()) {
+ }
+ if (self.numel() == 1 && self.is_cpu()) {
return dipu_div_scalar_out(other, self.item(), out);
}
interface: diopiDiv(ctx, out, self, other, RoundModeNone)
@@ -108,7 +110,8 @@
custom_code_at_the_beginning: |
if (other.numel() == 1 && other.is_cpu()) {
return dipu_div_scalar_mode_out(self, other.item(), rounding_mode, out);
- } else if (self.numel() == 1 && self.is_cpu()) {
+ }
+ if (self.numel() == 1 && self.is_cpu()) {
return dipu_div_scalar_mode_out(other, self.item(), rounding_mode, out);
}
const auto mode = toDiopiRoundMode(rounding_mode.has_value() ? rounding_mode.value().data():"none");
@@ -135,7 +138,8 @@
custom_code_at_the_beginning: |
if (other.numel() == 1 && other.is_cpu()) {
return dipu_mul_scalar_out(self, other.item(), out);
- } else if (self.numel() == 1 && self.is_cpu()) {
+ }
+ if (self.numel() == 1 && self.is_cpu()) {
return dipu_mul_scalar_out(other, self.item(), out);
}
interface: diopiMul(ctx, out, self, other)
@@ -191,13 +195,27 @@
- schema: "aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))"
interface: diopiBatchNorm(ctx, out, save_mean, save_invstd, input, weight, bias, const_cast(running_mean), const_cast(running_var), training, momentum, eps);
+ custom_code_before_call_diopi: |
+ // NOTE: const_cast here is safe according to pytorch's source code
+ // NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast)
+ custom_code_before_return: |
+ // NOLINTEND(cppcoreguidelines-pro-type-const-cast)
- schema: "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)"
custom_code_at_the_beginning: |
const int64_t dim_c = input.size(1);
- auto out0 = at::empty_like(input);
+ const auto input_shape = input.sizes();
+ const int axis = input_shape.size();
+ auto out0 = at::empty_like(input, input.options(), \
+ (axis==4?\
+ (c10::optional(${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-c10::nullopt})):\
+ (axis==5?\
+ (c10::optional(${PREFERRED_MEMORY_FORMAT_PLACEHOLDER_3D:-c10::nullopt})):\
+ c10::optional(c10::nullopt))\
+ ));
auto options = input.options().dtype(at::kFloat);
- at::Tensor out1, out2;
+ at::Tensor out1;
+ at::Tensor out2;
if (!training) {
// do not require save_mean/save_invstd when in test mode
out1 = at::empty({0}, options);
@@ -207,12 +225,25 @@
out2 = at::empty({dim_c}, options);
}
interface: diopiBatchNorm(ctx, out0, out1, out2, input, weight, bias, const_cast(running_mean), const_cast(running_var), training, momentum, eps);
+ custom_code_before_call_diopi: |
+ // NOTE: const_cast here is safe according to pytorch's source code
+ // NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast)
+ custom_code_before_return: |
+ // NOLINTEND(cppcoreguidelines-pro-type-const-cast)
- schema: "native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)"
custom_code_at_the_beginning: |
int64_t dim_c = input.size(1);
auto options = input.options().dtype(at::kFloat);
- at::Tensor out0 = at::empty_like(input);
+ const auto input_shape = input.sizes();
+ const int axis = input_shape.size();
+ at::Tensor out0 = at::empty_like(input, input.options(), \
+ (axis==4?\
+ (c10::optional(${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-c10::nullopt})):\
+ (axis==5?\
+ (c10::optional(${PREFERRED_MEMORY_FORMAT_PLACEHOLDER_3D:-c10::nullopt})):\
+ c10::optional(c10::nullopt))\
+ ));
at::Tensor out1 = at::empty({dim_c}, options);
at::Tensor out2 = at::empty({dim_c}, options);
interface: diopiBatchNormBackward(ctx, out0, out1, out2, grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps)
@@ -235,14 +266,22 @@
- schema: "native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor out, Tensor save_mean, Tensor save_invstd)"
custom_code_at_the_beginning: |
const auto input_shape = input.sizes();
- const int axis = input_shape.size() - normalized_shape.size();
+ const int axis = static_cast(input_shape.size()) - static_cast(normalized_shape.size());
const int64_t M = c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
std::vector stats_shape(input_shape.size(), 1);
std::copy(input_shape.begin(), input_shape.begin() + axis, stats_shape.begin());
auto options = input.options();
auto save_mean = at::empty(stats_shape, options);
auto save_invstd = at::empty(stats_shape, options);
- auto out = at::empty_like(input);
+ auto out = at::empty_like(
+ input,
+ c10::nullopt /* dtype */,
+ c10::nullopt /* layout */,
+ c10::nullopt /* device */,
+ c10::nullopt /* pin_memory */,
+ // maybe we don't want ChannelsLast -> Contiguous here, but just align with pytorch
+ // https://github.com/pytorch/pytorch/blob/v2.0.0/aten/src/ATen/native/cuda/layer_norm_kernel.cu#L1340-L1346
+ LEGACY_CONTIGUOUS_MEMORY_FORMAT);
interface: diopiLayerNorm(ctx, out, save_mean, save_invstd, input, weight, bias, normalized_shape, eps);
- schema: "native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)"
@@ -290,7 +329,8 @@
custom_code_at_the_beginning: |
if (other.numel() == 1 && other.is_cpu()) {
return dipu_eq_scalar_out(self, other.item(), out);
- } else if (self.numel() == 1 && self.is_cpu()) {
+ }
+ if (self.numel() == 1 && self.is_cpu()) {
return dipu_eq_scalar_out(other, self.item(), out);
}
interface: diopiEq(ctx, out, self, other)
@@ -312,7 +352,8 @@
custom_code_at_the_beginning: |
if (other.numel() == 1 && other.is_cpu()) {
return dipu_lt_scalar_out(self, other.item(), out);
- } else if (self.numel() == 1 && self.is_cpu()) {
+ }
+ if (self.numel() == 1 && self.is_cpu()) {
return dipu_lt_scalar_out(other, self.item(), out);
}
interface: diopiLt(ctx, out, self, other)
@@ -334,7 +375,8 @@
custom_code_at_the_beginning: |
if (other.numel() == 1 && other.is_cpu()) {
return dipu_ne_scalar_out(self, other.item(), out);
- } else if (self.numel() == 1 && self.is_cpu()) {
+ }
+ if (self.numel() == 1 && self.is_cpu()) {
return dipu_ne_scalar_out(other, self.item(), out);
}
interface: diopiNe(ctx, out, self, other)
@@ -356,7 +398,8 @@
custom_code_at_the_beginning: |
if (other.numel() == 1 && other.is_cpu()) {
return dipu_ge_scalar_out(self, other.item(), out);
- } else if (self.numel() == 1 && self.is_cpu()) {
+ }
+ if (self.numel() == 1 && self.is_cpu()) {
return dipu_ge_scalar_out(other, self.item(), out);
}
interface: diopiGe(ctx, out, self, other)
@@ -378,7 +421,8 @@
custom_code_at_the_beginning: |
if (other.numel() == 1 && other.is_cpu()) {
return dipu_gt_scalar_out(self, other.item(), out);
- } else if (self.numel() == 1 && self.is_cpu()) {
+ }
+ if (self.numel() == 1 && self.is_cpu()) {
return dipu_gt_scalar_out(other, self.item(), out);
}
interface: diopiGt(ctx, out, self, other)
@@ -400,7 +444,8 @@
custom_code_at_the_beginning: |
if (other.numel() == 1 && other.is_cpu()) {
return dipu_le_scalar_out(self, other.item(), out);
- } else if (self.numel() == 1 && self.is_cpu()) {
+ }
+ if (self.numel() == 1 && self.is_cpu()) {
return dipu_le_scalar_out(other, self.item(), out);
}
interface: diopiLe(ctx, out, self, other)
@@ -444,6 +489,7 @@
interface: diopiSum(ctx, out, self_dtype_diopi, diopi_size)
- schema: "addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"
+ custom_fallback: True
custom_code_at_the_beginning: |
interface: diopiAddmm(&context, out, self, mat1, mat2, beta, alpha)
@@ -494,7 +540,7 @@
int64_t out_height = (height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1;
int64_t out_width = (width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1;
c10::SmallVector output_size = {batch_size, out_channel, out_height, out_width};
- at::Tensor out = at::empty(output_size, input.options());
+ at::Tensor out = at::empty(output_size, input.options(),${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-input.suggest_memory_format()});
interface: diopiConvolution2d(&context, out, input, weight, bias, stride, padding, dilation, groups)
- schema: "convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)"
@@ -506,10 +552,10 @@
at::Tensor grad_bias;
std::vector bias_sizes;
if (output_mask[0]) {
- grad_input = at::empty(input.sizes(), input.options());
+ grad_input = at::empty(input.sizes(), input.options(), ${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-c10::nullopt});
}
if (output_mask[1]) {
- grad_weight = at::empty(weight.sizes(), weight.options().dtype(at::kFloat));
+ grad_weight = at::empty(weight.sizes(), weight.options().dtype(at::kFloat).memory_format(weight.suggest_memory_format()));
}
if (output_mask[2]) {
bias_sizes.push_back(grad_output.size(1));
@@ -526,7 +572,7 @@
at::Tensor grad_input;
at::Tensor grad_weight;
at::Tensor grad_bias;
- grad_input = at::empty(input.sizes(), input.options());
+ grad_input = at::empty(input.sizes(), input.options(), ${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-c10::nullopt});
grad_weight = at::empty(weight.sizes(), weight.options().dtype(at::kFloat));
if (output_mask[2]) {
grad_bias = at::empty({grad_output.size(1)}, grad_output.options());
@@ -548,10 +594,10 @@
const int64_t w_out = (w_in - 1) * stride[1] - 2 * padding[1] + (dilation[1] * (kernel_width - 1) + 1) + output_padding[1];
const int64_t c_out = weight.size(1) * groups;
auto output_shape = input.sizes().size() == 3 ? std::vector{c_out, h_out, w_out} : std::vector{n, c_out, h_out, w_out};
- auto out = at::empty(output_shape, input.options());
+ auto out = at::empty(output_shape, input.options(), ${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-c10::nullopt});
interface: diopiConvTranspose2d(ctx, out, input, weight, bias, stride, padding, output_padding, groups, dilation)
forward_process_code: |
- bool bias_has_value = (bias.has_value() == true) ? bias.value().requires_grad() : false;
+ bool bias_has_value = (bias.has_value()) ? bias.value().requires_grad() : false;
saved_data:
[
stride,
@@ -577,10 +623,7 @@
if (bias_has_value) {
bias_sizes.push_back(grad_output.size(1));
}
- std::array output_mask;
- output_mask[0] = input.requires_grad();
- output_mask[1] = weight.requires_grad();
- output_mask[2] = bias_has_value;
+ std::array output_mask = {input.requires_grad(), weight.requires_grad(), bias_has_value};
backward_schema: "convolution_transpose_backward(Tensor grad_output, Tensor input, Tensor weight, int[] bias_sizes, int[] stride, int[] padding, int[] dilation, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)"
backward_return_code: |
std::vector outputs = {
@@ -662,7 +705,9 @@
- schema: "topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)"
custom_code_at_the_beginning: |
std::vector output_size(self.sizes().begin(), self.sizes().end());
- dim = dim < 0 ? (dim + output_size.size()) : dim;
+ if (dim < 0) {
+ dim = dim + static_cast(output_size.size());
+ }
output_size[dim] = k;
auto values = at::empty(output_size, self.options());
auto indices = at::empty(output_size, self.options().dtype(at::kLong));
@@ -693,7 +738,9 @@
device: [all, -cuda]
custom_fallback: True
custom_code_at_the_beginning: |
- at::Tensor grad_input, grad_weight, grad_bias;
+ at::Tensor grad_input;
+ at::Tensor grad_weight;
+ at::Tensor grad_bias;
if (output_mask[0]) {
grad_input = at::empty(input.sizes(), grad_output.options());
}
@@ -706,6 +753,7 @@
interface: diopiLinearBackward(ctx, grad_input, grad_weight, grad_bias, grad_output, input, weight)
- schema: "linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"
+ custom_fallback: True
device: [all, -cuda]
custom_code_at_the_beginning: |
std::vector output_size(input.sizes().begin(), input.sizes().end());
@@ -850,15 +898,17 @@
- schema: "stack(Tensor[] tensors, int dim=0) -> Tensor"
custom_code_at_the_beginning: |
- dim += dim < 0 ? tensors[0].sizes().size()+1 : 0;
- auto num_tensors = tensors.size();
+ if (dim < 0) {
+ dim += static_cast(tensors[0].sizes().size()) + 1;
+ }
+ auto num_tensors = static_cast(tensors.size());
auto shape = tensors[0].sizes();
std::vector tmp;
for (int i = 0; i < dim; i++) {
tmp.push_back(shape[i]);
}
tmp.push_back(num_tensors);
- for (int i = dim; i < shape.size(); i++) {
+ for (int i = static_cast(dim); i < shape.size(); i++) {
tmp.push_back(shape[i]);
}
const std::vector& const_tmp = tmp;
@@ -873,28 +923,45 @@
- schema: "stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)"
custom_code_at_the_beginning: |
- dim += dim < 0 ? tensors[0].sizes().size() : 0;
+ if (dim < 0) {
+ dim += static_cast(tensors[0].sizes().size());
+ }
std::vector diopiTensorHandles(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
diopiTensorHandles[i] = dipu::diopi_helper::toDiopiTensorHandle(tensors.at(i));
}
- interface: diopiStack(ctx, out, diopiTensorHandles.data(), tensors.size(), dim)
+ interface: diopiStack(ctx, out, diopiTensorHandles.data(), static_cast(tensors.size()), dim)
- schema: "sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)"
custom_code_at_the_beginning: |
- auto dim_ = dim < 0 ? (dim + self.sizes().size()) : dim;
+ int64_t dim_ = 0;
+ if (dim < 0) {
+ dim_ = dim + static_cast(self.sizes().size());
+ } else {
+ dim_ = dim;
+ }
auto values = at::empty(self.sizes(), self.options());
auto indices = at::empty(self.sizes(), self.options().dtype(at::kLong));
interface: diopiSort(ctx, values, indices, self, dim_, descending, nullptr)
- schema: "sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)"
custom_code_at_the_beginning: |
- auto dim_ = dim < 0 ? (dim + self.sizes().size()) : dim;
+ int64_t dim_ = 0;
+ if (dim < 0) {
+ dim_ = dim + static_cast(self.sizes().size());
+ } else {
+ dim_ = dim;
+ }
interface: diopiSort(ctx, values, indices, self, dim_, descending, nullptr)
- schema: "sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)"
custom_code_at_the_beginning: |
- auto dim_ = dim < 0 ? (dim + self.sizes().size()) : dim;
+ int64_t dim_ = 0;
+ if (dim < 0) {
+ dim_ = dim + static_cast(self.sizes().size());
+ } else {
+ dim_ = dim;
+ }
bool stable_ = stable.has_value() ? stable.value() : false;
const bool *p = &stable_;
interface: diopiSort(ctx, values, indices, self, dim_, descending, p)
@@ -1027,7 +1094,8 @@
- schema: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor
custom_code_at_the_beginning: |
- const auto self_dtype = at::native::to(self, dtype);
+ auto promoted_dtype = at::native::get_dtype_from_self(self, dtype, /*promote_integers=*/true);
+ const auto self_dtype = at::native::to(self, promoted_dtype);
auto out = at::empty({}, self_dtype.options());
::diopiConstTensorHandle_t self_dtype_diopi = dipu::diopi_helper::toDiopiTensorHandle(self_dtype);
interface: diopiProd(ctx, out, self_dtype_diopi, nullptr)
@@ -1047,7 +1115,7 @@
}
const auto& self_sizes = self.sizes();
- for (int i = self_sizes.size() - 1, j = output_size.size() - 1;i >= 0;i--, j--) {
+ for (int i = static_cast(self_sizes.size()) - 1, j = static_cast(output_size.size()) - 1;i >= 0;i--, j--) {
output_size[j] *= self_sizes.at(i);
}
@@ -1057,15 +1125,20 @@
- schema: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
custom_code_at_the_beginning: |
auto out = at::empty_like(self);
+ // NOLINTNEXTLINE(readability-suspicious-call-argument)
return dipu_sub_out(other, self, alpha, out);
interface: diopiSub(ctx, out, other, self, alpha)
- schema: "unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor out, Tensor indices, Tensor counts)"
custom_code_at_the_beginning: |
- at::Tensor out, counts, indices;
+ at::Tensor out;
+ at::Tensor counts;
+ at::Tensor indices;
if (return_inverse) {
const auto ndims = self.sizes().size();
- dim += (dim < 0 ? ndims : 0);
+ if (dim < 0) {
+ dim += static_cast(ndims);
+ }
indices = at::empty({self.sizes().at(dim)}, self.options().dtype(at::kLong));
}
diopiTensorHandle_t out_ptr = nullptr;
@@ -1080,7 +1153,9 @@
- schema: "_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor out, Tensor indices, Tensor counts)"
custom_code_at_the_beginning: |
- at::Tensor out, counts, indices;
+ at::Tensor out;
+ at::Tensor counts;
+ at::Tensor indices;
if (return_inverse) {
indices = at::empty(self.sizes(), self.options().dtype(at::kLong));
}
@@ -1100,7 +1175,7 @@
std::transform(tensors.begin(), tensors.end(), diopiTensorHandles.begin(), [](const at::Tensor& tensor){
return dipu::diopi_helper::toDiopiTensorHandle(tensor);
});
- interface: diopiCat(ctx, out, diopiTensorHandles.data(), tensors.size(), dim);
+ interface: diopiCat(ctx, out, diopiTensorHandles.data(), static_cast(tensors.size()), dim);
- schema: "masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor"
custom_code_at_the_beginning: |
@@ -1125,7 +1200,7 @@
- schema: "min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) min, Tensor(b!) min_indices)"
custom_code_at_the_beginning: |
- dim += ((dim >= 0) ? 0 : self.sizes().size());
+ dim += ((dim >= 0) ? 0 : static_cast(self.sizes().size()));
interface: diopiMin(ctx, min, min_indices, self, dim)
- schema: "max(Tensor self) -> Tensor"
@@ -1134,12 +1209,16 @@
interface: diopiMaxAll(ctx, out, self)
- schema: "maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"
- no_device_check_args: [other]
- interface: diopiMaximum(ctx, out, self, other)
+ no_device_check_args: [self, other]
+ ins: [selfTemp, otherTemp]
+ custom_code_at_the_beginning: |
+ auto selfTemp = (self.numel() == 1 && self.is_cpu()) ? self.to(other.device()) : self;
+ auto otherTemp = (other.numel() == 1 && other.is_cpu()) ? other.to(self.device()) : other;
+ interface: diopiMaximum(ctx, out, selfTemp, otherTemp)
- schema: "max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_indices) -> (Tensor(a!) max, Tensor(b!) max_indices)"
custom_code_at_the_beginning: |
- dim += ((dim >= 0) ? 0 : self.sizes().size());
+ dim += ((dim >= 0) ? 0 : static_cast(self.sizes().size()));
if (max_indices.numel() <= 0) {
auto output_size = self.sizes().vec();
if (keepdim) {
@@ -1261,12 +1340,28 @@
custom_code_at_the_beginning: |
std::vector size(2);
custom_code_before_call_diopi: |
- if (output_size.size() > 0) {
+ if (!output_size.empty()) {
std::copy(output_sizeVector.begin(), output_sizeVector.end(), size.begin());
+ } else {
+ size[0] = std::floor(static_cast(self.size(-2)) * scales_h.value_or(1.0));
+ size[1] = std::floor(static_cast(self.size(-1)) * scales_w.value_or(1.0));
+ }
+ interface: diopiUpsampleNearest(ctx, out, self, size);
+
+- schema: "upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor"
+ size_attr: [size]
+ custom_code_at_the_beginning: |
+ std::vector size(2);
+ if (output_size.size() > 0) {
+ std::vector tmpVector(output_size.size());
+ auto symIntToInt = [](const c10::SymInt& t)-> int64_t {return t.expect_int();};
+ std::transform(output_size.cbegin(), output_size.cend(), tmpVector.begin(), symIntToInt);
+ std::copy(tmpVector.begin(), tmpVector.end(), size.begin());
} else {
size[0] = std::floor(self.size(-2) * scales_h.value_or(1.0));
size[1] = std::floor(self.size(-1) * scales_w.value_or(1.0));
}
+ auto out = at::empty({self.size(0),self.size(1),size[0],size[1]},self.options(),${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-c10::nullopt});
interface: diopiUpsampleNearest(ctx, out, self, size);
- schema: "upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)"
@@ -1274,12 +1369,29 @@
custom_code_at_the_beginning: |
std::vector size(2);
custom_code_before_call_diopi: |
- if (output_size.size() > 0) {
+ if (!output_size.empty()) {
std::copy(output_sizeVector.begin(), output_sizeVector.end(), size.begin());
+ } else {
+ size[0] = std::floor(static_cast(self.size(-2)) * scales_h.value_or(1.0));
+ size[1] = std::floor(static_cast(self.size(-1)) * scales_w.value_or(1.0));
+ }
+ const char* mode = "bilinear";
+ interface: diopiUpsampleLinear(ctx, out, self, size, align_corners, mode);
+
+- schema: "upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"
+ size_attr: [size]
+ custom_code_at_the_beginning: |
+ std::vector size(2);
+ if (output_size.size() > 0) {
+ std::vector tmpVector(output_size.size());
+ auto symIntToInt = [](const c10::SymInt& t)-> int64_t {return t.expect_int();};
+ std::transform(output_size.cbegin(), output_size.cend(), tmpVector.begin(), symIntToInt);
+ std::copy(tmpVector.begin(), tmpVector.end(), size.begin());
} else {
size[0] = std::floor(self.size(-2) * scales_h.value_or(1.0));
size[1] = std::floor(self.size(-1) * scales_w.value_or(1.0));
}
+ auto out = at::empty({self.size(0),self.size(1),size[0],size[1]},self.options(),${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-c10::nullopt});
const char* mode = "bilinear";
interface: diopiUpsampleLinear(ctx, out, self, size, align_corners, mode);
@@ -1287,6 +1399,23 @@
size_attr: [size]
custom_code_at_the_beginning: |
std::vector size(2);
+ custom_code_before_call_diopi: |
+ if (!output_size.empty()) {
+ std::copy(output_sizeVector.begin(), output_sizeVector.end(), size.begin());
+ } else {
+ size[0] = std::floor(static_cast(*(input_sizeVector.rbegin() + 1)) * scales_h.value_or(1.0));
+ size[1] = std::floor(static_cast(*(input_sizeVector.rbegin())) * scales_w.value_or(1.0));
+ }
+ interface: diopiUpsampleNearestBackward(ctx, grad_input, grad_output, size, input_size)
+
+- schema: "upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor grad_input"
+ size_attr: [size]
+ custom_code_at_the_beginning: |
+ std::vector size(2);
+ auto symInt2Int = [](const c10::SymInt& t)-> int64_t {return t.expect_int();};
+ std::vector grad_input_shape(input_size.size());
+ std::transform(input_size.cbegin(), input_size.cend(), grad_input_shape.begin(), symInt2Int);
+ auto grad_input = at::empty(grad_input_shape,grad_output.options(),${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-c10::nullopt});
custom_code_before_call_diopi: |
if (output_size.size() > 0) {
std::copy(output_sizeVector.begin(), output_sizeVector.end(), size.begin());
@@ -1300,6 +1429,24 @@
size_attr: [size]
custom_code_at_the_beginning: |
std::vector size(2);
+ custom_code_before_call_diopi: |
+ if (!output_size.empty()) {
+ std::copy(output_sizeVector.begin(), output_sizeVector.end(), size.begin());
+ } else {
+ size[0] = std::floor(static_cast(*(input_sizeVector.rbegin() + 1)) * scales_h.value_or(1.0));
+ size[1] = std::floor(static_cast(*(input_sizeVector.rbegin())) * scales_w.value_or(1.0));
+ }
+ const char* mode = "bilinear";
+ interface: diopiUpsampleLinearBackward(ctx, grad_input, grad_output, size, input_size, align_corners, mode)
+
+- schema: "upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor grad_input"
+ size_attr: [size]
+ custom_code_at_the_beginning: |
+ std::vector size(2);
+ auto symInt2Int = [](const c10::SymInt& t)-> int64_t {return t.expect_int();};
+ std::vector grad_input_shape(input_size.size());
+ std::transform(input_size.cbegin(), input_size.cend(), grad_input_shape.begin(), symInt2Int);
+ auto grad_input = at::empty(grad_input_shape,grad_output.options(),${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-c10::nullopt});
custom_code_before_call_diopi: |
if (output_size.size() > 0) {
std::copy(output_sizeVector.begin(), output_sizeVector.end(), size.begin());
@@ -1333,6 +1480,7 @@
interface: diopiCosInp(ctx, self)
- schema: "bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)"
+ custom_fallback: True
interface: diopiBmm(ctx, out, self, mat2)
- schema: "silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"
@@ -1346,7 +1494,13 @@
autocompare: disable
interface: diopiNormalInp(ctx, self, mean, std, generator)
+- schema: "mm(Tensor self, Tensor mat2) -> Tensor"
+ custom_code_at_the_beginning: |
+ auto out = nodispatch::empty({self.sizes()[0], mat2.sizes()[1]}, self.options());
+ interface: diopiMm(ctx, out, self, mat2)
+
- schema: "mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)"
+ custom_fallback: True
interface: diopiMm(ctx, out, self, mat2)
- schema: "matmul(Tensor self, Tensor other) -> Tensor"
@@ -1414,7 +1568,7 @@
custom_code_at_the_beginning: |
auto shape = self.sizes();
std::vector output_shape(shape.begin(), shape.end());
- dim += dim >= 0 ? 0 : shape.size();
+ dim += dim >= 0 ? 0 : static_cast(shape.size());
output_shape[dim] = index.numel();
auto out = at::empty({output_shape}, self.options());
interface: diopiIndexSelect(ctx, out, self, dim, index)
@@ -1523,7 +1677,35 @@
at::Tensor neg_log_likelihood = at::empty({batch_size}, options);
at::Tensor log_alpha = at::empty({batch_size, log_probs.size(0), 2 * max_target_length + 1}, options);
backward_return_code: |
- std::vector outputs(7);
+ /* Note: This kernel's output size will be checked by pytorch/torch/csrc/autograd/custom_function.h
+ *
+ * ''' custom_function.h
+ * auto num_outputs = static_cast(outputs.size());
+ * // Returning too many results is ok, but only as long as they're all
+ * // undefined. Truncate the result vector in that case.
+ * if (num_outputs > num_forward_inputs) {
+ * bool all_undef = true;
+ * for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {
+ * all_undef &= (!outputs[i].defined());
+ * }
+ * if (all_undef) {
+ * outputs.resize(num_forward_inputs);
+ * num_outputs = num_forward_inputs;
+ * }
+ * }
+ *
+ * if (num_outputs != num_forward_inputs) {
+ * std::string msg("function ");
+ * msg += name() + " returned an incorrect number of gradients (expected ";
+ * msg += c10::to_string(num_forward_inputs) + ", got ";
+ * msg += c10::to_string(num_outputs) + ")";
+ * throw std::runtime_error(msg);
+ * }
+ * '''
+ */
+
+ constexpr int kSameAsInputSize = 7;
+ std::vector outputs(kSameAsInputSize);
outputs[0] = result;
return outputs;
@@ -1606,7 +1788,35 @@
at::Tensor neg_log_likelihood = at::empty({batch_size}, options);
at::Tensor log_alpha = at::empty({batch_size, log_probs.size(0), 2 * max_target_length + 1}, options);
backward_return_code: |
- std::vector outputs(7);
+ /* Note: This kernel's output size will be checked by pytorch/torch/csrc/autograd/custom_function.h
+ *
+ * ''' custom_function.h
+ * auto num_outputs = static_cast(outputs.size());
+ * // Returning too many results is ok, but only as long as they're all
+ * // undefined. Truncate the result vector in that case.
+ * if (num_outputs > num_forward_inputs) {
+ * bool all_undef = true;
+ * for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {
+ * all_undef &= (!outputs[i].defined());
+ * }
+ * if (all_undef) {
+ * outputs.resize(num_forward_inputs);
+ * num_outputs = num_forward_inputs;
+ * }
+ * }
+ *
+ * if (num_outputs != num_forward_inputs) {
+ * std::string msg("function ");
+ * msg += name() + " returned an incorrect number of gradients (expected ";
+ * msg += c10::to_string(num_forward_inputs) + ", got ";
+ * msg += c10::to_string(num_outputs) + ")";
+ * throw std::runtime_error(msg);
+ * }
+ * '''
+ */
+
+ constexpr int kSameAsInputSize = 7;
+ std::vector outputs(kSameAsInputSize);
outputs[0] = result;
return outputs;
@@ -1679,7 +1889,12 @@
interface: diopiClampMaxInp(ctx, self, max)
- schema: "minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"
- interface: diopiMinimum(ctx,out, self, other)
+ no_device_check_args: [self, other]
+ ins: [selfTemp, otherTemp]
+ custom_code_at_the_beginning: |
+ auto selfTemp = (self.numel() == 1 && self.is_cpu()) ? self.to(other.device()) : self;
+ auto otherTemp = (other.numel() == 1 && other.is_cpu()) ? other.to(self.device()) : other;
+ interface: diopiMinimum(ctx, out, selfTemp, otherTemp)
- schema: "scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)"
interface: diopiScatterScalar(ctx, out, self, dim, value, index, "")
@@ -1746,7 +1961,7 @@
indices_tensor_vec[i] = (indices[i].has_value() && indices[i].value().defined()) ? indices[i].value().to(self.device()) : at::Tensor();
indices_vec[i] = diopi_helper::toDiopiTensorHandle(indices_tensor_vec[i]);
}
- interface: diopiIndex(ctx, &out_ptr, self, indices_vec.data(), indices_vec.size())
+ interface: diopiIndex(ctx, &out_ptr, self, indices_vec.data(), static_cast(indices_vec.size()))
custom_code_before_return: |
dipu::getCurrentDIPUStream().synchronize();
out = *reinterpret_cast(out_ptr);
@@ -1760,7 +1975,7 @@
indices_tensor_vec[i] = (indices[i].has_value() && indices[i].value().defined()) ? indices[i].value().to(self.device()) : at::Tensor();
indices_vec[i] = diopi_helper::toDiopiTensorHandle(indices_tensor_vec[i]);
}
- interface: diopiIndexPut(ctx, self, self, values, indices_vec.data(), indices_vec.size(), accumulate)
+ interface: diopiIndexPut(ctx, self, self, values, indices_vec.data(), static_cast(indices_vec.size()), accumulate)
- schema: "_cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor"
custom_code_at_the_beginning: |
@@ -1823,15 +2038,15 @@
int num_blocks = 1;
for(int i = 0; i < 2; i++){
- num_blocks *= int((input_shape[i + 2] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) / stride[i]) + 1;
+ num_blocks *= static_cast((input_shape[i + 2] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) / stride[i]) + 1;
}
- int channels = input_shape[1];
+ int channels = static_cast(input_shape[1]);
for(int i = 0; i < 2; i++){
- channels *= kernel_size[i];
+ channels *= static_cast(kernel_size[i]);
}
std::vector out_shape({channels, num_blocks});
- if(batched_input == true){
+ if(batched_input){
out_shape.insert(out_shape.begin(), input_shape[0]);
}
auto out = at::empty({out_shape}, self.options());
@@ -1847,13 +2062,13 @@
input_shape.insert(input_shape.begin(), 1);
}
- int channels = input_shape[1];
+ int channels = static_cast(input_shape[1]);
for(int i = 0; i < 2; i++){
- channels = channels / kernel_size[i];
+ channels = channels / static_cast(kernel_size[i]);
}
std::vector out_shape({channels, output_size.at(0).expect_int(), output_size.at(1).expect_int()});
- if(batched_input == true){
+ if(batched_input){
out_shape.insert(out_shape.begin(), input_shape[0]);
}
auto out = at::empty({out_shape}, self.options());
@@ -1898,7 +2113,12 @@
auto shape = input.size(1);
auto out0 = at::empty({shape}, input.options().dtype(at::kFloat));
auto out1 = at::empty({shape}, input.options().dtype(at::kFloat));
- interface: diopiBatchNormGatherStatsWithCounts(ctx, out0, out1, input, mean, invstd, const_cast(running_mean), const_cast(running_var), momentum, eps, counts)
+ interface: diopiBatchNormGatherStatsWithCounts(ctx, out0, out1, input, mean, invstd, const_cast(running_mean), const_cast(running_var), static_cast(momentum), static_cast(eps), counts)
+ custom_code_before_call_diopi: |
+ // NOTE: const_cast here is safe according to pytorch's source code
+ // NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast)
+ custom_code_before_return: |
+ // NOLINTEND(cppcoreguidelines-pro-type-const-cast)
- schema: batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor)
custom_code_at_the_beginning: |
@@ -1908,8 +2128,8 @@
at::Tensor out2;
at::Tensor out3;
if(input_g){
- out0 = at::empty({shape}, input.options().dtype(at::kFloat));
- out1 = at::empty({shape}, input.options().dtype(at::kFloat));
+ out0 = at::empty({shape}, input.options().dtype(at::kFloat), ${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-c10::nullopt});
+ out1 = at::empty({shape}, input.options().dtype(at::kFloat), ${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-c10::nullopt});
}
if(weight_g){
out2 = at::empty({shape}, input.options().dtype(at::kFloat));
@@ -1921,13 +2141,13 @@
- schema: batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor
custom_code_at_the_beginning: |
- auto out = at::empty_like(grad_out);
+ auto out = at::empty_like(grad_out, grad_out.options(), ${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-c10::nullopt});
interface: diopiBatchNormBackwardElemt(ctx, out, grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
- schema: batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor
custom_code_at_the_beginning: |
- auto out = at::empty_like(input);
- interface: diopiBatchNormElemt(ctx, out, input, weight, bias, mean, invstd, eps);
+ auto out = at::empty_like(input, input.options(), ${PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-c10::nullopt});
+ interface: diopiBatchNormElemt(ctx, out, input, weight, bias, mean, invstd, static_cast(eps));
- schema: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!)
interface: diopiSmoothL1Loss(ctx, out, self, target, static_cast(reduction), static_cast(beta));
@@ -2134,7 +2354,7 @@
auto selfVec = self.vec();
auto scalarsCpu = scalars.cpu();
for (size_t i = 0;i < self.size();i++) {
- dipu_addcmul_(selfVec[i], tensor1[i], tensor2[i], scalarsCpu[i].item());
+ dipu_addcmul_(selfVec[i], tensor1[i], tensor2[i], scalarsCpu[static_cast(i)].item());
}
return;
interface: diopiAddcmulInp(ctx, self, tensor1, tensor2, scalars)
@@ -2165,7 +2385,7 @@
auto selfVec = self.vec();
auto scalarsCpu = scalars.cpu();
for (size_t i = 0;i < self.size();i++) {
- dipu_addcdiv_(selfVec[i], tensor1[i], tensor2[i], scalarsCpu[i].item());
+ dipu_addcdiv_(selfVec[i], tensor1[i], tensor2[i], scalarsCpu[static_cast(i)].item());
}
return;
interface: diopiAddcdivInp(ctx, self, tensor1, tensor2, scalars)
@@ -2215,7 +2435,7 @@
return out;
interface: diopiNorm(ctx, out, self, p, dimDiopiSize);
-# wrap_diopi_cast_dtype has no corresponding aten op and not registed, it's just a diopi func wrapper.
+# wrap_diopi_cast_dtype has no corresponding aten op and not registered, it's just a diopi func wrapper.
# use this tricky method to support call multiple diopi-op in one aten-op
- schema: "wrap_diopi_cast_dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)"
register_op: False
@@ -2231,9 +2451,10 @@
# this copy_ aten op may use both diopiCastDtype and diopiCopyInp. it's a proxy/composite op
- schema: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
+ autocompare: disable
dummy_call_diopi: True
custom_fallback: True
- device: [cuda, camb, ascend, droplet, supa]
+ device: [cuda, camb, ascend, droplet, supa, kunlunxin]
custom_code_at_the_beginning: |
dipu::getDipuCopyInstance()->run(self, src, non_blocking);
return self;
@@ -2242,6 +2463,7 @@
# vendor who has no fully implemented diopi and proper fallback DIPUCopy sub-class
- schema: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
+ autocompare: disable
custom_fallback: True
dummy_call_diopi: True
custom_code_at_the_beginning: |
@@ -2250,15 +2472,20 @@
interface: diopiCopyInp(ctx, src, self)
- schema: _amp_foreach_non_finite_check_and_unscale_(at::TensorList self, Tensor(b!) found_inf, Tensor inv_scale) -> void
+ autocompare: disable
custom_fallback: True
custom_code_at_the_beginning: |
std::vector diopiTensorHandles(self.size(), nullptr);
+ // NOTE: const_cast here is safe according to pytorch's source code
+ // NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast)
std::transform(self.begin(), self.end(), diopiTensorHandles.begin(), [](const at::Tensor& t){
return dipu::diopi_helper::toDiopiTensorHandle(const_cast(t));
});
- interface: diopiAmpForeachNonFiniteCheckAndUnscaleInp(ctx, diopiTensorHandles.data(), self.size(), found_inf, inv_scale)
+ // NOLINTEND(cppcoreguidelines-pro-type-const-cast)
+ interface: diopiAmpForeachNonFiniteCheckAndUnscaleInp(ctx, diopiTensorHandles.data(), static_cast(self.size()), found_inf, inv_scale)
+ # TODO(someone): fix this issue when `autocompare` is on
autocompare: disable
- schema: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!)
custom_fallback: True
- interface: diopiAmpUpdateScaleInp(ctx, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval)
+ interface: diopiAmpUpdateScaleInp(ctx, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, static_cast(growth_interval))
diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py b/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py
index 7eda79b15c..1f4536cdd9 100644
--- a/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py
+++ b/dipu/scripts/autogen_diopi_wrapper/diopi_wrapper_template.py
@@ -1,44 +1,92 @@
# Copyright (c) 2023, DeepLink.
diopi_wrapper_file_template_content = \
-"""
-// autogened file
-#include
-#include
+"""// autogened file
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
#include
-#include
-
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
#include
+#include
#include
-#include "csrc_dipu/aten/DIPUATenFunctions.h"
+
+#include
+#include
+
#include "csrc_dipu/aten/RegisterDIPU.hpp"
+#include "csrc_dipu/aten/ops/DIPUCopy.hpp"
+#include "csrc_dipu/aten/ops/NodispatchUtils.hpp"
+#include "csrc_dipu/aten/ops/OpUtils.hpp"
+#include "csrc_dipu/base/basedef.h"
#include "csrc_dipu/diopirt/diopirt_impl.h"
#include "csrc_dipu/profiler/profiler.h"
-#include
+#include "csrc_dipu/runtime/core/DIPUGeneratorImpl.h"
+#include "csrc_dipu/runtime/core/DIPUStream.h"
+
#include "CustomFallbackFunctions.hpp"
-#include "csrc_dipu/aten/ops/DIPUCopy.hpp"
$header_include_code
-namespace dipu::native {
+// NOTE: Some kernels (e.g. _foreach_add_.List) have custom codes at the
+// beginning ending with early return. This is a workaround intended to skip
+// some of the autogened codes (e.g. type cast, calling DIOPI, etc.).
+//
+// NOLINTBEGIN(readability-redundant-control-flow)
-using dipu::diopi_helper::toDiopiGeneratorHandle;
+namespace dipu {
+namespace native {
-using namespace dipu::diopi_helper;
+using dipu::diopi_helper::toDiopiGeneratorHandle;
+using dipu::diopi_helper::toDiopiSize;
+using dipu::diopi_helper::toDiopiRoundMode;
$functions_code
+} // namespace native
+} // namespace dipu
-} // namespace dipu::native
+// NOLINTEND(readability-redundant-control-flow)
namespace at {
DIPU_LIBRARY_IMPL(aten, DIPU_DEVICE_TYPE_MACRO, m) {
- $op_register_code
+ $op_register_code
}
DIPU_LIBRARY_IMPL(aten, DIPU_AUTOGRAD_DEVICE_TYPE_MACRO, m) {
- $autograd_op_register_code
+ $autograd_op_register_code
}
} // namespace at
@@ -49,34 +97,32 @@
"""
// $comment
$cppsignautre {
- dipu::profile::RecordBlockCreator _(__FUNCTION__);
- $custom_code_at_the_beginning
+ dipu::profile::RecordBlockCreator _(__FUNCTION__);
+ $custom_code_at_the_beginning
- ::diopiContext context(dipu::getCurrentDIPUStream().rawstream());
- auto ctx = &context;
+ ::diopiContext context(dipu::getCurrentDIPUStream().rawstream());
+ auto ctx = &context;
- $input_process_code
+ $input_process_code
- $output_process_code
+ $output_process_code
- $attrs_process_code
+ $attrs_process_code
- $device_check_code
+ $device_check_code
- $custom_code_before_call_diopi
+ $custom_code_before_call_diopi
- dipu::profile::RecordBlockCreator dipuRecorder(R"($diopi_fun_call_code)");
- ::diopiError_t ret = $diopi_fun_call_code
- dipuRecorder.end();
- if (checkDiopiReturnValue()) {
- TORCH_CHECK(ret == ::diopiSuccess, __FILE__, ":", __LINE__, R"($diopi_fun_call_code)", " error, error code is ", ret, "error message is ", diopiGetLastErrorString());
- }
+ dipu::profile::RecordBlockCreator dipuRecorder(R"($interface_name)");
+ ::diopiError_t ret = $diopi_fun_call_code
+ dipuRecorder.end();
+ TORCH_CHECK(ret == ::diopiSuccess, __FILE__, ":", __LINE__, R"($diopi_fun_call_code)", " error, error code is ", ret, "error message is ", diopiGetLastErrorString());
- $custom_code_before_return
+ $custom_code_before_return
- synchronizeIfEnable();
+ synchronizeIfEnable();
- $return_code
+ $return_code
}
"""
@@ -94,29 +140,29 @@
"""
class $autograd_function_name : public torch::autograd::Function<$autograd_function_name> {
public:
- static $return_code forward(torch::autograd::AutogradContext *ctx, $param_list) {
- $forward_process_code
+ static $return_code forward(torch::autograd::AutogradContext *ctx, $param_list) {
+ $forward_process_code
- $save_for_backward_code
+ $save_for_backward_code
- at::AutoDispatchBelowADInplaceOrView g;
- return $call_forward_impl_code;
- }
+ at::AutoDispatchBelowADInplaceOrView g;
+ return $call_forward_impl_code;
+ }
static std::vector backward(torch::autograd::AutogradContext *ctx, std::vector grad_outputs) {
- $load_saved_data_code
+ $load_saved_data_code
- $cal_grad_code
+ $cal_grad_code
- $call_backward_impl_code
+ $call_backward_impl_code
- $backward_return_code
+ $backward_return_code
}
};
$cppsignautre {
- auto result = $autograd_function_name::apply($arg_name_list);
- $wrappter_custom_return
+ auto result = $autograd_function_name::apply($arg_name_list);
+ $wrappter_custom_return
}
"""
@@ -125,15 +171,15 @@ class $autograd_function_name : public torch::autograd::Function<$autograd_funct
"""
// $comment
$cppsignautre {
- std::cout << std::endl << __FUNCTION__ << std::endl;
- $transform_input_to_cpu_code
+ std::cout << std::endl << __FUNCTION__ << std::endl;
+ $transform_input_to_cpu_code
- $execute_op_on_cpu_code
+ $execute_op_on_cpu_code
- $execute_op_on_device_code
+ $execute_op_on_device_code
- $transform_result_to_cpu_code
+ $transform_result_to_cpu_code
- $result_compare_code
+ $result_compare_code
}
"""
diff --git a/dipu/scripts/autogen_diopi_wrapper/op_memory_format_converter.py b/dipu/scripts/autogen_diopi_wrapper/op_memory_format_converter.py
new file mode 100644
index 0000000000..80a8fccb4d
--- /dev/null
+++ b/dipu/scripts/autogen_diopi_wrapper/op_memory_format_converter.py
@@ -0,0 +1,115 @@
+import os
+import re
+import yaml
+
+accepted_interface = "ALL"
+
+class OpMemoryFormatConverter(object):
+ #The converter class, will do the converting memory format based on the convert_config.yaml loaded.
+ def __init__(self, convert_config):
+ assert(isinstance(convert_config, str))
+ if convert_config and len(convert_config):
+ with open(convert_config) as convert_config_yaml_file:
+ file_data = convert_config_yaml_file.read()
+ self.convert_config_yaml = yaml.load(file_data, Loader=yaml.FullLoader)
+ self.convert_config = ConvertConfig(self.convert_config_yaml)
+ else:
+ self.convert_config_yaml = list()
+ self.convert_config = ConvertConfig(self.convert_config_yaml)
+
+ def convert(self,custom_code,fun_config):
+ if "interface" in fun_config and (accepted_interface == "ALL" or (fun_config['interface'] in accepted_interface)):
+ return self.do_convert(custom_code,fun_config)
+ else:
+ return custom_code
+
+ def do_convert(self,custom_code,fun_config):
+ # Do the covert job
+ def choose_default(matched):
+ value = str(matched.group("default"))
+ return value
+
+ def choose_channelsLast3d(matched):
+ return "at::MemoryFormat::ChannelsLast3d"
+
+ def choose_channelsLast(matched):
+ return "at::MemoryFormat::ChannelsLast"
+
+ def choose_contiguous(matched):
+ return "at::MemoryFormat::Contiguous"
+
+ def choose_preserve(matched):
+ return "at::MemoryFormat::Preserve"
+
+ interface = fun_config["interface"]
+ custom_code = custom_code.split("\n")
+ memory_format = self.convert_config.interface2memoryformat(interface)
+ custom_code_new = list()
+ # match string like "${PREFERRED_MEMORY_FORMAT_PLACHOLDER_3D:-}"
+ placeholder_3d_pattern = "\$\{PREFERRED_MEMORY_FORMAT_PLACEHOLDER_3D:-(?P.*)\}"
+ # match string like "${PREFERRED_MEMORY_FORMAT_PLACHOLDER:-}"
+ placeholder_pattern = "\$\{PREFERRED_MEMORY_FORMAT_PLACEHOLDER:-(?P.*)\}"
+ for line in custom_code:
+ if memory_format == "channellast":
+ line = re.sub(placeholder_3d_pattern, choose_channelsLast3d, line)
+ line = re.sub(placeholder_pattern, choose_channelsLast, line)
+ elif memory_format == "contiguous":
+ line = re.sub(placeholder_3d_pattern, choose_contiguous, line)
+ line = re.sub(placeholder_pattern, choose_contiguous, line)
+ elif memory_format == "preserve":
+ line = re.sub(placeholder_3d_pattern, choose_preserve, line)
+ line = re.sub(placeholder_pattern, choose_preserve, line)
+ elif memory_format == "empty":
+ line = re.sub(placeholder_3d_pattern, choose_default, line)
+ line = re.sub(placeholder_pattern, choose_default, line)
+ else:
+ print("UNABLE TO RECOGNIZE MEMORY FORMAT!!!")
+ custom_code_new.append(line)
+ custom_code = "\n".join(custom_code_new)
+ return custom_code
+
+class ConvertConfig(object):
+ #This class is used to load and parse the convert_config.yaml
+ def __init__(self, config_yaml):
+ self.convert_dict = dict()
+ self.convert_config_yaml = config_yaml
+ self.default_layout = "empty"
+ assert(isinstance(config_yaml, list))
+ for config in config_yaml:
+ assert(isinstance(config,dict))
+ for interface in config.keys():
+ if interface == "common_config":
+ detail = config[interface]
+ assert(isinstance(detail, dict))
+ if "layout" in detail:
+ self.default_layout = self.layout2memoryformat(detail["layout"])
+ pass
+ # may add common behavior
+ for interface in config.keys():
+ if interface != "common_config":
+ self.convert_dict.setdefault(interface,dict())
+ detail = config[interface]
+ assert(isinstance(detail, dict))
+ if "layout" in detail:
+ self.convert_dict[interface]["layout"] = self.layout2memoryformat(detail["layout"])
+
+ def layout2memoryformat(self, layout):
+ #used when pasing convert_config.yaml, return the memory format based on NCHW/NHWC and other layout.
+ assert(isinstance(layout, str))
+ if "NCHW" in layout:
+ return "contiguous"
+ if "NLC" in layout:
+ return "channellast"
+ if "NHWC" in layout:
+ return "channellast"
+ if "NDHWC" in layout:
+ return "channellast"
+ return "preserve"
+
+ def interface2memoryformat(self, interface):
+ #return the prefered memory format based on the DIOPI interface.
+ interface_stripped = interface.strip().split("(")[0]
+ if (interface_stripped not in self.convert_dict) or ("layout" not in self.convert_dict[interface_stripped]):
+ return self.default_layout
+ else:
+ return self.convert_dict[interface_stripped]["layout"]
diff --git a/dipu/scripts/ci/ascend/ci_ascend_env.sh b/dipu/scripts/ci/ascend/ci_ascend_env.sh
index d7e4d17d53..381d6eb4bc 100644
--- a/dipu/scripts/ci/ascend/ci_ascend_env.sh
+++ b/dipu/scripts/ci/ascend/ci_ascend_env.sh
@@ -14,6 +14,9 @@ export DIPU_PATH=${DIPU_ROOT}
export PYTORCH_DIR=${ASCEND_TORCH_DIR}
export PYTHONPATH=${PYTORCH_DIR}:${PYTHONPATH}
+export MKL_NUM_THREADS=1
+export OMP_NUM_THREADS=1
+
source /usr/local/Ascend/ascend-toolkit/set_env.sh
ARCH=$(uname -m)
diff --git a/dipu/scripts/ci/camb/ci_camb_env.sh b/dipu/scripts/ci/camb/ci_camb_env.sh
index 7527809648..6b0de04a6a 100644
--- a/dipu/scripts/ci/camb/ci_camb_env.sh
+++ b/dipu/scripts/ci/camb/ci_camb_env.sh
@@ -1,9 +1,9 @@
PLATFORM=/mnt/lustre/share/platform
-ENV_NAME=dipu_poc
+ENV_NAME=pt2.0_diopi
export PATH=`python ${PLATFORM}/env/clear_path.py PATH`
export LD_LIBRARY_PATH=`python ${PLATFORM}/env/clear_path.py LD_LIBRARY_PATH`
-GCC_ROOT=/mnt/lustre/share/platform/dep/gcc-7.5
-CONDA_ROOT=${PLATFORM}/env/miniconda3.8
+GCC_ROOT=/mnt/lustre/share/platform/dep/gcc-10.2
+CONDA_ROOT=${PLATFORM}/env/miniconda3.10
export NEUWARE_HOME=/usr/local/neuware
export CC=${GCC_ROOT}/bin/gcc
@@ -13,8 +13,8 @@ export CXX=${GCC_ROOT}/bin/g++
export DIOPI_ROOT=$(pwd)/third_party/DIOPI/impl/lib/
export DIPU_ROOT=$(pwd)/torch_dipu
export LD_LIBRARY_PATH=$DIPU_ROOT:$LD_LIBRARY_PATH
-export PYTHONPATH=${PYTORCH_DIR}/install_path/lib/python3.8/site-packages:${PYTHONPATH}
-export PATH=${GCC_ROOT}/bin:${PYTORCH_DIR}/install_path/bin:${CONDA_ROOT}/envs/dipu_poc/bin:${CONDA_ROOT}/bin:${PATH}
+export PYTHONPATH=${PLATFORM}/dep/DIOPI_pytorch/pytorch2.0:${PYTHONPATH}
+export PATH=${GCC_ROOT}/bin:${CONDA_ROOT}/envs/dipu_poc/bin:${CONDA_ROOT}/bin:${PATH}
export LD_PRELOAD=${GCC_ROOT}/lib64/libstdc++.so.6
@@ -33,6 +33,9 @@ export DIPU_HOST_MEMCACHING_ALGORITHM=BS
#export DIPU_RAW_ALLOCATOR_MIN_ALLOCATE_SIZE=512
export DIPU_CHECK_TENSOR_DEVICE=1
+export MKL_NUM_THREADS=1
+export OMP_NUM_THREADS=1
+
source activate $ENV_NAME
echo "python path : ${PYTHONPATH}"
diff --git a/dipu/scripts/ci/droplet/ci_droplet_env.sh b/dipu/scripts/ci/droplet/ci_droplet_env.sh
index 5140be7c41..1bf7defe90 100644
--- a/dipu/scripts/ci/droplet/ci_droplet_env.sh
+++ b/dipu/scripts/ci/droplet/ci_droplet_env.sh
@@ -16,5 +16,8 @@ export DIPU_PATH=${DIPU_ROOT}
export LIBRARY_PATH=$DIPU_ROOT:$DIOPI_ROOT:$LIBRARY_PATH
export LD_LIBRARY_PATH=$DIPU_ROOT:$DIOPI_ROOT:$LD_LIBRARY_PATH
+export MKL_NUM_THREADS=1
+export OMP_NUM_THREADS=1
+
echo $ENV_PATH
source activate $ENV_PATH
diff --git a/dipu/scripts/ci/nv/ci_nv_env.sh b/dipu/scripts/ci/nv/ci_nv_env.sh
index d885dc983e..453a1da092 100644
--- a/dipu/scripts/ci/nv/ci_nv_env.sh
+++ b/dipu/scripts/ci/nv/ci_nv_env.sh
@@ -2,14 +2,14 @@ PLATFORM=/mnt/cache/share/platform
ENV_NAME=pt2.0_diopi
export PATH=`python ${PLATFORM}/env/clear_path.py PATH`
export LD_LIBRARY_PATH=`python ${PLATFORM}/env/clear_path.py LD_LIBRARY_PATH`
-GCC_ROOT=${PLATFORM}/dep/gcc-7.5
-CONDA_ROOT=${PLATFORM}/env/miniconda3.8
+GCC_ROOT=${PLATFORM}/dep/gcc-10.2
+CONDA_ROOT=${PLATFORM}/env/miniconda3.10
export CC=${GCC_ROOT}/bin/gcc
export CXX=${GCC_ROOT}/bin/g++
-export CUDA_PATH=${PLATFORM}/dep/cuda11.7-cudnn8.5
-export MPI_ROOT=${PLATFORM}/dep/openmpi-4.0.5-cuda11.7
-export NCCL_ROOT=${PLATFORM}/dep/nccl-2.13.4-cuda11.7
+export CUDA_PATH=${PLATFORM}/dep/cuda11.8-cudnn8.9
+export MPI_ROOT=${PLATFORM}/dep/openmpi-4.0.5-cuda11.8
+export NCCL_ROOT=${PLATFORM}/dep/nccl-2.15.5-cuda11.8
export GTEST_ROOT=${PLATFORM}/dep/googletest-gcc5.4
@@ -24,11 +24,10 @@ export DIOPI_ROOT=$(pwd)/third_party/DIOPI/impl/lib/
export DIPU_ROOT=$(pwd)/torch_dipu
export DIOPI_PATH=$(pwd)/third_party/DIOPI/proto
export DIPU_PATH=${DIPU_ROOT}
-export PYTORCH_DIR=${PLATFORM}/env/miniconda3.8/envs/pt2.0_diopi/lib/python3.8/site-packages
+export PYTORCH_DIR=${PLATFORM}/dep/DIOPI_pytorch/pytorch2.0_cu118
export LD_LIBRARY_PATH=$DIPU_ROOT:$LD_LIBRARY_PATH
export PYTHONPATH=${PYTORCH_DIR}:${PYTHONPATH}
export PATH=${GCC_ROOT}/bin:${CONDA_ROOT}/envs/dipu_poc/bin:${CONDA_ROOT}/bin:${PLATFORM}/dep/binutils-2.27/bin:${PATH}
-export LD_PRELOAD=${GCC_ROOT}/lib64/libstdc++.so.6
export PYTORCH_TEST_DIR=${PLATFORM}/env/miniconda3.8/envs/pt2.0_diopi/pytorch2.0
export CUBLAS_WORKSPACE_CONFIG=:4096:8
@@ -45,4 +44,10 @@ export DIPU_HOST_MEMCACHING_ALGORITHM=BF
export DIPU_PATCH_CUDA_CACHED_ALLOCATOR=0
export DIPU_CHECK_TENSOR_DEVICE=1
+# Setting OMP_NUM_THREADS environment variable for each process in default,
+# to avoid your system being overloaded, please further tune the variable
+# for optimal performance in your application as needed.
+export MKL_NUM_THREADS=1
+export OMP_NUM_THREADS=1
+
source activate $ENV_NAME
diff --git a/dipu/scripts/ci/topsrider/ci_topsrider_env.sh b/dipu/scripts/ci/topsrider/ci_topsrider_env.sh
index 250ba8284d..58d8b3787d 100644
--- a/dipu/scripts/ci/topsrider/ci_topsrider_env.sh
+++ b/dipu/scripts/ci/topsrider/ci_topsrider_env.sh
@@ -16,4 +16,7 @@ export VENDOR_INCLUDE_DIRS=/usr/include/tops
export DIOPI_PATH=${DIPU_LOCAL_DIR}/third_party/DIOPI/proto
export DIPU_PATH=${DIPU_ROOT}
+export MKL_NUM_THREADS=1
+export OMP_NUM_THREADS=1
+
# source activate $ENV_NAME
diff --git a/dipu/tests/python/README.md b/dipu/tests/python/README.md
index 0c68dc8cfd..31dbb7ef64 100644
--- a/dipu/tests/python/README.md
+++ b/dipu/tests/python/README.md
@@ -28,12 +28,12 @@
- 对于带有随机性的 op,可以考虑考察其分布的特征(参考 multinomial、random 等)。
- 可以考虑不使用 assertion,只检测 error 不检测 failure(加上注释说明)。
- `torch.allclose` **不**检测 shape、dtype 等,请谨慎使用。
- - 如果需要检查 C++ 库内部的输出,可以使用 `test.python.utils.stdout_redirector.stdout_redirector` 来捕获。
+ - 如果需要检查 C++ 库内部的输出,可以使用 `utils.stdout_redirector.stdout_redirector` 来捕获。
- 如果需要使用输出辅助 debug,可以考虑在使用 unittest 的 assertion 函数时传入 [`msg` 参数](https://docs.python.org/3/library/unittest.html#unittest.TestCase.assertEqual)。
- **请勿**做对全局空间有影响的事,例如:
- 修改 import 库的内容;
- 在全局空间中定义其他函数和变量(考虑挪至 class 内);
- - 修改环境变量(可使用 `test.python.utils.local_eviron.local_eviron`);
+ - 修改环境变量(可使用 `utils.local_eviron.local_eviron`);
- 应根据 torch 的文档广泛地测试各种使用场景。
- 尽量借助 setUp()、class 变量等方式简化代码,不要复制大量代码,以便后续维护。
- 对于预期会失败的测例,可以使用 `onlyOn` 和 `skipOn` 修饰器设置在某些设备上跳过测例(参考 cdist)。
@@ -46,17 +46,21 @@
独立测例应该是一个可独立运行的 python 脚本。这些测试脚本会被自动转为单元测试,脚本返回值为 0 说明测试成功,否则测试失败。
-如果需要自动化检测 C++ 库内部的输出,可以使用 `test.python.utils.stdout_redirector.stdout_redirector` 来捕获。
+如果需要自动化检测 C++ 库内部的输出,可以使用 `utils.stdout_redirector.stdout_redirector` 来捕获。
独立测例可以包含 print。不过,在自动生成的单元测试中,独立测例中的输出会在测试通过的情况下被消除。
+可以使用 `utils.test_in_subprocess.run_individual_test_cases` 在同一个文件中进行多个独立测例的编写。
+
#### 子进程的 coverage 收集
使用 `multiprocessing.Process` 创建的子进程在 CI 上跑 coverage 时不会被统计,因此使用这种测试方式(e.g. `test_allocator.py`)的独立测例需要一些特别的处理。
#### C++ `gcov`
-在调用 `multiprocessing.Process` 之前,**必须**调用 `multiprocessing.set_start_method('spawn', force=True)` 修改 multiprocessing 的默认进程生成方式。
+~~在调用 `multiprocessing.Process` 之前,**必须**调用 `multiprocessing.set_start_method("spawn", force=True)` 修改 multiprocessing 的默认进程生成方式。~~
+
+请使用 `utils.test_in_subprocess.run_individual_test_cases` 来创建子进程。
##### Python `coverage`
diff --git a/dipu/tests/python/individual_scripts/generate_unittest_for_individual_scripts.py b/dipu/tests/python/individual_scripts/generate_unittest_for_individual_scripts.py
index c5e17db78b..39ab90d229 100644
--- a/dipu/tests/python/individual_scripts/generate_unittest_for_individual_scripts.py
+++ b/dipu/tests/python/individual_scripts/generate_unittest_for_individual_scripts.py
@@ -9,7 +9,7 @@ def generate_unittest_for_individual_scripts():
import io
import os
import unittest
-from stdout_redirector import stdout_redirector
+from utils.stdout_redirector import stdout_redirector
class TestIndividualScripts(unittest.TestCase):
diff --git a/dipu/tests/python/individual_scripts/local_eviron.py b/dipu/tests/python/individual_scripts/local_eviron.py
deleted file mode 120000
index 7570555029..0000000000
--- a/dipu/tests/python/individual_scripts/local_eviron.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../torch_dipu/testing/_internal/local_eviron.py
\ No newline at end of file
diff --git a/dipu/tests/python/individual_scripts/stdout_redirector.py b/dipu/tests/python/individual_scripts/stdout_redirector.py
deleted file mode 120000
index fe5e70337c..0000000000
--- a/dipu/tests/python/individual_scripts/stdout_redirector.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../torch_dipu/testing/_internal/stdout_redirector.py
\ No newline at end of file
diff --git a/dipu/tests/python/individual_scripts/test_allocator.py b/dipu/tests/python/individual_scripts/test_allocator.py
index 9ebe2563f3..281c4d25fa 100644
--- a/dipu/tests/python/individual_scripts/test_allocator.py
+++ b/dipu/tests/python/individual_scripts/test_allocator.py
@@ -1,8 +1,15 @@
+import itertools
import os
-from multiprocessing import Process, set_start_method
+from utils.test_in_subprocess import run_individual_test_cases
-def test_allocator(max_allocate, step, algorithm, log_mask, test_pin_memory=True):
+def test_allocator(
+ max_allocate: int,
+ step: int,
+ algorithm: str,
+ log_mask: int,
+ test_pin_memory: bool = True,
+):
os.environ["DIPU_DEVICE_MEMCACHING_ALGORITHM"] = algorithm
os.environ["DIPU_DEBUG_ALLOCATOR"] = str(log_mask)
os.environ["DIPU_MEM_CHECK"] = "1"
@@ -67,35 +74,16 @@ def test_allocator(max_allocate, step, algorithm, log_mask, test_pin_memory=True
if __name__ == "__main__":
- set_start_method('spawn', force=True)
- max_allocate = 1 << 15
- p1 = Process(
- target=test_allocator,
- args=(max_allocate, 1, "BF", 0),
+ MAX_ALLOCATE = 1 << 15
+ run_individual_test_cases(
+ itertools.product(
+ (test_allocator,),
+ (
+ {"args": (MAX_ALLOCATE, 1, "BF", 0)},
+ {"args": (MAX_ALLOCATE, 1, "BS", 0)},
+ {"args": (MAX_ALLOCATE, 1, "RAW", 0)},
+ {"args": (MAX_ALLOCATE, 17919, "BF", 3, False)},
+ ),
+ ),
+ in_parallel=False,
)
- p1.start()
- p1.join()
-
- p2 = Process(
- target=test_allocator,
- args=(max_allocate, 1, "BS", 0),
- )
- p2.start()
- p2.join()
-
- p3 = Process(target=test_allocator, args=(max_allocate, 1, "RAW", 0))
- p3.start()
- p3.join()
-
- max_allocate = 1 << 30
- p4 = Process(
- target=test_allocator,
- args=(max_allocate, 17919, "BF", 3, False),
- )
- p4.start()
- p4.join()
-
- assert p1.exitcode == 0
- assert p2.exitcode == 0
- assert p3.exitcode == 0
- assert p4.exitcode == 0
diff --git a/dipu/tests/python/individual_scripts/test_dipu_fallback.py b/dipu/tests/python/individual_scripts/test_dipu_fallback.py
index 8c4f65235e..f2dbf25027 100644
--- a/dipu/tests/python/individual_scripts/test_dipu_fallback.py
+++ b/dipu/tests/python/individual_scripts/test_dipu_fallback.py
@@ -1,32 +1,172 @@
# Copyright (c) 2023, DeepLink.
import io
-from stdout_redirector import stdout_redirector
-from local_eviron import local_eviron
+from typing import Callable, List
+import torch
+from utils.stdout_redirector import stdout_redirector
+from utils.local_eviron import local_eviron
+from utils.test_in_subprocess import run_individual_test_cases
-def _test_dipu_fallback():
+def test_fallback(
+ op_names: List[str],
+ diopi_protos: List[str],
+ test_fn: Callable[[], None],
+ extra_check_str_in_output: List[str] = [],
+) -> None:
captured = io.BytesIO()
with stdout_redirector(captured):
with local_eviron(
{
- "DIPU_FORCE_FALLBACK_OPS_LIST": "add.out,sub.out",
+ "DIPU_FORCE_FALLBACK_OPS_LIST": ",".join(op_names),
"DIPU_DUMP_OP_ARGS": "1",
+ "DIPU_LOG_FALLBACK_INFO": "1",
}
):
- import torch
import torch_dipu
- x = torch.randn(3, 4).cuda()
- _ = x + x
- _ = x - x
-
+ test_fn()
output = captured.getvalue().decode()
- assert "force fallback has been set, add.out will be fallback to cpu" in output
- assert "force fallback has been set, sub.out will be fallback to cpu" in output
- assert "dipu_fallback" in output
- assert "diopiAdd" not in output
- assert "diopiSub" not in output
+ print(output, end="")
+ assert all(
+ f"force fallback has been set, {name} will be fallback to cpu" in output
+ for name in op_names
+ )
+ assert all(item not in output for item in diopi_protos)
+ if extra_check_str_in_output is not None:
+ assert all(item in output for item in extra_check_str_in_output)
+
+
+def _test_dipu_fallback():
+ def fn():
+ x = torch.randn(3, 4).cuda()
+ _ = x + x
+ _ = x - x
+
+ test_fallback(
+ ["add.out", "sub.out"], ["diopiAdd", "diopiSub"], fn, ["dipu_fallback"]
+ )
+
+
+def _test_cpu_fallback():
+ def fn():
+ device = "cuda"
+ m = torch.nn.BatchNorm2d(100, affine=False).to(device)
+ input = torch.randn(20, 100, 35, 45).to(device)
+ m(input)
+
+ test_fallback(
+ ["native_batch_norm"],
+ ["diopiBatchNorm"],
+ fn,
+ ["cpu_fallback:\taten::native_batch_norm", "dipu_fallback"],
+ )
+
+
+def _test_dipu_index_put_impl_fallback():
+ def fn():
+ dipu_tensor = torch.tensor([1, 2, 3, 4, 5]).cuda()
+ indices = torch.tensor([1, 3]).cuda()
+ values = torch.tensor([10, 40]).cuda()
+ torch._index_put_impl_(dipu_tensor, (indices,), values, accumulate=False)
+
+ tensor = dipu_tensor.cpu()
+ indices = indices.cpu()
+ values = values.cpu()
+ torch._index_put_impl_(tensor, (indices,), values, accumulate=False)
+
+ assert torch.allclose(tensor, dipu_tensor.cpu())
+
+ test_fallback(
+ ["_index_put_impl_"],
+ ["diopiIndexPut"],
+ fn,
+ ["custom fallback to cpu, name=_index_put_impl_"],
+ )
+
+
+def _test_dipu_copy_fallback_():
+ def fn():
+ source_tensor = torch.tensor([1.0, 2.0, 3.0]).cuda()
+ target_dipu = torch.zeros_like(source_tensor).cuda()
+ target_dipu.copy_(source_tensor)
+
+ source_tensor = source_tensor.cpu()
+ target_tensor = torch.zeros_like(source_tensor)
+ target_tensor.copy_(source_tensor)
+
+ assert torch.allclose(target_tensor, target_dipu.cpu())
+
+ test_fallback(
+ ["copy_"],
+ ["diopiCopyInp"],
+ fn,
+ ["custom fallback to dipu copy, name=copy_"],
+ )
+
+
+def _test_dipu_convolution_backward_overrideable_fallback():
+ def fn():
+ torch.manual_seed(42)
+ device = torch.device("dipu")
+ m = torch.nn.Conv2d(2, 3, 3, stride=2).to(device)
+ m.weight = torch.nn.Parameter(torch.ones_like(m.weight))
+ m.bias = torch.nn.Parameter(torch.ones_like(m.bias))
+ input_dipu = torch.randn(2, 2, 5, 5).to(device).requires_grad_(True)
+ output_dipu = m(input_dipu)
+ output_dipu.backward(torch.ones_like(output_dipu))
+
+ torch.manual_seed(42)
+ m = torch.nn.Conv2d(2, 3, 3, stride=2)
+ m.weight = torch.nn.Parameter(torch.ones_like(m.weight))
+ m.bias = torch.nn.Parameter(torch.ones_like(m.bias))
+ input_cpu = torch.randn(2, 2, 5, 5, requires_grad=True)
+ output_cpu = m(input_cpu)
+ output_cpu.backward(torch.ones_like(output_cpu))
+
+ assert torch.allclose(output_dipu.cpu(), output_cpu)
+ assert torch.allclose(input_dipu.grad.cpu(), input_cpu.grad)
+
+ test_fallback(
+ ["convolution_backward_overrideable"],
+ ["diopiConvolution2dBackward"],
+ fn,
+ ["custom fallback to cpu, name=convolution_backward_overrideable"],
+ )
+
+
+def _test_dipu_convolution_overrideable_fallback():
+ def fn():
+ m = torch.nn.Conv2d(2, 3, 3, stride=2).cuda()
+ m.weight = torch.nn.Parameter(torch.ones_like(m.weight))
+ m.bias = torch.nn.Parameter(torch.ones_like(m.bias))
+ input_dipu = torch.randn(2, 2, 5, 5).cuda()
+ output_dipu = m(input_dipu)
+
+ m = m.cpu()
+ m.weight = torch.nn.Parameter(torch.ones_like(m.weight))
+ m.bias = torch.nn.Parameter(torch.ones_like(m.bias))
+ input_cpu = input_dipu.cpu()
+ output_cpu = m(input_cpu)
+
+ assert torch.allclose(output_dipu.cpu(), output_cpu)
+
+ test_fallback(
+ ["convolution_overrideable"],
+ ["diopiConvolution2d"],
+ fn,
+ ["custom fallback to cpu, name=convolution_overrideable"],
+ )
if __name__ == "__main__":
- _test_dipu_fallback()
+ run_individual_test_cases(
+ [
+ _test_dipu_fallback,
+ _test_cpu_fallback,
+ _test_dipu_index_put_impl_fallback,
+ _test_dipu_copy_fallback_,
+ _test_dipu_convolution_backward_overrideable_fallback,
+ _test_dipu_convolution_overrideable_fallback,
+ ],
+ in_parallel=True,
+ )
diff --git a/dipu/tests/python/individual_scripts/test_dipu_op_register.py b/dipu/tests/python/individual_scripts/test_dipu_op_register.py
index 770c41cfd7..dd0f580e72 100644
--- a/dipu/tests/python/individual_scripts/test_dipu_op_register.py
+++ b/dipu/tests/python/individual_scripts/test_dipu_op_register.py
@@ -1,9 +1,11 @@
# Copyright (c) 2023, DeepLink.
-from multiprocessing import Process, set_start_method
-from local_eviron import local_eviron
+import itertools
+from typing import Union
+from utils.local_eviron import local_eviron
+from utils.test_in_subprocess import run_individual_test_cases
-def _test_op_register(mode):
+def _test_op_register(mode: Union[int, str]) -> None:
with local_eviron(
{"DIPU_IMMEDIATE_REGISTER_OP": str(mode), "DIPU_DUMP_OP_ARGS": "1"}
):
@@ -15,28 +17,14 @@ def _test_op_register(mode):
if __name__ == "__main__":
- set_start_method('spawn', force=True)
- p1 = Process(
- target=_test_op_register,
- args=(0,),
+ run_individual_test_cases(
+ itertools.product(
+ (_test_op_register,),
+ (
+ {"args": (0,)},
+ {"args": (1,)},
+ {"args": ("",)},
+ ),
+ ),
+ in_parallel=True,
)
- p1.start()
- p1.join()
-
- p2 = Process(
- target=_test_op_register,
- args=(1,),
- )
- p2.start()
- p2.join()
-
- p3 = Process(
- target=_test_op_register,
- args=("",),
- )
- p3.start()
- p3.join()
-
- assert p1.exitcode == 0
- assert p2.exitcode == 0
- assert p3.exitcode == 0
diff --git a/dipu/tests/python/individual_scripts/test_dipu_profiler.py b/dipu/tests/python/individual_scripts/test_dipu_profiler.py
new file mode 100644
index 0000000000..95dfbd8042
--- /dev/null
+++ b/dipu/tests/python/individual_scripts/test_dipu_profiler.py
@@ -0,0 +1,59 @@
+# Copyright (c) 2023, DeepLink.
+import os
+os.environ["FORCE_USE_DIPU_PROFILER"] = "True"
+
+import tempfile
+import torch
+import torch_dipu
+import torchvision.models as models
+from torch.profiler import profile, ProfilerActivity
+from torch_dipu.testing._internal.common_utils import TestCase, run_tests, onlyOn
+from utils.local_eviron import local_eviron
+
+
+class TestProfiler(TestCase):
+ def test_profiler(self):
+ model = models.resnet18().cuda()
+ inputs = torch.randn(5, 3, 224, 224).cuda()
+
+ with local_eviron({"KINETO_LOG_LEVEL": "999"}): # suppress profiler logs
+ with profile(
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+ profile_memory=True,
+ record_shapes=True,
+ with_modules=True,
+ with_stack=True,
+ experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True)
+ ) as prof:
+ output = model(inputs)
+ output.sum().backward()
+
+ profile_output = prof.key_averages(group_by_input_shape=True).table(
+ sort_by="self_cuda_time_total", row_limit=1000
+ )
+ self.assertIn("diopiConvolution2dBackward", profile_output)
+ self.assertIn("dipu_convolution_", profile_output)
+ self.assertIn("LaunchKernel_dipu", profile_output)
+ self.assertIn("LaunchKernel_diopi", profile_output)
+ self.assertIn("Self CPU time total", profile_output)
+ self.assertIn("Self CUDA time total", profile_output)
+ self.assertIn("5, 3, 224, 224", profile_output)
+
+ profile_stack_output = prof.key_averages(group_by_stack_n=15).table(
+ sort_by="cuda_time_total", row_limit=1000)
+ self.assertIn("Source Location", profile_stack_output)
+ self.assertIn("resnet.py", profile_stack_output)
+
+ profile_memory_output = prof.key_averages().table(
+ sort_by="self_cuda_memory_usage", row_limit=1000)
+ self.assertIn("Self CPU Mem", profile_memory_output)
+ self.assertIn("Self CUDA Mem", profile_memory_output)
+ self.assertIn("Mb", profile_memory_output)
+ self.assertIn("Kb", profile_memory_output)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ prof.export_chrome_trace(f"{tmpdir}/dipu_resnet18_profiler.json")
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/dipu/tests/python/individual_scripts/test_dumparg.py b/dipu/tests/python/individual_scripts/test_dumparg.py
new file mode 100644
index 0000000000..a2e3829ddf
--- /dev/null
+++ b/dipu/tests/python/individual_scripts/test_dumparg.py
@@ -0,0 +1,29 @@
+# Copyright (c) 2023, DeepLink.
+import io
+from utils.stdout_redirector import stdout_redirector
+from utils.local_eviron import local_eviron
+
+
+def _test_copy_dumparg():
+ captured = io.BytesIO()
+ with stdout_redirector(captured):
+ with local_eviron(
+ {
+ "DIPU_DUMP_OP_ARGS": "2",
+ }
+ ):
+ import torch
+ import torch_dipu
+
+ source_tensor = torch.tensor([1.0, 2.0, 3.0]).cuda()
+ target_tensor = torch.zeros_like(source_tensor).cuda()
+ target_tensor.copy_(source_tensor)
+
+ output = captured.getvalue().decode()
+ print(output)
+ assert "DIPUCopyInplace.run" in output
+ assert "numel: 3, sizes: [3], stride: [1], is_view: 0, dtype: float" in output
+
+
+if __name__ == "__main__":
+ _test_copy_dumparg()
diff --git a/dipu/tests/python/individual_scripts/test_memory_stats.py b/dipu/tests/python/individual_scripts/test_memory_stats.py
index 34b044a1f2..3b50b5a377 100644
--- a/dipu/tests/python/individual_scripts/test_memory_stats.py
+++ b/dipu/tests/python/individual_scripts/test_memory_stats.py
@@ -1,8 +1,9 @@
+import itertools
import os
-from multiprocessing import Process, set_start_method
+from utils.test_in_subprocess import run_individual_test_cases
-def test_mem_stats(algorithm, log_mask):
+def test_mem_stats(algorithm: str, log_mask: int):
os.environ["DIPU_DEVICE_MEMCACHING_ALGORITHM"] = algorithm
os.environ["DIPU_DEBUG_ALLOCATOR"] = str(log_mask)
print("allocator algorithm:", algorithm)
@@ -13,7 +14,7 @@ def test_mem_stats(algorithm, log_mask):
ins = []
pin_ins = []
real_allocated = 0
- for i in range(100):
+ for _ in range(100):
numel = random.randint(0, 1 << 20)
x = torch.randn(numel).to(torch.device("cuda:0"))
y = torch.randn(numel).pin_memory()
@@ -37,7 +38,7 @@ def test_mem_stats(algorithm, log_mask):
real_max_allocate = real_allocated
- for i in range(len(ins)):
+ for _ in range(len(ins)):
numel = ins[0].numel()
real_allocated -= ((numel * 4 - 1) | 511) + 1
ins.pop(0)
@@ -61,25 +62,14 @@ def test_mem_stats(algorithm, log_mask):
if __name__ == "__main__":
- set_start_method('spawn', force=True)
- p1 = Process(
- target=test_mem_stats,
- args=("BF", 0),
+ run_individual_test_cases(
+ itertools.product(
+ (test_mem_stats,),
+ (
+ {"args": ("BF", 0)},
+ {"args": ("BS", 0)},
+ {"args": ("RAW", 0)},
+ ),
+ ),
+ in_parallel=False,
)
- p1.start()
- p1.join()
-
- p2 = Process(
- target=test_mem_stats,
- args=("BS", 0),
- )
- p2.start()
- p2.join()
-
- p3 = Process(target=test_mem_stats, args=("RAW", 0))
- p3.start()
- p3.join()
-
- assert p1.exitcode == 0
- assert p2.exitcode == 0
- assert p3.exitcode == 0
diff --git a/dipu/tests/python/individual_scripts/test_profiler_communication.py b/dipu/tests/python/individual_scripts/test_profiler_communication.py
index dfc279b0a3..f3cce135f4 100644
--- a/dipu/tests/python/individual_scripts/test_profiler_communication.py
+++ b/dipu/tests/python/individual_scripts/test_profiler_communication.py
@@ -1,5 +1,8 @@
import os
+os.environ["FORCE_USE_DIPU_PROFILER"] = "True"
+
import random
+import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
@@ -56,7 +59,8 @@ def demo_basic_ddp(rank, world_size, port):
)
assert("c10d::allreduce_" in profile_output)
assert("LaunchKernel_DiclAllreduce" in profile_output)
- prof.export_chrome_trace(f"./dipu_resnet18_profiler_{rank}.json")
+ with tempfile.TemporaryDirectory() as tmpdir:
+ prof.export_chrome_trace(f"{tmpdir}/dipu_resnet18_profiler_{rank}.json")
cleanup()
def test_profiler_communication():
diff --git a/dipu/tests/python/individual_scripts/utils b/dipu/tests/python/individual_scripts/utils
new file mode 120000
index 0000000000..468ba705ba
--- /dev/null
+++ b/dipu/tests/python/individual_scripts/utils
@@ -0,0 +1 @@
+../utils
\ No newline at end of file
diff --git a/dipu/tests/python/unittests/stdout_redirector.py b/dipu/tests/python/unittests/stdout_redirector.py
deleted file mode 120000
index fe5e70337c..0000000000
--- a/dipu/tests/python/unittests/stdout_redirector.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../torch_dipu/testing/_internal/stdout_redirector.py
\ No newline at end of file
diff --git a/dipu/tests/python/unittests/test_conv2d.py b/dipu/tests/python/unittests/test_conv2d.py
index e93181c670..b33677aef3 100644
--- a/dipu/tests/python/unittests/test_conv2d.py
+++ b/dipu/tests/python/unittests/test_conv2d.py
@@ -39,6 +39,23 @@ def test_conv_2d(self):
)
# print("conv2d output compare successfully")
+ def test_conv2d_nhwc(self):
+ device = torch.device("dipu")
+
+ m = nn.Conv2d(2, 3, 3).to(device=device, memory_format=torch.channels_last)
+ self.assertTrue(m.weight.is_contiguous(memory_format=torch.channels_last))
+
+ x = torch.rand(2, 2, 5, 5).to(device=device, memory_format=torch.channels_last)
+ x.requires_grad_()
+ self.assertTrue(x.is_contiguous(memory_format=torch.channels_last))
+
+ y = m(x)
+ self.assertTrue(y.is_contiguous(memory_format=torch.channels_last))
+
+ y.backward(torch.rand_like(y))
+ self.assertTrue(x.grad.is_contiguous(memory_format=torch.channels_last))
+ self.assertTrue(m.weight.grad.is_contiguous(memory_format=torch.channels_last))
+
if __name__ == "__main__":
run_tests()
diff --git a/dipu/tests/python/unittests/test_layer_norm.py b/dipu/tests/python/unittests/test_layer_norm.py
index aec7d0aa97..bb6424a811 100644
--- a/dipu/tests/python/unittests/test_layer_norm.py
+++ b/dipu/tests/python/unittests/test_layer_norm.py
@@ -76,6 +76,21 @@ def test_layer_norm_no_affine(self):
)
self._run_layer_norm()
+ # maybe we don't want ChannelsLast -> Contiguous here, but just align with pytorch
+ # https://github.com/pytorch/pytorch/blob/v2.0.0/aten/src/ATen/native/cuda/layer_norm_kernel.cu#L1340-L1346
+ def test_layer_norm_out_format(self):
+ l = torch.nn.LayerNorm(4).cuda()
+ xs = [
+ torch.rand(2, 3, 5, 4, device='cuda').to(memory_format=torch.channels_last),
+ torch.rand(2, 4, 3, device='cuda').permute([0, 2, 1]),
+ torch.rand(2, 6, device='cuda')[:, 1:5],
+ ]
+ for x in xs:
+ y = l(x)
+ # seems can't get LEGACY_CONTIGUOUS_MEMORY_FORMAT in python,
+ # just assume it's MemoryFormat::Contiguous
+ self.assertTrue(y.is_contiguous())
+
if __name__ == "__main__":
run_tests()
diff --git a/dipu/tests/python/unittests/test_minimum_maximum.py b/dipu/tests/python/unittests/test_minimum_maximum.py
index eecc57bc18..a6b00383d4 100644
--- a/dipu/tests/python/unittests/test_minimum_maximum.py
+++ b/dipu/tests/python/unittests/test_minimum_maximum.py
@@ -15,6 +15,26 @@ def test_minimum(self):
r_cpu = torch.minimum(a.to(self.cpu), b.to(self.cpu))
self.assertEqual(r_dipu.to(self.cpu), r_cpu)
+ def test_minimum_scalar(self):
+ # special test cases from the inference of internlm
+ a = torch.randn((3, 4))
+ b = torch.tensor(torch.finfo(a.dtype).max)
+ # scalar on cpu
+ r_dipu1 = torch.minimum(a.to(self.dipu), b)
+ # scalar on device
+ r_dipu2 = torch.minimum(a.to(self.dipu), b.to(self.dipu))
+ r_cpu = torch.minimum(a, b)
+ self.assertEqual(r_dipu1.to(self.cpu), r_cpu)
+ self.assertEqual(r_dipu2.to(self.cpu), r_cpu)
+
+ def test_minimum_different_devices(self):
+ a = torch.tensor([1, -2, 3])
+ b = torch.tensor([4, 0, 2]).to(self.dipu)
+ with self.assertRaises(RuntimeError) as context:
+ torch.minimum(a, b)
+ self.assertIn(
+ 'Expected all tensors to be on the same device', str(context.exception))
+
def test_maximum(self):
a = torch.tensor((1, 2, -1))
b = torch.tensor((3, 0, 4))
@@ -22,6 +42,26 @@ def test_maximum(self):
r_cpu = torch.maximum(a.to(self.cpu), b.to(self.cpu))
self.assertEqual(r_dipu.to(self.cpu), r_cpu)
+ def test_maximum_scalar(self):
+ # special test cases from the inference of internlm
+ a = torch.randn((3, 4))
+ b = torch.tensor(torch.finfo(a.dtype).min)
+ # scalar on cpu
+ r_dipu1 = torch.maximum(a.to(self.dipu), b)
+ # scalar on device
+ r_dipu2 = torch.maximum(a.to(self.dipu), b.to(self.dipu))
+ r_cpu = torch.maximum(a, b)
+ self.assertEqual(r_dipu1.to(self.cpu), r_cpu)
+ self.assertEqual(r_dipu2.to(self.cpu), r_cpu)
+
+ def test_maximum_different_devices(self):
+ a = torch.tensor([1, -2, 3])
+ b = torch.tensor([4, 0, 2]).to(self.dipu)
+ with self.assertRaises(RuntimeError) as context:
+ torch.maximum(a, b)
+ self.assertIn(
+ 'Expected all tensors to be on the same device', str(context.exception))
+
if __name__ == "__main__":
run_tests()
diff --git a/dipu/tests/python/unittests/test_mm.py b/dipu/tests/python/unittests/test_mm.py
index f3c8a7eb10..992ece4f82 100644
--- a/dipu/tests/python/unittests/test_mm.py
+++ b/dipu/tests/python/unittests/test_mm.py
@@ -9,7 +9,7 @@ def test_mm(self):
dipu = torch.device("dipu")
cpu = torch.device("cpu")
mat1 = torch.randn(2, 3)
- mat2 = torch.randn(3, 3)
+ mat2 = torch.randn(3, 4)
r1 = torch.mm(mat1.to(dipu), mat2.to(dipu))
r2 = torch.mm(mat1.to(cpu), mat2.to(cpu))
self.assertEqual(r1.to(cpu), r2)
diff --git a/dipu/tests/python/unittests/test_prod.py b/dipu/tests/python/unittests/test_prod.py
index 5f0f4fa3fa..24964673a9 100644
--- a/dipu/tests/python/unittests/test_prod.py
+++ b/dipu/tests/python/unittests/test_prod.py
@@ -25,12 +25,12 @@ def test_prod_bool(self):
input_arrays = [[True, True], [True, False], [False, False]]
for input_array in input_arrays:
input_tensor = torch.tensor(input_array)
- out = torch.prod(input_tensor).item()
- out_cuda = torch.prod(input_tensor.cuda()).item()
- self.assertEqual(out, out_cuda)
+ out = torch.prod(input_tensor)
+ out_cuda = torch.prod(input_tensor.cuda())
+ self.assertEqual(out, out_cuda, exact_dtype=True)
def test_prod_dtype(self):
- test_dtypes = [torch.float16, torch.float32]
+ test_dtypes = [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64]
for input_dtype in test_dtypes:
input_tensor = torch.tensor(
[[1, 2, 3], [4, 5, 6]], dtype=input_dtype, device="dipu"
@@ -46,6 +46,20 @@ def test_prod_dtype(self):
out = torch.prod(input_tensor, 1, dtype=output_dtype)
self.assertEqual(out, expected_output, exact_dtype=True)
+ def test_prod_integer_promotion(self):
+ test_dtypes = [torch.int8, torch.int16, torch.int32]
+ for input_dtype in test_dtypes:
+ input_tensor = torch.tensor(
+ [[1, 2, 3], [4, 5, 6]], dtype=input_dtype, device="dipu"
+ )
+ expected_output = torch.tensor(720, dtype=torch.int64, device="dipu")
+ out = torch.prod(input_tensor)
+ self.assertEqual(out, expected_output, exact_dtype=True)
+
+ expected_output = torch.tensor([6, 120], dtype=torch.int64, device="dipu")
+ out = torch.prod(input_tensor, 1)
+ self.assertEqual(out, expected_output, exact_dtype=True)
+
if __name__ == "__main__":
run_tests()
diff --git a/dipu/tests/python/unittests/test_profiler.py b/dipu/tests/python/unittests/test_profiler.py
index 5343ce1712..fbe75e3e68 100644
--- a/dipu/tests/python/unittests/test_profiler.py
+++ b/dipu/tests/python/unittests/test_profiler.py
@@ -1,10 +1,7 @@
# Copyright (c) 2023, DeepLink.
import torch
import torch_dipu
-import torchvision.models as models
-from torch.profiler import profile, ProfilerActivity
from torch_dipu.testing._internal.common_utils import TestCase, run_tests, onlyOn
-from torch_dipu.testing._internal.local_eviron import local_eviron
import torch._dynamo as dynamo
import subprocess
@@ -17,50 +14,7 @@ def check_string_in_directory(directory, search_string):
return False
-
class TestProfiler(TestCase):
- def test_profiler(self):
- model = models.resnet18().cuda()
- inputs = torch.randn(5, 3, 224, 224).cuda()
-
- with local_eviron({"KINETO_LOG_LEVEL": "999"}): # suppress profiler logs
- with profile(
- activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
- profile_memory=True,
- record_shapes=True,
- with_modules=True,
- with_stack=True,
- experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True)
- ) as prof:
- output = model(inputs)
- output.sum().backward()
-
- profile_output = prof.key_averages(group_by_input_shape=True).table(
- sort_by="self_cuda_time_total", row_limit=1000
- )
- self.assertIn("diopiConvolution2dBackward", profile_output)
- self.assertIn("dipu_convolution_", profile_output)
- self.assertIn("LaunchKernel_dipu", profile_output)
- self.assertIn("LaunchKernel_diopi", profile_output)
- self.assertIn("Self CPU time total", profile_output)
- self.assertIn("Self CUDA time total", profile_output)
- self.assertIn("5, 3, 224, 224", profile_output)
-
- profile_stack_output = prof.key_averages(group_by_stack_n=15).table(
- sort_by="cuda_time_total", row_limit=1000)
- self.assertIn("Source Location", profile_stack_output)
- self.assertIn("resnet.py", profile_stack_output)
- self.assertIn("test_profiler.py", profile_stack_output)
-
- profile_memory_output = prof.key_averages().table(
- sort_by="self_cuda_memory_usage", row_limit=1000)
- self.assertIn("Self CPU Mem", profile_memory_output)
- self.assertIn("Self CUDA Mem", profile_memory_output)
- self.assertIn("Mb", profile_memory_output)
- self.assertIn("Kb", profile_memory_output)
-
- prof.export_chrome_trace("./dipu_resnet18_profiler.json")
-
@onlyOn("NPU")
def test_aot_profiler(self):
x = torch.randn(3, 4).cuda()
diff --git a/dipu/tests/python/unittests/test_profiler_cuda.py b/dipu/tests/python/unittests/test_profiler_cuda.py
new file mode 100644
index 0000000000..3937cf4e7b
--- /dev/null
+++ b/dipu/tests/python/unittests/test_profiler_cuda.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2023, DeepLink.
+import tempfile
+import torch
+import torch_dipu
+import torchvision.models as models
+from torch.profiler import profile, ProfilerActivity
+from torch_dipu.testing._internal.common_utils import TestCase, run_tests, onlyOn
+from utils.local_eviron import local_eviron
+
+
+class TestProfiler(TestCase):
+ @onlyOn("CUDA")
+ def test_profiler(self):
+ model = models.resnet18().cuda()
+ inputs = torch.randn(5, 3, 224, 224).cuda()
+
+ with local_eviron({"KINETO_LOG_LEVEL": "999"}): # suppress profiler logs
+ with profile(
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+ profile_memory=True,
+ record_shapes=True,
+ with_modules=True,
+ with_stack=True,
+ experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True)
+ ) as prof:
+ output = model(inputs)
+ output.sum().backward()
+
+ profile_output = prof.key_averages(group_by_input_shape=True).table(
+ sort_by="self_cuda_time_total", row_limit=1000
+ )
+ self.assertNotIn("diopiConvolution2dBackward", profile_output)
+ self.assertNotIn("dipu_convolution_", profile_output)
+ self.assertNotIn("LaunchKernel_dipu", profile_output)
+ self.assertNotIn("LaunchKernel_diopi", profile_output)
+ self.assertIn("aten::cudnn_convolution", profile_output)
+ self.assertIn("aten::add", profile_output)
+ self.assertIn("vectorized_elementwise_kernel", profile_output)
+ self.assertIn("Self CPU time total", profile_output)
+ self.assertIn("Self CUDA time total", profile_output)
+ self.assertIn("5, 3, 224, 224", profile_output)
+
+ profile_stack_output = prof.key_averages(group_by_stack_n=15).table(
+ sort_by="cuda_time_total", row_limit=1000)
+ self.assertIn("Source Location", profile_stack_output)
+ self.assertIn("resnet.py", profile_stack_output)
+
+ profile_memory_output = prof.key_averages().table(
+ sort_by="self_cuda_memory_usage", row_limit=1000)
+ self.assertIn("Self CPU Mem", profile_memory_output)
+ self.assertIn("Self CUDA Mem", profile_memory_output)
+ self.assertIn("Mb", profile_memory_output)
+ self.assertIn("Kb", profile_memory_output)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ prof.export_chrome_trace(f"{tmpdir}/resnet18_profiler.json")
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/dipu/tests/python/unittests/utils b/dipu/tests/python/unittests/utils
new file mode 120000
index 0000000000..468ba705ba
--- /dev/null
+++ b/dipu/tests/python/unittests/utils
@@ -0,0 +1 @@
+../utils
\ No newline at end of file
diff --git a/dipu/tests/python/utils/__init__.py b/dipu/tests/python/utils/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/dipu/torch_dipu/testing/_internal/local_eviron.py b/dipu/tests/python/utils/local_eviron.py
similarity index 100%
rename from dipu/torch_dipu/testing/_internal/local_eviron.py
rename to dipu/tests/python/utils/local_eviron.py
diff --git a/dipu/torch_dipu/testing/_internal/stdout_redirector.py b/dipu/tests/python/utils/stdout_redirector.py
similarity index 100%
rename from dipu/torch_dipu/testing/_internal/stdout_redirector.py
rename to dipu/tests/python/utils/stdout_redirector.py
index 903f023c51..64669caae9 100644
--- a/dipu/torch_dipu/testing/_internal/stdout_redirector.py
+++ b/dipu/tests/python/utils/stdout_redirector.py
@@ -48,12 +48,12 @@ def _redirect_stdout(to_fd):
_redirect_stdout(tfile.fileno())
# Yield to caller, then redirect stdout back to the saved fd
yield
+ finally:
_redirect_stdout(saved_stdout_fd)
# Copy contents of temporary file to the given stream
tfile.flush()
tfile.seek(0, io.SEEK_SET)
stream.write(tfile.read())
- finally:
tfile.close()
os.close(saved_stdout_fd)
diff --git a/dipu/tests/python/utils/test_in_subprocess.py b/dipu/tests/python/utils/test_in_subprocess.py
new file mode 100644
index 0000000000..6268ea6997
--- /dev/null
+++ b/dipu/tests/python/utils/test_in_subprocess.py
@@ -0,0 +1,97 @@
+import io
+import os
+import pathlib
+import queue
+import sys
+from multiprocessing import Process, Queue, set_start_method
+from tempfile import TemporaryDirectory
+from typing import Callable, Iterable, List, Tuple, TypedDict, Union
+from .stdout_redirector import stdout_redirector
+
+
+class Args(TypedDict, total=False):
+ args: tuple
+ kwargs: dict
+
+
+def _run_individual_test_cases_sequential(
+ entry_points: Iterable[Tuple[Callable, Args]]
+) -> None:
+ all_tests_pass = True
+ for entry_point, args in entry_points:
+ p = Process(
+ target=entry_point, args=args.get("args", ()), kwargs=args.get("kwargs", {})
+ )
+ p.start()
+ p.join()
+ all_tests_pass = all_tests_pass and p.exitcode == 0
+ assert all_tests_pass
+
+
+def _entry_point_wrapper(
+ entry_point: Callable, future_output: Queue, log_dir: str, *args, **kwargs
+) -> None:
+ sys.stderr = open(f"{log_dir}/stderr_{os.getpid()}", "w")
+ captured = io.BytesIO()
+ try:
+ with stdout_redirector(captured):
+ entry_point(*args, **kwargs)
+ finally:
+ future_output.put(captured.getvalue().decode("utf-8"))
+
+
+def _run_individual_test_cases_parallel(
+ entry_points: Iterable[Tuple[Callable, Args]]
+) -> None:
+ with TemporaryDirectory() as tmpdir:
+ future_outputs: List[Queue] = []
+ ps: List[Process] = []
+ for entry_point, args in entry_points:
+ future_output = Queue()
+ p = Process(
+ target=_entry_point_wrapper,
+ args=(entry_point, future_output, tmpdir) + args.get("args", ()),
+ kwargs=args.get("kwargs", {}),
+ )
+ p.start()
+ future_outputs.append(future_output)
+ ps.append(p)
+
+ all_tests_pass = True
+ for p, future_output in zip(ps, future_outputs):
+ p.join()
+ try:
+ print(future_output.get_nowait(), end="")
+ except queue.Empty:
+ all_tests_pass = False
+ print(
+ pathlib.Path(f"{tmpdir}/stderr_{p.pid}").read_text(),
+ end="",
+ file=sys.stderr,
+ )
+ all_tests_pass = all_tests_pass and p.exitcode == 0
+ assert all_tests_pass
+
+
+def run_individual_test_cases(
+ entry_points: Iterable[Union[Callable, Tuple[Callable, Args]]],
+ in_parallel: bool = False,
+) -> None:
+ """
+ Run test cases in individual processes in parallel or sequential.
+ WARN: This function must be called within an `if __name__ == "__main__"` region.
+ ---
+ Args:
+ `entry_points`: A sequence of test cases. Each test case is either a function
+ or a tuple of a function and its arguments
+ `(func, {"args": [...], "kwargs": {...}})`.
+ `in_parallel`: Whether to run test cases in parallel.
+ """
+ set_start_method("spawn", force=True) # this is required for gcov to work
+ uniform_entry_points: Iterable[Tuple[Callable, Args]] = map(
+ lambda x: x if isinstance(x, tuple) else (x, {}), entry_points
+ )
+ if in_parallel:
+ _run_individual_test_cases_parallel(uniform_entry_points)
+ else:
+ _run_individual_test_cases_sequential(uniform_entry_points)
diff --git a/dipu/third_party/DIOPI b/dipu/third_party/DIOPI
index 9b9589b226..385ce67f65 160000
--- a/dipu/third_party/DIOPI
+++ b/dipu/third_party/DIOPI
@@ -1 +1 @@
-Subproject commit 9b9589b226d3a18482582037d9707574fe39fd48
+Subproject commit 385ce67f65c1c785c9a3713465c6489025da7bf1
diff --git a/dipu/third_party/kineto b/dipu/third_party/kineto
index c1bed2f2dc..2923b3002a 160000
--- a/dipu/third_party/kineto
+++ b/dipu/third_party/kineto
@@ -1 +1 @@
-Subproject commit c1bed2f2dc3779dec2a63025ea1b72a957f4badf
+Subproject commit 2923b3002a179d6dfe202e6d032567bb2816eae7
diff --git a/dipu/torch_dipu/csrc_dipu/CMakeLists.txt b/dipu/torch_dipu/csrc_dipu/CMakeLists.txt
index 764c36c910..f12feb8558 100644
--- a/dipu/torch_dipu/csrc_dipu/CMakeLists.txt
+++ b/dipu/torch_dipu/csrc_dipu/CMakeLists.txt
@@ -29,12 +29,15 @@ add_custom_command(
COMMAND
python "${DIPU_AUTOGEN_DIOPI_WRAPPER_SCRIPT}" --config
"${DIPU_AUTOGEN_DIOPI_WRAPPER_CONFIG}" --out "${DIPU_AUTOGENED_KERNELS_CPP}"
+ "$<$>:--convert_config=${CMAKE_SOURCE_DIR}/third_party/DIOPI/impl/${UsedVendor}/convert_config.yaml>"
--use_diopi_adapter "False" --autocompare "False" --print_func_call_info "True"
--print_op_arg "True" --fun_config_dict
'{\"current_device\": \"${UsedVendor}\"}'
DEPENDS ${DIPU_AUTOGEN_DIOPI_WRAPPER_SCRIPT}
${DIPU_AUTOGEN_DIOPI_WRAPPER_CONFIG}
- ${DIPU_AUTOGEN_DIOPI_WRAPPER_TEMPLATE})
+ ${DIPU_AUTOGEN_DIOPI_WRAPPER_TEMPLATE}
+ "$<$>:${CMAKE_SOURCE_DIR}/third_party/DIOPI/impl/${UsedVendor}/convert_config.yaml>"
+)
add_custom_target(autogen_diopi_kernels_cpp
DEPENDS ${DIPU_AUTOGENED_KERNELS_CPP})
add_dependencies(${DIPU_AUTOGENED_KERNELS} autogen_diopi_kernels_cpp)
diff --git a/dipu/torch_dipu/csrc_dipu/aten/DIPUATenFunctions.h b/dipu/torch_dipu/csrc_dipu/aten/DIPUATenFunctions.h
index 010c07836c..36bc802fa3 100644
--- a/dipu/torch_dipu/csrc_dipu/aten/DIPUATenFunctions.h
+++ b/dipu/torch_dipu/csrc_dipu/aten/DIPUATenFunctions.h
@@ -1,62 +1,65 @@
// Copyright (c) 2023, DeepLink.
#pragma once
-#include
-#include
-#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
namespace dipu {
namespace native {
+namespace dipu_aten {
+// dipu native func
+at::Tensor empty(at::IntArrayRef size, c10::optional dtype_opt,
+ c10::optional layout_opt,
+ c10::optional device_opt,
+ c10::optional pin_memory_opt,
+ c10::optional memory_format_opt);
+at::Tensor empty_cpu(at::IntArrayRef size,
+ c10::optional dtype_opt,
+ c10::optional layout_opt,
+ c10::optional device_opt,
+ c10::optional pin_memory_opt,
+ c10::optional memory_format_opt);
+
+at::Tensor empty_strided(at::IntArrayRef size, at::IntArrayRef stride,
+ c10::optional dtype_opt,
+ c10::optional layout_opt,
+ c10::optional device_opt,
+ c10::optional pin_memory_opt);
+at::Tensor empty_strided_cpu(at::IntArrayRef size, at::IntArrayRef stride,
+ c10::optional dtype_opt,
+ c10::optional layout_opt,
+ c10::optional device_opt,
+ c10::optional pin_memory_opt);
+
+const at::Tensor& resize_(const at::Tensor& self, at::IntArrayRef size,
+ c10::optional memory_format);
+
+at::Scalar _local_scalar_dense_dipu(const at::Tensor& self);
+
+at::Tensor& set_storage_dipu_(at::Tensor& result, c10::Storage storage,
+ int64_t storage_offset, at::IntArrayRef size,
+ at::IntArrayRef stride);
+at::Tensor& set_dipu_(at::Tensor& self);
+
+void resize_bytes_dipu(c10::StorageImpl* storage, size_t newsize_bytes);
+
+bool is_pinned(const at::Tensor& self, c10::optional device);
+at::Tensor _pin_memory(const at::Tensor& self,
+ c10::optional device);
-struct DIPUATenFunctions {
- // dipu native func
- static at::Tensor empty(at::IntArrayRef size,
- c10::optional dtype_opt,
- c10::optional layout_opt,
- c10::optional device_opt,
- c10::optional pin_memory_opt,
- c10::optional memory_format_opt);
- static at::Tensor empty_cpu(
- at::IntArrayRef size, c10::optional dtype_opt,
- c10::optional layout_opt,
- c10::optional device_opt, c10::optional pin_memory_opt,
- c10::optional memory_format_opt);
-
- static at::Tensor empty_strided(at::IntArrayRef size, at::IntArrayRef stride,
- c10::optional dtype_opt,
- c10::optional layout_opt,
- c10::optional device_opt,
- c10::optional pin_memory_opt);
- static at::Tensor empty_strided_cpu(at::IntArrayRef size,
- at::IntArrayRef stride,
- c10::optional dtype_opt,
- c10::optional layout_opt,
- c10::optional device_opt,
- c10::optional pin_memory_opt);
-
- static const at::Tensor& resize_(
- const at::Tensor& self, at::IntArrayRef size,
- c10::optional memory_format);
-
- static at::Scalar _local_scalar_dense_dipu(const at::Tensor& self);
-
- static at::Tensor& set_storage_dipu_(at::Tensor& result, c10::Storage storage,
- int64_t storage_offset,
- at::IntArrayRef size,
- at::IntArrayRef stride);
- static at::Tensor& set_dipu_(at::Tensor& self);
-
- static void resize_bytes_dipu(c10::StorageImpl* storage,
- size_t newsize_bytes);
-
- static bool is_pinned(const at::Tensor& self,
- c10::optional device);
- static at::Tensor _pin_memory(const at::Tensor& self,
- c10::optional device);
-
- // todo:: use same format as autogen
- // diopi function defined in AutoGenedKernels.cpp,
-};
+// todo:: use same format as autogen
+// diopi function defined in AutoGenedKernels.cpp,
+}; // namespace dipu_aten
} // namespace native
} // namespace dipu
diff --git a/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.cpp b/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.cpp
index 6898e83b0c..e03796b938 100644
--- a/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.cpp
+++ b/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.cpp
@@ -14,42 +14,57 @@
#include
#include
-using dnative = dipu::native::DIPUATenFunctions;
-
-static std::string force_fallback_operators_list = []() -> std::string {
- std::ifstream stream(".dipu_force_fallback_op_list.config",
- std::ios_base::in | std::ios::binary);
- std::string content;
- const char* env = std::getenv("DIPU_FORCE_FALLBACK_OPS_LIST");
- if (env != nullptr) {
- content += env;
- }
- if (stream.is_open()) {
- while (!stream.eof()) {
- std::string line;
- stream >> line;
- content += "," + line;
+namespace dnative = dipu::native::dipu_aten;
+
+namespace dipu {
+namespace {
+
+void read_comma_separated_list(std::istream& input,
+ std::vector& output) {
+ auto line = std::string();
+ while (std::getline(input, line)) {
+ auto buffer = std::stringstream(line);
+ auto value = std::string();
+ while (std::getline(buffer, value, ',')) {
+ output.push_back(std::move(value));
}
}
- return content;
-}();
+}
+
+std::vector getFallbackList() {
+ auto fallback_list = std::vector();
+ if (auto env = std::getenv("DIPU_FORCE_FALLBACK_OPS_LIST")) {
+ auto iss = std::stringstream(env);
+ read_comma_separated_list(iss, fallback_list);
+ }
+ auto file = std::ifstream(".dipu_force_fallback_op_list.config",
+ std::ios_base::in | std::ios::binary);
+ read_comma_separated_list(file, fallback_list);
+
+ return fallback_list;
+}
+
+const std::vector force_fallback_operators_list =
+ getFallbackList();
+
+} // end of namespace
-namespace dipu {
bool get_force_fallback(const char* opname) {
- if (force_fallback_operators_list.size() <= 0 || opname == nullptr) {
+ if (force_fallback_operators_list.empty() || opname == nullptr) {
return false;
- } else {
- std::stringstream strstream(force_fallback_operators_list);
- std::string force_fallback_pattern;
- while (std::getline(strstream, force_fallback_pattern, ',')) {
- if (force_fallback_pattern.size() <= 0) {
- continue;
- }
+ }
+ for (auto& force_fallback_pattern : force_fallback_operators_list) {
+ if (force_fallback_pattern.empty()) {
+ continue;
+ }
+ try {
bool force_fallback =
std::regex_match(opname, std::regex(force_fallback_pattern));
if (force_fallback) {
return true;
}
+ } catch (const std::regex_error& e) {
+ TORCH_CHECK(false, e.what());
}
}
return false;
@@ -76,7 +91,7 @@ void dump_fallback_op_args(const c10::OperatorHandle& op,
const auto num_arguments = schema_args.size();
auto arguments = torch::jit::last(stack, num_arguments);
- auto dumpTensor = [&](const at::Tensor tensor) {
+ auto dumpTensor = [&](const at::Tensor& tensor) {
if (tensor.defined()) {
std::cout << "numel: " << tensor.numel() << ", sizes: " << tensor.sizes()
<< ", stride: " << tensor.strides()
@@ -97,7 +112,6 @@ void dump_fallback_op_args(const c10::OperatorHandle& op,
}
};
- const auto arguments_begin = stack->size() - num_arguments;
for (const auto idx : c10::irange(arguments.size())) {
std::cout << "\t" << name << ": \t" << schema_args[idx].name() << ": ";
const auto& ivalue = arguments[idx];
@@ -108,9 +122,9 @@ void dump_fallback_op_args(const c10::OperatorHandle& op,
} else if (ivalue.isTensorList()) {
const auto& tensorlist = ivalue.toTensorList();
std::cout << std::endl;
- for (size_t i = 0; i < tensorlist.size(); i++) {
+ for (const auto& tensor : tensorlist) {
std::cout << "\t";
- dumpTensor(tensorlist[i]);
+ dumpTensor(tensor);
std::cout << std::endl;
}
} else {
@@ -149,16 +163,17 @@ void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys,
}
}
+// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
std::deque>
DIPUOpRegister::dipuOpRegisterList;
std::mutex DIPUOpRegister::mutex_;
+// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
void DIPUOpRegister::register_op() {
std::lock_guard guard(mutex_);
- for (auto iter = dipuOpRegisterList.begin(); iter != dipuOpRegisterList.end();
- ++iter) {
- torch::Library* lib = std::get<0>(*iter);
- DIPUOpRegister::OpRegFunPtr fun_ptr = std::get<1>(*iter);
+ for (auto& iter : dipuOpRegisterList) {
+ torch::Library* lib = std::get<0>(iter);
+ DIPUOpRegister::OpRegFunPtr fun_ptr = std::get<1>(iter);
fun_ptr(*lib);
}
dipuOpRegisterList.clear();
@@ -288,6 +303,7 @@ at::Scalar wrapper_DIPU___local_scalar_dense(const at::Tensor& self) {
return dnative::_local_scalar_dense_dipu(self);
}
+// NOLINTBEGIN(performance-unnecessary-value-param)
at::Tensor& wrapper_DIPU_source_Storage_set_(at::Tensor& self,
at::Storage source) {
// No device check
@@ -302,10 +318,11 @@ at::Tensor& wrapper_DIPU_source_Storage_offset_set_(
c10::SymIntArrayRef size, c10::SymIntArrayRef stride) {
// No device check
// DeviceGuard omitted
- return dnative::set_storage_dipu_(self, source, storage_offset.expect_int(),
- C10_AS_INTARRAYREF_SLOW(size),
- C10_AS_INTARRAYREF_SLOW(stride));
+ return dnative::set_storage_dipu_(
+ self, std::move(source), storage_offset.expect_int(),
+ C10_AS_INTARRAYREF_SLOW(size), C10_AS_INTARRAYREF_SLOW(stride));
}
+// NOLINTEND(performance-unnecessary-value-param)
at::Tensor& wrapper_DIPU_source_Tensor_set_(at::Tensor& self,
const at::Tensor& source) {
@@ -413,7 +430,7 @@ DIPU_LIBRARY_IMPL(aten, DIPU_DEVICE_TYPE_MACRO, m) {
class IgnoreWarningHandler : public c10::WarningHandler {
public:
- void process(const c10::Warning& warning) {
+ void process(const c10::Warning& warning) override {
// do nothing
}
};
diff --git a/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.hpp b/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.hpp
index 9cef60995c..aa5acbb20b 100644
--- a/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.hpp
+++ b/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.hpp
@@ -18,6 +18,7 @@ void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys,
torch::jit::Stack* stack);
// Print the warning message only once for one process.
+// NOLINTBEGIN(bugprone-macro-parentheses): x cannot be in parentheses
#define DIPU_LOG_WARNING_ONCE(x) \
do { \
static bool should_print = true; \
@@ -26,6 +27,7 @@ void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys,
should_print = false; \
} \
} while (0)
+// NOLINTEND(bugprone-macro-parentheses)
// Check the environment variable and call the DIPU_LOG_WARNING_ONCE
#define DIPU_OP_LOG_WARNING_ONCE(...) \
@@ -53,8 +55,8 @@ void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys,
} else { \
DIPU_OP_LOG_WARNING_ONCE("force fallback has been set, "); \
} \
- DIPU_OP_LOG_WARNING_ONCE(opname << " will be fallback to cpu" \
- << std::endl); \
+ DIPU_OP_LOG_WARNING_ONCE((opname) << " will be fallback to cpu" \
+ << "\n"); \
} \
} while (false);
@@ -62,7 +64,7 @@ void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys,
wapper_func, custom_fallback_func) \
do { \
if ((reinterpret_cast(diopi_func) != nullptr) && \
- !(force_fallback || dipu::get_force_fallback(opname))) { \
+ !((force_fallback) || dipu::get_force_fallback(opname))) { \
m.impl(opname, TORCH_FN(wapper_func)); \
} else { \
if ((reinterpret_cast(diopi_func) == nullptr)) { \
@@ -70,22 +72,24 @@ void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys,
} else { \
DIPU_OP_LOG_WARNING_ONCE("force fallback has been set, "); \
} \
- DIPU_OP_LOG_WARNING_ONCE(opname << " will be fallback to cpu" \
- << std::endl); \
+ DIPU_OP_LOG_WARNING_ONCE((opname) << " will be fallback to cpu" \
+ << "\n"); \
m.impl(opname, TORCH_FN(custom_fallback_func)); \
} \
} while (false);
class DIPUOpRegister {
public:
- typedef void (*OpRegFunPtr)(torch::Library&);
+ using OpRegFunPtr = void (*)(torch::Library&);
private:
OpRegFunPtr fun_ptr_;
torch::Library lib_;
+ // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
static std::deque>
dipuOpRegisterList;
static std::mutex mutex_;
+ // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
public:
DIPUOpRegister(OpRegFunPtr fun_ptr, const char* ns,
@@ -97,7 +101,7 @@ class DIPUOpRegister {
fun_ptr_(lib_);
} else {
std::lock_guard guard(mutex_);
- dipuOpRegisterList.push_back(std::make_tuple(&lib_, fun_ptr_));
+ dipuOpRegisterList.emplace_back(&lib_, fun_ptr_);
}
}
@@ -106,8 +110,6 @@ class DIPUOpRegister {
} // namespace at
-namespace {
-
#define DIPU_LIBRARY_IMPL(ns, k, m) _DIPU_LIBRARY_IMPL(ns, k, m, C10_UID)
#define _DIPU_LIBRARY_IMPL(ns, k, m, uid) \
@@ -124,6 +126,4 @@ namespace {
[]() { return [](torch::Library&) -> void {}; }), \
#ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__); \
void C10_CONCATENATE(DIPU_LIBRARY_IMPL_init_##ns##_##k##_, \
- uid)(torch::Library & m)
-
-} // namespace
+ uid)(torch::Library & (m))
diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctions.hpp b/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctions.hpp
index 7de896f582..955ef7a092 100644
--- a/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctions.hpp
+++ b/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctions.hpp
@@ -17,22 +17,23 @@ static c10::optional dipu_to_cpu(
return cpu_tensor;
}
-static at::Tensor to_cpu_no_half(const at::Tensor& devtensor) {
+static at::Tensor to_cpu_with_half_to_float(const at::Tensor& devtensor) {
auto cpu_tensor = devtensor.cpu();
auto intype = devtensor.options().dtype_opt()->toScalarType();
if (intype == at::ScalarType::Half) {
return cpu_tensor.to(at::ScalarType::Float);
- } else {
- return cpu_tensor;
}
+ return cpu_tensor;
}
static at::Tensor& custom_fallback_dipu_silu_out(const at::Tensor& self,
at::Tensor& out) {
DIPU_OP_LOG_WARNING_ONCE("custom fallback to cpu, name=silu_out"
<< std::endl);
- auto self_cpu = to_cpu_no_half(self);
- auto out_cpu = to_cpu_no_half(self);
+ auto self_cpu = to_cpu_with_half_to_float(self);
+ auto out_cpu = to_cpu_with_half_to_float(self);
+
+ // NOLINTNEXTLINE(readability-suspicious-call-argument): It's the correct order
out_cpu = at::silu_out(self_cpu, out_cpu);
out.copy_(out_cpu);
return out;
@@ -153,7 +154,9 @@ custom_fallback_dipu_convolution_backward_overrideable(
grad_output_cpu, input_cpu, weight_cpu, c10::nullopt, stride, padding,
dilation, transposed, output_padding, groups, output_mask_temp);
- at::Tensor grad_input, grad_weight, grad_bias;
+ at::Tensor grad_input;
+ at::Tensor grad_weight;
+ at::Tensor grad_bias;
if (output_mask[0]) {
grad_input = std::get<0>(result).to(device);
@@ -226,8 +229,15 @@ custom_fallback_dipu_linear_backward(const at::Tensor& input,
auto grad_output_cpu = grad_output.cpu();
auto weight_cpu = weight.cpu();
- at::Tensor grad_input_cpu, grad_weight_cpu, grad_bias_cpu;
- at::Tensor grad_input, grad_weight, grad_bias;
+ at::Tensor grad_input;
+ at::Tensor grad_input_cpu;
+
+ at::Tensor grad_weight;
+ at::Tensor grad_weight_cpu;
+
+ at::Tensor grad_bias;
+ at::Tensor grad_bias_cpu;
+
int64_t dims = input.dim();
const auto device = input.device();
@@ -330,5 +340,64 @@ at::Tensor& custom_fallback_dipu__amp_update_scale_(at::Tensor& current_scale,
double backoff_factor,
int64_t growth_interval);
+static at::Tensor& custom_fallback_dipu_addmm_out(
+ const at::Tensor& self, const at::Tensor& mat1, const at::Tensor& mat2,
+ const at::Scalar& beta, const at::Scalar& alpha, at::Tensor& out) {
+ auto self_cpu = to_cpu_with_half_to_float(self);
+ auto mat1_cpu = to_cpu_with_half_to_float(mat1);
+ auto mat2_cpu = to_cpu_with_half_to_float(mat2);
+ auto out_cpu = to_cpu_with_half_to_float(out);
+ out_cpu = at::addmm_out(out_cpu, self_cpu, mat1_cpu, mat2_cpu, beta, alpha);
+ out.copy_(out_cpu);
+ return out;
+}
+
+static at::Tensor& custom_fallback_dipu_bmm_out(const at::Tensor& self,
+ const at::Tensor& mat2,
+ at::Tensor& out) {
+ auto self_cpu = to_cpu_with_half_to_float(self);
+ auto mat2_cpu = to_cpu_with_half_to_float(mat2);
+ auto out_cpu = to_cpu_with_half_to_float(out);
+ out_cpu = at::bmm_out(out_cpu, self_cpu, mat2_cpu);
+ out.copy_(out_cpu);
+ return out;
+}
+
+static at::Tensor& custom_fallback_dipu_mm_out(const at::Tensor& self,
+ const at::Tensor& mat2,
+ at::Tensor& out) {
+ auto self_cpu = to_cpu_with_half_to_float(self);
+ auto mat2_cpu = to_cpu_with_half_to_float(mat2);
+ auto out_cpu = to_cpu_with_half_to_float(out);
+ out_cpu = at::mm_out(out_cpu, self_cpu, mat2_cpu);
+ out.copy_(out_cpu);
+ return out;
+}
+
+static at::Tensor custom_fallback_dipu_linear(
+ const at::Tensor& input, const at::Tensor& weight,
+ const c10::optional& bias) {
+ auto input_cpu = to_cpu_with_half_to_float(input);
+ auto weight_cpu = to_cpu_with_half_to_float(weight);
+ c10::optional bias_cpu = c10::nullopt;
+
+ at::Tensor out;
+ at::Tensor out_cpu;
+
+ if (bias.has_value() && bias.value().defined()) {
+ if (bias.value().options().dtype_opt()->toScalarType() ==
+ at::ScalarType::Half) {
+ bias_cpu = bias.value().to(at::ScalarType::Float).cpu();
+ } else {
+ bias_cpu = bias.value().cpu();
+ }
+ }
+
+ out_cpu = at::linear(input_cpu, weight_cpu, bias_cpu);
+ out = out_cpu.to(input.device())
+ .to(input.options().dtype_opt()->toScalarType());
+ return out;
+}
+
} // namespace native
} // namespace dipu
diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctionsForAmpGradScaler.cpp b/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctionsForAmpGradScaler.cpp
index 2514e1e163..03a8fb2334 100644
--- a/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctionsForAmpGradScaler.cpp
+++ b/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctionsForAmpGradScaler.cpp
@@ -18,7 +18,7 @@ void _amp_non_finite_check_and_unscale_(at::Tensor& scaled_grad,
const at::Tensor& inv_scale) {
scaled_grad *= inv_scale.item();
if (!scaled_grad.isfinite().all().item()) {
- found_inf[0] = 1.f;
+ found_inf[0] = 1.F;
}
}
@@ -46,8 +46,7 @@ void custom_fallback_dipu__amp_foreach_non_finite_check_and_unscale_(
TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor.");
TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
for (const at::Tensor& t : scaled_grads) {
- // NOLINTNEXTLINE: const_cast here is safe according to pytorch's source
- // code
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast): const_cast here is safe according to pytorch's source code
_amp_non_finite_check_and_unscale_(const_cast(t), found_inf,
inv_scale);
}
diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUAmp.cpp b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUAmp.cpp
index b7c5c347d3..7181d9892e 100644
--- a/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUAmp.cpp
+++ b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUAmp.cpp
@@ -237,8 +237,10 @@ DIPU_DEFINE_CAST_POLICY_CONVERSION(kPromote, promote);
// This function will throw an error message when
// torch.nn.functional.binary_cross_entropy is called within an autocast block
-Tensor DipuBinaryCrossEntropyBanned(const Tensor&, const Tensor&,
- const c10::optional&, int64_t) {
+Tensor DipuBinaryCrossEntropyBanned(const Tensor& /*unused*/,
+ const Tensor& /*unused*/,
+ const c10::optional& /*unused*/,
+ int64_t /*unused*/) {
AT_ERROR(
R"(torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.cpp b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.cpp
index eb75a7b8cb..523533cddf 100644
--- a/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.cpp
+++ b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.cpp
@@ -35,7 +35,8 @@ void setDipuCopyInstance(DIPUCopyBase* op) { dipu_copy_op() = op; }
namespace dipu {
namespace native {
-at::Scalar DIPUATenFunctions::_local_scalar_dense_dipu(const at::Tensor& self) {
+namespace dipu_aten {
+at::Scalar _local_scalar_dense_dipu(const at::Tensor& self) {
at::Scalar r;
AT_DISPATCH_ALL_TYPES_AND2(
at::kHalf, at::kBool, self.scalar_type(), "_local_scalar_dense_dipu",
@@ -50,5 +51,6 @@ at::Scalar DIPUATenFunctions::_local_scalar_dense_dipu(const at::Tensor& self) {
});
return r;
}
+} // namespace dipu_aten
} // namespace native
} // namespace dipu
diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.hpp b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.hpp
index 47f519984e..bd79a4abde 100644
--- a/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.hpp
+++ b/dipu/torch_dipu/csrc_dipu/aten/ops/DIPUCopy.hpp
@@ -4,6 +4,7 @@
#include
#include
#include
+#include
#include
#include
@@ -15,12 +16,12 @@ namespace dipu {
namespace native {
// NOTICE: these 2 func defined in AutoGenedKernels.cpp
// if dipu autogen support header file gen, remove this
-at::Tensor dipu_wrap_diopi_cast_dtype(const at::Tensor& src,
+at::Tensor dipu_wrap_diopi_cast_dtype(const at::Tensor& self,
at::ScalarType dtype);
// if dipu autogen support proxy one torch op to multiple diopi op, remove
// this.
-at::Tensor& dipu_wrap_diopi_copy_inp(at::Tensor& dst, const at::Tensor& src,
+at::Tensor& dipu_wrap_diopi_copy_inp(at::Tensor& self, const at::Tensor& src,
bool non_blocking);
} // namespace native
@@ -46,7 +47,7 @@ inline void tryRecordStream(const at::Tensor& tensor, DIPUStream& curStream,
bool is_default_stream) {
if ((tensor.is_cpu() && tensor.options().pinned_memory()) ||
!is_default_stream) {
- tensor.record_stream(curStream);
+ tensor.record_stream(curStream.unwrap());
}
}
diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/EmptyOpsKernel.cpp b/dipu/torch_dipu/csrc_dipu/aten/ops/EmptyOpsKernel.cpp
index 2837967029..0467b4a76d 100644
--- a/dipu/torch_dipu/csrc_dipu/aten/ops/EmptyOpsKernel.cpp
+++ b/dipu/torch_dipu/csrc_dipu/aten/ops/EmptyOpsKernel.cpp
@@ -1,21 +1,28 @@
// Copyright (c) 2023, DeepLink.
#include
+#include
+#include
#include
-#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
#include
+#include
+#include
#include
-#include
-#include
-#include
+#include "csrc_dipu/aten/DIPUATenFunctions.h"
+#include "csrc_dipu/base/basedef.h"
+#include "csrc_dipu/profiler/profiler.h"
+#include "csrc_dipu/runtime/core/allocator/DIPUCachingAllocator.h"
+#include "csrc_dipu/runtime/rthelper.h"
-using at::Layout;
-using c10::device_or_default;
-using c10::layout_or_default;
-using c10::StorageImpl;
-using c10::TensorImpl;
-
-namespace dipu::native {
+namespace dipu {
+namespace native {
static c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
if (pin_memory) {
@@ -24,11 +31,12 @@ static c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
return c10::GetCPUAllocator();
}
-at::Tensor DIPUATenFunctions::empty(
- at::IntArrayRef size, c10::optional dtype_opt,
- c10::optional layout_opt, c10::optional