Skip to content

Commit

Permalink
[dicp][ascend] infer op resinfo and run single op 240124 (#668)
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinfromTJ authored Jan 25, 2024
1 parent b906e2f commit 8516682
Show file tree
Hide file tree
Showing 9 changed files with 363 additions and 239 deletions.
12 changes: 7 additions & 5 deletions dicp/dicp/dynamo_bridge/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import torch
import torch.fx
from typing import Optional
from typing import List, Optional, Tuple
from torch._dynamo.utils import dynamo_timed
from torch._subclasses import FakeTensor, FakeTensorMode
from torch._inductor.codecache import cache_dir
Expand Down Expand Up @@ -61,10 +61,12 @@ def make_tensor_meta(x) -> Optional[TensorMetadata]:
else:
continue
if 'val' in n.meta and test_infer:
assert n.meta['val'].size() == fake_value.size(), "check infer size failed"
assert n.meta['val'].dtype == fake_value.dtype, "check infer dtype failed"
assert n.meta['val'].stride() == fake_value.stride(), "check infer stride failed"
assert n.meta['val'].storage_offset() == fake_value.storage_offset(), "check infer storage offset failed"
(n_meta_val, fake_val) = ((n.meta['val'],),(fake_value,)) if not isinstance(n.meta['val'],(Tuple,List)) else (n.meta['val'], fake_value)
for i,(meta_i,fv_i) in enumerate(zip(n_meta_val, fake_val)):
assert meta_i.size() == fv_i.size(), "check infer size failed"
assert meta_i.dtype == fv_i.dtype, "check infer dtype failed"
assert meta_i.stride() == fv_i.stride(), "check infer stride failed"
assert meta_i.storage_offset() == fv_i.storage_offset(), "check infer storage offset failed"
if 'val' not in n.meta:
n.meta['val'] = fake_value
n.meta["tensor_meta"] = make_tensor_meta(n.meta['val'])
Expand Down
82 changes: 68 additions & 14 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def __init__(self):
super().__init__("BroadcastTo")

def infer_result(self, x, shape):
x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x)
x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x)
if isinstance(shape, torch._subclasses.fake_tensor.FakeTensor): # case1: shape is a fakeTensor, like conversion for 'scatter' and 'where'
shape, shape_shape, shape_dim, shape_dtype = get_fake_tensor_meta_val(shape)
shape, shape_shape, _, _ = get_fake_tensor_meta_val(shape)
shape = shape_shape
elif isinstance(shape, Tuple): # case2: shape is tuple from 'Const' , like conversion for 'lt'
shape, _, _, _ =get_op_const_arg_kwarg(shape)
Expand Down Expand Up @@ -564,6 +564,8 @@ def __init__(self):

def infer_result(self, x, idx=None):
x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x)
if isinstance(x, (List, Tuple)):
return x[idx]
out_dtype = x_dtype
if x_dtype == torch.complex64: # for complex64
out_shape = list(x_shape)
Expand Down Expand Up @@ -594,8 +596,8 @@ class IdentityN(Operator):
def __init__(self):
super().__init__("IdentityN")

def infer_result(self, x):
return common_unary_op_infer(x)
def infer_result(self, *args, **kwargs):
return remove_nested_parentheses(args)


class Empty(Operator):
Expand Down Expand Up @@ -712,16 +714,6 @@ def __init__(self):
super().__init__("NLLLossGrad")


class BNTrainingReduce(Operator):
def __init__(self):
super().__init__("BNTrainingReduce")


class BNTrainingUpdate(Operator):
def __init__(self):
super().__init__("BNTrainingUpdate")


class BNTrainingUpdateGrad(Operator):
def __init__(self):
super().__init__("BNTrainingUpdateGrad")
Expand Down Expand Up @@ -932,11 +924,25 @@ class AdaptiveAvgPool2D(Operator):
def __init__(self):
super().__init__("AdaptiveAvgPool2D")

def infer_result(self, x, output_size):
_, x_shape, _, x_dtype = get_fake_tensor_meta_val(x)
batch_channel_size = list(x_shape)[:-2]
return torch.empty(
batch_channel_size + output_size,
dtype=x_dtype,
memory_format=get_memory_format(x),
)


class AdaptiveAvgPool2DGrad(Operator):
def __init__(self):
super().__init__("AdaptiveAvgPool2DGrad")

def infer_result(self, input_grad, orig_input_shape):
return common_unary_op_infer(
input_grad, spec_format=torch.contiguous_format, spec_shape=orig_input_shape
)


class MaxPoolGrad(Operator):
def __init__(self):
Expand Down Expand Up @@ -972,6 +978,54 @@ def infer_result(self, x, multiples):
return torch.ops.aten.repeat.default(x, multiples)


class BNTrainingReduce(Operator):
def __init__(self):
super().__init__("BNTrainingReduce")

def infer_result(self, x, x_shape, format, dtype):
# the output should be two 1D tensors(reduce_sum and reduce_square_sum) of same type,
# so it may not matter to return only a single tensor here
return reduce_op_infer(x, None, False) # TODO: return a list of two tensors


class BNTrainingUpdate(Operator):
def __init__(self):
super().__init__("BNTrainingUpdate")

def infer_result(
self,
x,
sum,
sum_idx,
square_sum,
square_idx,
weight,
bias,
running_mean,
running_var,
eps,
momentum,
):
_, x_shape, _, x_dtype = get_fake_tensor_meta_val(x)
channel_size = x_shape[1]
output_y = torch.empty(
x_shape, dtype=x_dtype, memory_format=get_memory_format(x)
)
output_mean = torch.empty(
[channel_size], dtype=torch.float32, memory_format=torch.contiguous_format
)
output_var = torch.empty(
[channel_size], dtype=torch.float32, memory_format=torch.contiguous_format
)
output_batch_mean = torch.empty(
[channel_size], dtype=torch.float32, memory_format=torch.contiguous_format
)
output_batch_var = torch.empty(
[channel_size], dtype=torch.float32, memory_format=torch.contiguous_format
)
return [output_y,output_mean,output_var,output_batch_mean,output_batch_var]


class TileWithAxis(Operator):
def __init__(self):
super().__init__("TileWithAxis")
Expand Down
14 changes: 14 additions & 0 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,20 @@ def AdaptiveAvgPool2DGrad(name, input_grad, orig_input_shape):
op.set_attr_list_int("orig_input_shape", orig_input_shape)
return op.to_node()

@staticmethod
def AdaptiveAvgPool2D(name, x, output_size):
op = OP(name, "AdaptiveAvgPool2d")
op.set_input("x", x)
op.set_attr_list_int("output_size", output_size)
return op.to_node()

@staticmethod
def AdaptiveAvgPool2DGrad(name, input_grad, orig_input_shape):
op = OP(name, "AdaptiveAvgPool2dGrad")
op.set_input("input_grad", input_grad)
op.set_attr_list_int("orig_input_shape", orig_input_shape)
return op.to_node()

@staticmethod
def Tril(name, x, diagonal=0):
op = OP(name, "Tril")
Expand Down
17 changes: 17 additions & 0 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,23 @@ def adaptiveavgpool2dBackward(self, grad, input):
input_shape = list(input.node.meta['val'].shape)
return self.get_proxy(ascend_op.AdaptiveAvgPool2DGrad, (grad, input_shape))

@register_conversion([torch.ops.aten._adaptive_avg_pool2d.default])
def adaptiveavgpool2d(self, x, output_size):
assert isinstance(output_size, int) or ( len(output_size) in range(1,3) and any(output_size) )
if not isinstance(output_size, list):
if isinstance(output_size, tuple):
output_size = list(output_size)
elif isinstance(output_size, int):
output_size = [output_size, output_size]
else:
raise RuntimeError("not supported output size!")
return self.get_proxy(ascend_op.AdaptiveAvgPool2D, (x, output_size))

@register_conversion([torch.ops.aten._adaptive_avg_pool2d_backward.default])
def adaptiveavgpool2dBackward(self, grad, input):
input_shape = list(input.node.meta['val'].shape)
return self.get_proxy(ascend_op.AdaptiveAvgPool2DGrad, (grad, input_shape))

@register_conversion(torch.ops.aten.tril.default)
def Tril(self, x, diagonal=0):
return self.get_proxy(ascend_op.Tril, (x, diagonal))
Expand Down
Loading

0 comments on commit 8516682

Please sign in to comment.