Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaochaoxing committed Jan 23, 2024
1 parent 4db9a20 commit e8c2f65
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
5 changes: 4 additions & 1 deletion dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,15 @@ def placeholder(self, name, target, args, kwargs):
dims = list(fake_tensor.shape)
data_type = get_ascend_dtype(fake_tensor.dtype).upper()

native_memory_format = "ND"
if 'native_memory_format' in self.cur_node.meta:
native_memory_format = self.cur_node.meta['native_memory_format']
# gen data_nodes
self.data_nodes.append({
"op_name": self.args_dict[name],
"op_type": "Data",
"dims": dims,
"format": self.cur_node.meta['native_memory_format'],
"format": native_memory_format,
"data_type": data_type,
"cpp_data_type": data_type,
"index": index
Expand Down
7 changes: 2 additions & 5 deletions dicp/dicp/vendor/AscendGraph/opset_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@
class ArgsTransDataPass:
def transform(self, gm: torch.fx.graph_module):
for n in gm.graph.nodes:
if n.op == 'placeholder':
if hasattr(n, 'op') and n.op == 'placeholder':
fake_tensor = n.meta['val']
if not hasattr(torch_dipu, 'get_native_memory_format'):
n.meta['native_memory_format'] = 'ND'
continue
memo = fake_tensor.fake_mode.fake_tensor_converter.tensor_memo
for key in memo:
if id(memo[key].fake_device) == id(fake_tensor.fake_device):
memory_format = torch_dipu.get_native_memory_format(key())
n.meta['native_memory_format'] = str(memory_format.value)
n.meta['native_memory_format'] = str(memory_format.name)
break
return gm

Expand Down

0 comments on commit e8c2f65

Please sign in to comment.