Skip to content

Commit

Permalink
Merge branch 'main' into cjt/resnet50_op_infer_torch211
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinfromTJ authored Jan 24, 2024
2 parents 89577ae + 20f81ab commit ce925e4
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 11 deletions.
19 changes: 16 additions & 3 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def placeholder(self, name, target, args, kwargs):
self.input_args.append(self.cur_node)

fake_tensor = self.cur_node.meta['val']

format = "NCHW"
index = -1

Expand Down Expand Up @@ -110,8 +109,8 @@ def placeholder(self, name, target, args, kwargs):
dims = list(fake_tensor.shape)
data_type = get_ascend_dtype(fake_tensor.dtype).upper()

if 'format' in self.cur_node.meta:
format = self.cur_node.meta['format']
if 'native_memory_format' in self.cur_node.meta:
format = self.cur_node.meta['native_memory_format']
# gen data_nodes
self.data_nodes.append({
"op_name": self.args_dict[name],
Expand Down Expand Up @@ -1531,6 +1530,20 @@ def GatherElements(name, x, index, dim):
op.set_input("index", index)
op.set_attr_int("dim", dim)
return op.to_node()

@staticmethod
def AdaptiveAvgPool2D(name, x, output_size):
op = OP(name, "AdaptiveAvgPool2d")
op.set_input("x", x)
op.set_attr_list_int("output_size", output_size)
return op.to_node()

@staticmethod
def AdaptiveAvgPool2DGrad(name, input_grad, orig_input_shape):
op = OP(name, "AdaptiveAvgPool2dGrad")
op.set_input("input_grad", input_grad)
op.set_attr_list_int("orig_input_shape", orig_input_shape)
return op.to_node()

@staticmethod
def AdaptiveAvgPool2D(name, x, output_size):
Expand Down
17 changes: 17 additions & 0 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,23 @@ def NativeDropoutBackward(self, grad_output, mask, scale):
p = 1. - scale
prob_op = self.get_const_proxy(float(p), dtype)
return self.get_proxy(ascend_op.DropOutDoMaskV3, (grad_output, mask, prob_op))

@register_conversion([torch.ops.aten._adaptive_avg_pool2d.default])
def adaptiveavgpool2d(self, x, output_size):
assert isinstance(output_size, int) or ( len(output_size) in range(1,3) and any(output_size) )
if not isinstance(output_size, list):
if isinstance(output_size, tuple):
output_size = list(output_size)
elif isinstance(output_size, int):
output_size = [output_size, output_size]
else:
raise RuntimeError("not supported output type!")
return self.get_proxy(ascend_op.AdaptiveAvgPool2D, (x, output_size))

@register_conversion([torch.ops.aten._adaptive_avg_pool2d_backward.default])
def adaptiveavgpool2dBackward(self, grad, input):
input_shape = list(input.node.meta['val'].shape)
return self.get_proxy(ascend_op.AdaptiveAvgPool2DGrad, (grad, input_shape))

@register_conversion([torch.ops.aten._adaptive_avg_pool2d.default])
def adaptiveavgpool2d(self, x, output_size):
Expand Down
17 changes: 10 additions & 7 deletions dicp/dicp/vendor/AscendGraph/opset_convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch_dipu
from dicp.dynamo_bridge.compile_fx import is_torch_210
from dicp.vendor.AscendGraph.ascend_op import MatMul, CastToCpu, IdentityInp
from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer
Expand All @@ -12,16 +13,17 @@
)


# 该pass需要在FuseTransposeMatmul之后
class ArgsTransDataPass:
def transform(self, gm: torch.fx.graph_module):
for n in gm.graph.nodes:
if n.op != 'call_function':
continue
if type(n.target) in [MatMul]:
for arg in n.args:
if arg.op == 'placeholder':
arg.meta['format'] = 'FRACTAL_NZ'
if hasattr(n, 'op') and n.op == 'placeholder':
fake_tensor = n.meta['val']
memo = fake_tensor.fake_mode.fake_tensor_converter.tensor_memo
for key in memo:
if id(memo[key].fake_device) == id(fake_tensor.fake_device):
memory_format = torch_dipu.get_native_memory_format(key())
n.meta['native_memory_format'] = str(memory_format.name)
break
return gm


Expand Down Expand Up @@ -88,5 +90,6 @@ def ascendgraph_opset_convert(
gm = BackendPatternMatcherTransformer(
ascend_pattern_matcher, ascend_patterns_cls_list).transform(gm)
gm = OutputMarkPass().transform(gm)
# uncomment this after DIOPI support pytorch2.1.1
# gm = ArgsTransDataPass().transform(gm)
return gm
4 changes: 3 additions & 1 deletion dicp/test/ascend_scripts/models/static.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[pytest]
testpaths = ../../model
python_files = test_llama.py
python_files =
test_llama.py
test_stable_diffusion.py
test_resnet50.py

0 comments on commit ce925e4

Please sign in to comment.