diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index 09c8f302c..346d3cc0e 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -65,9 +65,9 @@ def __init__(self): super().__init__("Range") def infer_result(self, start, limit=None, delta=None): - start, start_dtype, _ = get_op_const_arg_kwarg(start) - limit, limit_dtype, _ = get_op_const_arg_kwarg(limit) - delta, delta_dtype, _ = get_op_const_arg_kwarg(delta) + start, start_dtype, _, _ = get_op_const_arg_kwarg(start) + limit, limit_dtype, _, _ = get_op_const_arg_kwarg(limit) + delta, delta_dtype, _, _ = get_op_const_arg_kwarg(delta) assert start is not None, ( self.__class__.__name__ + ": input 'start' can't be None!" @@ -427,7 +427,7 @@ def __init__(self): def infer_result(self, base, expo): base, base_shape, base_dim, base_dtype = get_fake_tensor_meta_val(base) if isinstance(expo, Tuple): # Const - expo, _, expo_shape = get_op_const_arg_kwarg(expo) + expo, _, expo_shape, _ = get_op_const_arg_kwarg(expo) expo_dtype = type(expo[0]) if len(expo) > 0 else base_dtype else: # fake Tensor expo, expo_shape, expo_dim, expo_dtype = get_fake_tensor_meta_val(expo) @@ -564,7 +564,7 @@ def __init__(self): def infer_result( self, shape, dtype, layout, device, memory_format=torch.contiguous_format ): - shape, _, _ = get_op_const_arg_kwarg(shape) + shape, _, _, _ = get_op_const_arg_kwarg(shape) return torch.empty( shape, dtype=dtype, @@ -609,8 +609,8 @@ def __init__(self): super().__init__("Fill") def infer_result(self, dims, value): - _, value_dtype, _ = get_op_const_arg_kwarg(value) - shape, _, _ = get_op_const_arg_kwarg(dims) + _, value_dtype, _, _ = get_op_const_arg_kwarg(value) + shape, _, _, _ = get_op_const_arg_kwarg(dims) return torch.empty( shape, dtype=value_dtype, memory_format=torch.contiguous_format ) @@ -718,8 +718,8 @@ def __init__(self): def infer_result(self, x, offset, size): x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x) - new_shape, _, _ = get_op_const_arg_kwarg(size) - offset, _, _ = get_op_const_arg_kwarg(offset) + new_shape, _, _, _ = get_op_const_arg_kwarg(size) + offset, _, _, _ = get_op_const_arg_kwarg(offset) _, storage_offset = cal_stride_offset(new_shape, offset, x) res = torch.as_strided(x, new_shape, x.stride(), storage_offset) return res @@ -764,7 +764,7 @@ def __init__(self): def infer_result(self, x, shape_const_op, ori_op=None, params_passed=None): x, _, _, x_dtype = get_fake_tensor_meta_val(x) - re_shape, _, _ = get_op_const_arg_kwarg(shape_const_op) + re_shape, _, _, _ = get_op_const_arg_kwarg(shape_const_op) x_stride = list(x.stride()) res = torch.empty(re_shape, dtype=x_dtype, memory_format=get_memory_format(x)) if ori_op == "Select": diff --git a/dicp/dicp/vendor/AscendGraph/infer_res_utils.py b/dicp/dicp/vendor/AscendGraph/infer_res_utils.py index 10cd5c167..0f7db62c0 100644 --- a/dicp/dicp/vendor/AscendGraph/infer_res_utils.py +++ b/dicp/dicp/vendor/AscendGraph/infer_res_utils.py @@ -50,12 +50,11 @@ def get_op_const_arg_kwarg( """ new_args = const_arg[0] len_args = len(new_args) - assert ( - len_args >= 2 and len_args <= 3 - ), " :currently, op 'Const' support only 2 or 3 params passed!" + assert (len_args >= 2 and len_args <= 4) arg0, dtype = new_args[0], new_args[1] - shape = new_args[2] if len(new_args) == 3 else None - return arg0, dtype, shape + shape = new_args[2] if len(new_args) >= 3 else None + ascend_format = new_args[3] if len(new_args) == 4 else None + return arg0, dtype, shape, ascend_format """analyze dtype,format"""