Skip to content

Commit

Permalink
[DICP][ascend] use get_const_proxy to construct const op (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
tangzhiyi11 authored Jan 19, 2024
1 parent 0cb19a2 commit b17aaee
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 140 deletions.
20 changes: 10 additions & 10 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __init__(self):
super().__init__("Range")

def infer_result(self, start, limit=None, delta=None):
start, start_dtype, _ = get_op_const_arg_kwarg(start)
limit, limit_dtype, _ = get_op_const_arg_kwarg(limit)
delta, delta_dtype, _ = get_op_const_arg_kwarg(delta)
start, start_dtype, _, _ = get_op_const_arg_kwarg(start)
limit, limit_dtype, _, _ = get_op_const_arg_kwarg(limit)
delta, delta_dtype, _, _ = get_op_const_arg_kwarg(delta)

assert start is not None, (
self.__class__.__name__ + ": input 'start' can't be None!"
Expand Down Expand Up @@ -427,7 +427,7 @@ def __init__(self):
def infer_result(self, base, expo):
base, base_shape, base_dim, base_dtype = get_fake_tensor_meta_val(base)
if isinstance(expo, Tuple): # Const
expo, _, expo_shape = get_op_const_arg_kwarg(expo)
expo, _, expo_shape, _ = get_op_const_arg_kwarg(expo)
expo_dtype = type(expo[0]) if len(expo) > 0 else base_dtype
else: # fake Tensor
expo, expo_shape, expo_dim, expo_dtype = get_fake_tensor_meta_val(expo)
Expand Down Expand Up @@ -564,7 +564,7 @@ def __init__(self):
def infer_result(
self, shape, dtype, layout, device, memory_format=torch.contiguous_format
):
shape, _, _ = get_op_const_arg_kwarg(shape)
shape, _, _, _ = get_op_const_arg_kwarg(shape)
return torch.empty(
shape,
dtype=dtype,
Expand Down Expand Up @@ -609,8 +609,8 @@ def __init__(self):
super().__init__("Fill")

def infer_result(self, dims, value):
_, value_dtype, _ = get_op_const_arg_kwarg(value)
shape, _, _ = get_op_const_arg_kwarg(dims)
_, value_dtype, _, _ = get_op_const_arg_kwarg(value)
shape, _, _, _ = get_op_const_arg_kwarg(dims)
return torch.empty(
shape, dtype=value_dtype, memory_format=torch.contiguous_format
)
Expand Down Expand Up @@ -718,8 +718,8 @@ def __init__(self):

def infer_result(self, x, offset, size):
x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x)
new_shape, _, _ = get_op_const_arg_kwarg(size)
offset, _, _ = get_op_const_arg_kwarg(offset)
new_shape, _, _, _ = get_op_const_arg_kwarg(size)
offset, _, _, _ = get_op_const_arg_kwarg(offset)
_, storage_offset = cal_stride_offset(new_shape, offset, x)
res = torch.as_strided(x, new_shape, x.stride(), storage_offset)
return res
Expand Down Expand Up @@ -764,7 +764,7 @@ def __init__(self):

def infer_result(self, x, shape_const_op, ori_op=None, params_passed=None):
x, _, _, x_dtype = get_fake_tensor_meta_val(x)
re_shape, _, _ = get_op_const_arg_kwarg(shape_const_op)
re_shape, _, _, _ = get_op_const_arg_kwarg(shape_const_op)
x_stride = list(x.stride())
res = torch.empty(re_shape, dtype=x_dtype, memory_format=get_memory_format(x))
if ori_op == "Select":
Expand Down
1 change: 1 addition & 0 deletions dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ ge::DataType get_ascend_datatype(const std::string& data_type) {
{"FLOAT", ge::DataType::DT_FLOAT}, {"FLOAT16", ge::DataType::DT_FLOAT16},
{"INT32", ge::DataType::DT_INT32}, {"INT64", ge::DataType::DT_INT64},
{"BOOL", ge::DataType::DT_BOOL}, {"UINT8", ge::DataType::DT_UINT8},
{"BF16", ge::DataType::DT_BF16},
};
if (datatype_map.count(data_type) > 0) {
return datatype_map[data_type];
Expand Down
12 changes: 8 additions & 4 deletions dicp/dicp/vendor/AscendGraph/codegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,29 @@ def get_ascend_dtype_num(dtype: str):
return 4
elif dtype == "UINT64":
return 10
elif dtype == "BF16":
return 27
else:
raise RuntimeError("unknow torch data tyep type in get_ascend_dtype!")
raise RuntimeError(f"unknow torch data type ({dtype}) in get_ascend_dtype_num!")


def get_ascend_dtype(dtype: torch.dtype) -> str:
if dtype == torch.bool:
return "BOOL"
elif dtype == torch.int64:
return "INT64"
elif dtype == torch.float32:
elif dtype in [torch.float32, torch.float]:
return "FLOAT"
elif dtype == torch.float16:
return "FLOAT16"
elif dtype == torch.int32:
return "INT32"
elif dtype == torch.complex64:
return "COMPLEX64"
elif dtype == torch.bfloat16:
return "BF16"
else:
raise RuntimeError("unknow torch data tyep type in get_ascend_dtype!")
raise RuntimeError(f"unknow torch data type ({dtype}) in get_ascend_dtype!")


def get_cpp_dtype(dtype: torch.dtype) -> str:
Expand All @@ -56,4 +60,4 @@ def get_cpp_dtype(dtype: torch.dtype) -> str:
elif dtype == torch.int32:
return "INT32"
else:
raise RuntimeError("unknow torch data tyep type in get_cpp_dtype!")
raise RuntimeError(f"unknow torch data type ({dtype}) in get_cpp_dtype!")
Loading

0 comments on commit b17aaee

Please sign in to comment.