diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index c466ff1e4..5bd76775f 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -107,12 +107,15 @@ def placeholder(self, name, target, args, kwargs): dims = list(fake_tensor.shape) data_type = get_ascend_dtype(fake_tensor.dtype).upper() + native_memory_format = "ND" + if 'native_memory_format' in self.cur_node.meta: + native_memory_format = self.cur_node.meta['native_memory_format'] # gen data_nodes self.data_nodes.append({ "op_name": self.args_dict[name], "op_type": "Data", "dims": dims, - "format": self.cur_node.meta['native_memory_format'], + "format": native_memory_format, "data_type": data_type, "cpp_data_type": data_type, "index": index diff --git a/dicp/dicp/vendor/AscendGraph/opset_convert.py b/dicp/dicp/vendor/AscendGraph/opset_convert.py index ad4864e99..39c2ef172 100644 --- a/dicp/dicp/vendor/AscendGraph/opset_convert.py +++ b/dicp/dicp/vendor/AscendGraph/opset_convert.py @@ -16,16 +16,13 @@ class ArgsTransDataPass: def transform(self, gm: torch.fx.graph_module): for n in gm.graph.nodes: - if n.op == 'placeholder': + if hasattr(n, 'op') and n.op == 'placeholder': fake_tensor = n.meta['val'] - if not hasattr(torch_dipu, 'get_native_memory_format'): - n.meta['native_memory_format'] = 'ND' - continue memo = fake_tensor.fake_mode.fake_tensor_converter.tensor_memo for key in memo: if id(memo[key].fake_device) == id(fake_tensor.fake_device): memory_format = torch_dipu.get_native_memory_format(key()) - n.meta['native_memory_format'] = str(memory_format.value) + n.meta['native_memory_format'] = str(memory_format.name) break return gm