From 52ac6450c7f6857bc92edf59e7aa10233c4fb067 Mon Sep 17 00:00:00 2001 From: zhaochaoxing Date: Tue, 9 Jan 2024 02:45:15 +0000 Subject: [PATCH] dicp support dipu native memory format --- .../dicp/vendor/AscendGraph/codegen/ascend.py | 7 +------ dicp/dicp/vendor/AscendGraph/opset_convert.py | 21 ++++++++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index c32b8e7974..35fe6778e8 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -77,14 +77,11 @@ def placeholder(self, name, target, args, kwargs): self.input_args.append(self.cur_node) fake_tensor = self.cur_node.meta['val'] - - format = "NCHW" index = -1 if isinstance(fake_tensor, torch.SymInt): dims = [1] data_type = "INT32" - format = "ND" self.sym_to_inputs[fake_tensor.node.str()] = name elif symint_in_shape(fake_tensor.shape): # mention symint position in args @@ -110,14 +107,12 @@ def placeholder(self, name, target, args, kwargs): dims = list(fake_tensor.shape) data_type = get_ascend_dtype(fake_tensor.dtype).upper() - if 'format' in self.cur_node.meta: - format = self.cur_node.meta['format'] # gen data_nodes self.data_nodes.append({ "op_name": self.args_dict[name], "op_type": "Data", "dims": dims, - "format": format, + "format": self.cur_node.meta['native_memory_format'], "data_type": data_type, "cpp_data_type": data_type, "index": index diff --git a/dicp/dicp/vendor/AscendGraph/opset_convert.py b/dicp/dicp/vendor/AscendGraph/opset_convert.py index 9913055e70..6c3b983f0e 100644 --- a/dicp/dicp/vendor/AscendGraph/opset_convert.py +++ b/dicp/dicp/vendor/AscendGraph/opset_convert.py @@ -1,4 +1,5 @@ import torch +import torch_dipu from dicp.dynamo_bridge.op_transformer import BackendPatternMatcherTransformer from dicp.vendor.AscendGraph.ascend_op import MatMul, CastToCpu, IdentityInp from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer @@ -9,16 +10,20 @@ ) -# 该pass需要在FuseTransposeMatmul之后 class ArgsTransDataPass: def transform(self, gm: torch.fx.graph_module): for n in gm.graph.nodes: - if n.op != 'call_function': - continue - if type(n.target) in [MatMul]: - for arg in n.args: - if arg.op == 'placeholder': - arg.meta['format'] = 'FRACTAL_NZ' + if n.op == 'placeholder': + fake_tensor = n.meta['val'] + if not hasattr(torch_dipu, 'get_native_memory_format'): + n.meta['native_memory_format'] = 'ND' + continue + memo = fake_tensor.fake_mode.fake_tensor_converter.tensor_memo + for key in memo: + if id(memo[key].fake_device) == id(fake_tensor.fake_device): + memory_format = torch_dipu.get_native_memory_format(key()) + n.meta['native_memory_format'] = str(memory_format.value) + break return gm @@ -84,5 +89,5 @@ def ascendgraph_opset_convert( gm = BackendPatternMatcherTransformer( ascend_pattern_matcher, ascend_patterns_cls_list).transform(gm) gm = OutputMarkPass().transform(gm) - # gm = ArgsTransDataPass().transform(gm) + gm = ArgsTransDataPass().transform(gm) return gm