Skip to content

Commit

Permalink
dicp support dipu native memory format
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaochaoxing committed Jan 9, 2024
1 parent 51d1e3d commit 52ac645
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
7 changes: 1 addition & 6 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,11 @@ def placeholder(self, name, target, args, kwargs):
self.input_args.append(self.cur_node)

fake_tensor = self.cur_node.meta['val']

format = "NCHW"
index = -1

if isinstance(fake_tensor, torch.SymInt):
dims = [1]
data_type = "INT32"
format = "ND"
self.sym_to_inputs[fake_tensor.node.str()] = name
elif symint_in_shape(fake_tensor.shape):
# mention symint position in args
Expand All @@ -110,14 +107,12 @@ def placeholder(self, name, target, args, kwargs):
dims = list(fake_tensor.shape)
data_type = get_ascend_dtype(fake_tensor.dtype).upper()

if 'format' in self.cur_node.meta:
format = self.cur_node.meta['format']
# gen data_nodes
self.data_nodes.append({
"op_name": self.args_dict[name],
"op_type": "Data",
"dims": dims,
"format": format,
"format": self.cur_node.meta['native_memory_format'],
"data_type": data_type,
"cpp_data_type": data_type,
"index": index
Expand Down
21 changes: 13 additions & 8 deletions dicp/dicp/vendor/AscendGraph/opset_convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch_dipu
from dicp.dynamo_bridge.op_transformer import BackendPatternMatcherTransformer
from dicp.vendor.AscendGraph.ascend_op import MatMul, CastToCpu, IdentityInp
from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer
Expand All @@ -9,16 +10,20 @@
)


# 该pass需要在FuseTransposeMatmul之后
class ArgsTransDataPass:
def transform(self, gm: torch.fx.graph_module):
for n in gm.graph.nodes:
if n.op != 'call_function':
continue
if type(n.target) in [MatMul]:
for arg in n.args:
if arg.op == 'placeholder':
arg.meta['format'] = 'FRACTAL_NZ'
if n.op == 'placeholder':
fake_tensor = n.meta['val']
if not hasattr(torch_dipu, 'get_native_memory_format'):
n.meta['native_memory_format'] = 'ND'
continue
memo = fake_tensor.fake_mode.fake_tensor_converter.tensor_memo
for key in memo:
if id(memo[key].fake_device) == id(fake_tensor.fake_device):
memory_format = torch_dipu.get_native_memory_format(key())
n.meta['native_memory_format'] = str(memory_format.value)
break
return gm


Expand Down Expand Up @@ -84,5 +89,5 @@ def ascendgraph_opset_convert(
gm = BackendPatternMatcherTransformer(
ascend_pattern_matcher, ascend_patterns_cls_list).transform(gm)
gm = OutputMarkPass().transform(gm)
# gm = ArgsTransDataPass().transform(gm)
gm = ArgsTransDataPass().transform(gm)
return gm

0 comments on commit 52ac645

Please sign in to comment.