Skip to content

Commit

Permalink
[dicp] add dicp ci (#648)
Browse files Browse the repository at this point in the history
* add dicp ci

* fix cache dir not exist error
  • Loading branch information
zhaochaoxing authored Jan 19, 2024
1 parent b17aaee commit 9a862ce
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 17 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,17 @@ jobs:
cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/Build-Ascend-910b/dipu
source scripts/ci/ascend/ci_ascend_env.sh
bash tests/run_ascend_tests.sh
- name: Run dicp op test
run: |
set -ex
cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/Build-Ascend-910b/dipu
source scripts/ci/ascend/ci_ascend_env.sh
cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/Build-Ascend-910b/dicp
pip uninstall dicp -y
python setup.py clean && python setup.py install --user
export TEST_DIR=$(pwd)/test
bash ${TEST_DIR}/ascend_scripts/ops/run_test_ops.sh false
Test-One-Iter-Ascend-910b:
name: Test-one-iter-ascend-910b
Expand Down
2 changes: 2 additions & 0 deletions dicp/dicp/dynamo_bridge/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from pathlib import Path
from typing import Any, Dict, Tuple

import torch.fx
Expand All @@ -7,6 +8,7 @@


def save_cpu_gm(gm: torch.fx.GraphModule, folder: str):
Path(folder).mkdir(exist_ok=True)
cpu_gm = copy_gm_to_cpu(gm)
grap_code = cpu_gm.code
graph_key = code_hash(grap_code)
Expand Down
4 changes: 2 additions & 2 deletions dicp/dicp/vendor/AscendGraph/compile_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time

from dicp.dynamo_bridge.compile import DeviceCompileJob
from torch._inductor.codecache import pick_vec_isa, cpp_compile_command, write, get_hash
from torch._inductor.codecache import pick_vec_isa, cpp_compile_command, write, code_hash
from torch._inductor import exc


Expand All @@ -24,7 +24,7 @@ def __init__(self, source_code) -> None:
source_code.strip(),
"json",
extra=cpp_compile_command("i", "o", vec_isa=picked_vec_isa) +
'local_rank' + str(self._local_rank) + get_hash(compile_file_code, 'code')
'local_rank' + str(self._local_rank) + code_hash(compile_file_code)
)
self._output_graph_path = self._input_path[:-5] + '/graph'
print('output_path: ', self._output_graph_path)
Expand Down
22 changes: 13 additions & 9 deletions dicp/dicp/vendor/AscendGraph/opset_convert.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import torch
from dicp.dynamo_bridge.op_transformer import BackendPatternMatcherTransformer
from dicp.dynamo_bridge.compile_fx import is_torch_210
from dicp.vendor.AscendGraph.ascend_op import MatMul, CastToCpu, IdentityInp
from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer
from dicp.vendor.AscendGraph.pattern_replacement import (
ascend_pattern_matcher,
aten_patterns_cls_list,
ascend_patterns_cls_list
)

if is_torch_210:
from dicp.dynamo_bridge.op_transformer import BackendPatternMatcherTransformer
from dicp.vendor.AscendGraph.pattern_replacement import (
ascend_pattern_matcher,
aten_patterns_cls_list,
ascend_patterns_cls_list
)


# 该pass需要在FuseTransposeMatmul之后
Expand Down Expand Up @@ -74,13 +77,14 @@ def symint_in_inputs(nodes):
def ascendgraph_opset_convert(
gm: torch.fx.GraphModule,
):
gm = BackendPatternMatcherTransformer(
ascend_pattern_matcher, aten_patterns_cls_list).transform(gm)
if is_torch_210:
gm = BackendPatternMatcherTransformer(
ascend_pattern_matcher, aten_patterns_cls_list).transform(gm)
gm = AtenToAscendTransformer(gm).transform()

# For bug in pytorch
# Avoid for dynamic shape
if not symint_in_inputs(list(gm.graph.nodes)):
if is_torch_210 and not symint_in_inputs(list(gm.graph.nodes)):
gm = BackendPatternMatcherTransformer(
ascend_pattern_matcher, ascend_patterns_cls_list).transform(gm)
gm = OutputMarkPass().transform(gm)
Expand Down
10 changes: 4 additions & 6 deletions dicp/test/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch._dynamo as dynamo
import torch_dipu
from dicp.dynamo_bridge import pt_patch # noqa F401
from dicp.dynamo_bridge.compile_fx import is_torch_210
torch.manual_seed(1)
random.seed(1)

Expand All @@ -22,12 +23,9 @@ def __init__(self, static_size, dynamic_size):


def update_dynamo_config(dynamic=False):
if dynamic:
dynamo.config.dynamic_shapes = True
dynamo.config.assume_static_by_default = False
else:
dynamo.config.dynamic_shapes = False
dynamo.config.assume_static_by_default = True
dynamo.config.dynamic_shapes = dynamic
if is_torch_210:
dynamo.config.assume_static_by_default = not dynamic


def get_device():
Expand Down

0 comments on commit 9a862ce

Please sign in to comment.