Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tangzhiyi11 committed Jan 17, 2024
1 parent 6b06013 commit 65eba28
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 15 deletions.
20 changes: 10 additions & 10 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __init__(self):
super().__init__("Range")

def infer_result(self, start, limit=None, delta=None):
start, start_dtype, _ = get_op_const_arg_kwarg(start)
limit, limit_dtype, _ = get_op_const_arg_kwarg(limit)
delta, delta_dtype, _ = get_op_const_arg_kwarg(delta)
start, start_dtype, _, _ = get_op_const_arg_kwarg(start)
limit, limit_dtype, _, _ = get_op_const_arg_kwarg(limit)
delta, delta_dtype, _, _ = get_op_const_arg_kwarg(delta)

assert start is not None, (
self.__class__.__name__ + ": input 'start' can't be None!"
Expand Down Expand Up @@ -427,7 +427,7 @@ def __init__(self):
def infer_result(self, base, expo):
base, base_shape, base_dim, base_dtype = get_fake_tensor_meta_val(base)
if isinstance(expo, Tuple): # Const
expo, _, expo_shape = get_op_const_arg_kwarg(expo)
expo, _, expo_shape, _ = get_op_const_arg_kwarg(expo)
expo_dtype = type(expo[0]) if len(expo) > 0 else base_dtype
else: # fake Tensor
expo, expo_shape, expo_dim, expo_dtype = get_fake_tensor_meta_val(expo)
Expand Down Expand Up @@ -564,7 +564,7 @@ def __init__(self):
def infer_result(
self, shape, dtype, layout, device, memory_format=torch.contiguous_format
):
shape, _, _ = get_op_const_arg_kwarg(shape)
shape, _, _, _ = get_op_const_arg_kwarg(shape)
return torch.empty(
shape,
dtype=dtype,
Expand Down Expand Up @@ -609,8 +609,8 @@ def __init__(self):
super().__init__("Fill")

def infer_result(self, dims, value):
_, value_dtype, _ = get_op_const_arg_kwarg(value)
shape, _, _ = get_op_const_arg_kwarg(dims)
_, value_dtype, _, _ = get_op_const_arg_kwarg(value)
shape, _, _, _ = get_op_const_arg_kwarg(dims)
return torch.empty(
shape, dtype=value_dtype, memory_format=torch.contiguous_format
)
Expand Down Expand Up @@ -718,8 +718,8 @@ def __init__(self):

def infer_result(self, x, offset, size):
x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x)
new_shape, _, _ = get_op_const_arg_kwarg(size)
offset, _, _ = get_op_const_arg_kwarg(offset)
new_shape, _, _, _ = get_op_const_arg_kwarg(size)
offset, _, _, _ = get_op_const_arg_kwarg(offset)
_, storage_offset = cal_stride_offset(new_shape, offset, x)
res = torch.as_strided(x, new_shape, x.stride(), storage_offset)
return res
Expand Down Expand Up @@ -764,7 +764,7 @@ def __init__(self):

def infer_result(self, x, shape_const_op, ori_op=None, params_passed=None):
x, _, _, x_dtype = get_fake_tensor_meta_val(x)
re_shape, _, _ = get_op_const_arg_kwarg(shape_const_op)
re_shape, _, _, _ = get_op_const_arg_kwarg(shape_const_op)
x_stride = list(x.stride())
res = torch.empty(re_shape, dtype=x_dtype, memory_format=get_memory_format(x))
if ori_op == "Select":
Expand Down
9 changes: 4 additions & 5 deletions dicp/dicp/vendor/AscendGraph/infer_res_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,11 @@ def get_op_const_arg_kwarg(
"""
new_args = const_arg[0]
len_args = len(new_args)
assert (
len_args >= 2 and len_args <= 3
), " :currently, op 'Const' support only 2 or 3 params passed!"
assert (len_args >= 2 and len_args <= 4)
arg0, dtype = new_args[0], new_args[1]
shape = new_args[2] if len(new_args) == 3 else None
return arg0, dtype, shape
shape = new_args[2] if len(new_args) >= 3 else None
ascend_format = new_args[3] if len(new_args) == 4 else None
return arg0, dtype, shape, ascend_format


"""analyze dtype,format"""
Expand Down

0 comments on commit 65eba28

Please sign in to comment.