Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dicp]add dicp ci #648

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,17 @@ jobs:
cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/Build-Ascend-910b/dipu
source scripts/ci/ascend/ci_ascend_env.sh
bash tests/run_ascend_tests.sh

- name: Run dicp op test
run: |
set -ex
cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/Build-Ascend-910b/dipu
source scripts/ci/ascend/ci_ascend_env.sh
cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/Build-Ascend-910b/dicp
pip uninstall dicp -y
python setup.py clean && python setup.py install --user
export TEST_DIR=$(pwd)/test
bash ${TEST_DIR}/ascend_scripts/ops/run_test_ops.sh false

Test-One-Iter-Ascend-910b:
name: Test-one-iter-ascend-910b
Expand Down
8 changes: 6 additions & 2 deletions dicp/dicp/vendor/AscendGraph/compile_job.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import subprocess
import time
import base64
import hashlib

from dicp.dynamo_bridge.compile import DeviceCompileJob
from torch._inductor.codecache import pick_vec_isa, cpp_compile_command, write, get_hash
from torch._inductor.codecache import pick_vec_isa, cpp_compile_command, write
from torch._inductor import exc


Expand All @@ -18,13 +20,15 @@ def __init__(self, source_code) -> None:
for file in [source_path, source_include]:
with open(file, 'r') as f:
compile_file_code += f.read()
code_sha256 = hashlib.sha256(compile_file_code.encode("utf-8")).digest()
code_hash = base64.b32encode(code_sha256)[:51].decode("utf-8").lower()
picked_vec_isa = pick_vec_isa()
self._local_rank = int(os.environ.get("LOCAL_RANK", 0))
self._key, self._input_path = write(
source_code.strip(),
"json",
extra=cpp_compile_command("i", "o", vec_isa=picked_vec_isa) +
'local_rank' + str(self._local_rank) + get_hash(compile_file_code, 'code')
'local_rank' + str(self._local_rank) + code_hash
)
self._output_graph_path = self._input_path[:-5] + '/graph'
print('output_path: ', self._output_graph_path)
Expand Down
22 changes: 13 additions & 9 deletions dicp/dicp/vendor/AscendGraph/opset_convert.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import torch
from dicp.dynamo_bridge.op_transformer import BackendPatternMatcherTransformer
from dicp.dynamo_bridge.compile_fx import is_torch_210
from dicp.vendor.AscendGraph.ascend_op import MatMul, CastToCpu, IdentityInp
from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer
from dicp.vendor.AscendGraph.pattern_replacement import (
ascend_pattern_matcher,
aten_patterns_cls_list,
ascend_patterns_cls_list
)

if is_torch_210:
from dicp.dynamo_bridge.op_transformer import BackendPatternMatcherTransformer
from dicp.vendor.AscendGraph.pattern_replacement import (
ascend_pattern_matcher,
aten_patterns_cls_list,
ascend_patterns_cls_list
)


# 该pass需要在FuseTransposeMatmul之后
Expand Down Expand Up @@ -74,13 +77,14 @@ def symint_in_inputs(nodes):
def ascendgraph_opset_convert(
gm: torch.fx.GraphModule,
):
gm = BackendPatternMatcherTransformer(
ascend_pattern_matcher, aten_patterns_cls_list).transform(gm)
if is_torch_210:
gm = BackendPatternMatcherTransformer(
ascend_pattern_matcher, aten_patterns_cls_list).transform(gm)
gm = AtenToAscendTransformer(gm).transform()

# For bug in pytorch
# Avoid for dynamic shape
if not symint_in_inputs(list(gm.graph.nodes)):
if is_torch_210 and not symint_in_inputs(list(gm.graph.nodes)):
gm = BackendPatternMatcherTransformer(
ascend_pattern_matcher, ascend_patterns_cls_list).transform(gm)
gm = OutputMarkPass().transform(gm)
Expand Down
10 changes: 4 additions & 6 deletions dicp/test/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch._dynamo as dynamo
import torch_dipu
from dicp.dynamo_bridge import pt_patch # noqa F401
from dicp.dynamo_bridge.compile_fx import is_torch_210
torch.manual_seed(1)
random.seed(1)

Expand All @@ -22,12 +23,9 @@ def __init__(self, static_size, dynamic_size):


def update_dynamo_config(dynamic=False):
if dynamic:
dynamo.config.dynamic_shapes = True
dynamo.config.assume_static_by_default = False
else:
dynamo.config.dynamic_shapes = False
dynamo.config.assume_static_by_default = True
dynamo.config.dynamic_shapes = dynamic
if is_torch_210:
dynamo.config.assume_static_by_default = not dynamic


def get_device():
Expand Down
Loading