diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index 09c8f302c..346d3cc0e 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -65,9 +65,9 @@ def __init__(self): super().__init__("Range") def infer_result(self, start, limit=None, delta=None): - start, start_dtype, _ = get_op_const_arg_kwarg(start) - limit, limit_dtype, _ = get_op_const_arg_kwarg(limit) - delta, delta_dtype, _ = get_op_const_arg_kwarg(delta) + start, start_dtype, _, _ = get_op_const_arg_kwarg(start) + limit, limit_dtype, _, _ = get_op_const_arg_kwarg(limit) + delta, delta_dtype, _, _ = get_op_const_arg_kwarg(delta) assert start is not None, ( self.__class__.__name__ + ": input 'start' can't be None!" @@ -427,7 +427,7 @@ def __init__(self): def infer_result(self, base, expo): base, base_shape, base_dim, base_dtype = get_fake_tensor_meta_val(base) if isinstance(expo, Tuple): # Const - expo, _, expo_shape = get_op_const_arg_kwarg(expo) + expo, _, expo_shape, _ = get_op_const_arg_kwarg(expo) expo_dtype = type(expo[0]) if len(expo) > 0 else base_dtype else: # fake Tensor expo, expo_shape, expo_dim, expo_dtype = get_fake_tensor_meta_val(expo) @@ -564,7 +564,7 @@ def __init__(self): def infer_result( self, shape, dtype, layout, device, memory_format=torch.contiguous_format ): - shape, _, _ = get_op_const_arg_kwarg(shape) + shape, _, _, _ = get_op_const_arg_kwarg(shape) return torch.empty( shape, dtype=dtype, @@ -609,8 +609,8 @@ def __init__(self): super().__init__("Fill") def infer_result(self, dims, value): - _, value_dtype, _ = get_op_const_arg_kwarg(value) - shape, _, _ = get_op_const_arg_kwarg(dims) + _, value_dtype, _, _ = get_op_const_arg_kwarg(value) + shape, _, _, _ = get_op_const_arg_kwarg(dims) return torch.empty( shape, dtype=value_dtype, memory_format=torch.contiguous_format ) @@ -718,8 +718,8 @@ def __init__(self): def infer_result(self, x, offset, size): x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x) - new_shape, _, _ = get_op_const_arg_kwarg(size) - offset, _, _ = get_op_const_arg_kwarg(offset) + new_shape, _, _, _ = get_op_const_arg_kwarg(size) + offset, _, _, _ = get_op_const_arg_kwarg(offset) _, storage_offset = cal_stride_offset(new_shape, offset, x) res = torch.as_strided(x, new_shape, x.stride(), storage_offset) return res @@ -764,7 +764,7 @@ def __init__(self): def infer_result(self, x, shape_const_op, ori_op=None, params_passed=None): x, _, _, x_dtype = get_fake_tensor_meta_val(x) - re_shape, _, _ = get_op_const_arg_kwarg(shape_const_op) + re_shape, _, _, _ = get_op_const_arg_kwarg(shape_const_op) x_stride = list(x.stride()) res = torch.empty(re_shape, dtype=x_dtype, memory_format=get_memory_format(x)) if ori_op == "Select": diff --git a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h index 2cbacf3bc..a40b94319 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h +++ b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h @@ -148,6 +148,7 @@ ge::DataType get_ascend_datatype(const std::string& data_type) { {"FLOAT", ge::DataType::DT_FLOAT}, {"FLOAT16", ge::DataType::DT_FLOAT16}, {"INT32", ge::DataType::DT_INT32}, {"INT64", ge::DataType::DT_INT64}, {"BOOL", ge::DataType::DT_BOOL}, {"UINT8", ge::DataType::DT_UINT8}, + {"BF16", ge::DataType::DT_BF16}, }; if (datatype_map.count(data_type) > 0) { return datatype_map[data_type]; diff --git a/dicp/dicp/vendor/AscendGraph/codegen/utils.py b/dicp/dicp/vendor/AscendGraph/codegen/utils.py index 24ec294a6..29718f084 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/utils.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/utils.py @@ -27,8 +27,10 @@ def get_ascend_dtype_num(dtype: str): return 4 elif dtype == "UINT64": return 10 + elif dtype == "BF16": + return 27 else: - raise RuntimeError("unknow torch data tyep type in get_ascend_dtype!") + raise RuntimeError(f"unknow torch data type ({dtype}) in get_ascend_dtype_num!") def get_ascend_dtype(dtype: torch.dtype) -> str: @@ -36,7 +38,7 @@ def get_ascend_dtype(dtype: torch.dtype) -> str: return "BOOL" elif dtype == torch.int64: return "INT64" - elif dtype == torch.float32: + elif dtype in [torch.float32, torch.float]: return "FLOAT" elif dtype == torch.float16: return "FLOAT16" @@ -44,8 +46,10 @@ def get_ascend_dtype(dtype: torch.dtype) -> str: return "INT32" elif dtype == torch.complex64: return "COMPLEX64" + elif dtype == torch.bfloat16: + return "BF16" else: - raise RuntimeError("unknow torch data tyep type in get_ascend_dtype!") + raise RuntimeError(f"unknow torch data type ({dtype}) in get_ascend_dtype!") def get_cpp_dtype(dtype: torch.dtype) -> str: @@ -56,4 +60,4 @@ def get_cpp_dtype(dtype: torch.dtype) -> str: elif dtype == torch.int32: return "INT32" else: - raise RuntimeError("unknow torch data tyep type in get_cpp_dtype!") + raise RuntimeError(f"unknow torch data type ({dtype}) in get_cpp_dtype!") diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 129f183fb..66bb39f75 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -37,6 +37,21 @@ def get_reduction_str(r): raise RuntimeError("not supported yet!") +def try_to_get_dtype(x): + if isinstance(x, torch.fx.proxy.Proxy): + if hasattr(x.node, "meta") and "val" in x.node.meta.keys(): + return x.node.meta['val'].dtype + else: + return None + return None + + +def is_dicp_cpp_support_dtype(dtype): + if dtype in [torch.float32, torch.float, torch.int32, torch.int64]: + return True + return False + + def register_conversion(aten_fn): """ Shim to support decorator syntax. @@ -124,19 +139,45 @@ def get_shape_proxy(self, shape): return self.get_proxy( ascend_op.Const, (shape, torch.int32, [len(shape)])) - def get_param_proxy(self, param, type, target_shape): + def get_const_proxy(self, param, dtype, format=None, target_shape=None): if not isinstance(param, torch.fx.proxy.Proxy) and not isinstance(param, FakeTensor): + format = "ND" if format is None else format + if target_shape is None: + shape = [len(param)] if isinstance(param, list) else [] + else: + shape = target_shape param = param if isinstance(param, list) else [param] - param = self.get_proxy( - ascend_op.Const, (param, type, [len(param)])) + if is_dicp_cpp_support_dtype(dtype): + param = self.get_proxy( + ascend_op.Const, (param, dtype, shape, format)) + else: + const = self.get_proxy( + ascend_op.Const, (param, torch.float32, shape, format)) + param = self.get_proxy(ascend_op.Cast, (const, get_ascend_dtype(dtype))) return param + def promote_dtype(self, *args, target_dtype): + result = [] + ascend_dtype = get_ascend_dtype(target_dtype) + for arg in args: + if isinstance(arg, torch.fx.proxy.Proxy): + current_dtype = try_to_get_dtype(arg) + if current_dtype and current_dtype == target_dtype: + result.append(arg) + continue + # do cast if: + # 1. unable to get tensor dtype + # 2. current_dtype != target_dtype + result.append(self.get_proxy(ascend_op.Cast, (arg, ascend_dtype))) + else: + raise RuntimeError("Not implemented") + return tuple(result) if len(result) > 1 else result[0] + def mul_scalar(self, x, y): out_dtype = fx_traceback.get_current_meta()['val'].dtype # Muls support bfloat16, int32, int16, float16, float32, complex32, complex64. if out_dtype not in [torch.float, torch.float16, torch.int32]: - y_shape = list(x.node.meta['val'].shape) - y_op = self.get_param_proxy(y, out_dtype, y_shape) + y_op = self.get_const_proxy(y, out_dtype) return self.get_proxy(ascend_op.Mul, (x, y_op)) return self.get_proxy(ascend_op.Muls, (x, y)) @@ -161,15 +202,10 @@ def mul_complex64(self, x, y): return out def binary_cmp_cast_input(self, x, y): + x_dtype = x.node.meta["val"].dtype if not isinstance(y, torch.fx.proxy.Proxy): - x_dtype = x.node.meta["val"].dtype - const_dtype = torch.float32 if x_dtype == torch.float16 else x_dtype - y_shape = list(x.node.meta["val"].shape) - y = self.get_param_proxy(y, const_dtype, y_shape) - if x_dtype == torch.float16: - y = self.get_proxy(ascend_op.Cast, (y, "FLOAT16")) + y = self.get_const_proxy(y, x_dtype) else: - x_dtype = x.node.meta["val"].dtype y_dtype = y.node.meta["val"].dtype if x_dtype != y_dtype: y = self.get_proxy(ascend_op.Cast, (y, get_ascend_dtype(x_dtype))) @@ -193,19 +229,7 @@ def mul(self, x, y): return self.mul_scalar(x, y) x_shape = list(x.node.meta['val'].shape) y_shape = list(y.node.meta['val'].shape) - x_dtype = x.node.meta['val'].dtype - y_dtype = y.node.meta['val'].dtype - # handling with broadcasting cases - if np.prod(x_shape) < np.prod(y_shape): - x = self.get_param_proxy(x, None, y_shape) - elif np.prod(x_shape) > np.prod(y_shape): - y = self.get_param_proxy(y, None, x_shape) - if x_dtype != out_dtype: - x = self.get_proxy( - ascend_op.Cast, (x, get_ascend_dtype(out_dtype)), {}) - if y_dtype != out_dtype: - y = self.get_proxy( - ascend_op.Cast, (y, get_ascend_dtype(out_dtype)), {}) + x, y = self.promote_dtype(x, y, target_dtype=out_dtype) return self.get_proxy(ascend_op.Mul, (x, y), {}) @register_conversion(torch.ops.aten.add.Tensor) @@ -213,20 +237,12 @@ def add(self, x, y, alpha: Optional[Number] = 1): out_dtype = fx_traceback.get_current_meta()['val'].dtype if not isinstance(y, torch.fx.proxy.Proxy): y = y * alpha - if out_dtype == torch.float or out_dtype == torch.float16: + if out_dtype in [torch.float, torch.float16]: return self.get_proxy(ascend_op.Adds, (x, float(y)), {}) - else: - y = self.get_proxy(ascend_op.Const, ([y], out_dtype, [])) + y = self.get_const_proxy(y, out_dtype) else: - x_dtype = x.node.meta['val'].dtype - y_dtype = y.node.meta['val'].dtype y = self.mul(y, alpha) - if x_dtype != out_dtype: - x = self.get_proxy( - ascend_op.Cast, (x, get_ascend_dtype(out_dtype)), {}) - if y_dtype != out_dtype: - y = self.get_proxy( - ascend_op.Cast, (y, get_ascend_dtype(out_dtype)), {}) + x, y = self.promote_dtype(x, y, target_dtype=out_dtype) return self.get_proxy(ascend_op.AddV2, (x, y), {}) @register_conversion(torch.ops.aten.add.Scalar) @@ -279,7 +295,7 @@ def rsqrt(self, x): def inge(self, x, y): if not isinstance(y, torch.fx.proxy.Proxy): assert isinstance(y, int) - y = self.get_proxy(ascend_op.Const, ([y], torch.int32, [])) + y = self.get_const_proxy(ascend_op.Const, (y, torch.int32)) return self.get_proxy(ascend_op.GreaterEqual, (x, y)) @register_conversion(aten.div) @@ -288,11 +304,7 @@ def div(self, x, y): return self.get_proxy(ascend_op.DivNoNan, (x, y)) assert y != 0 out_dtype = fx_traceback.get_current_meta()['val'].dtype - const_dtype = torch.float32 if out_dtype == torch.float16 else out_dtype - y_shape = list(x.node.meta['val'].shape) - y_op = self.get_param_proxy(y, const_dtype, y_shape) - if out_dtype == torch.float16: - y_op = self.get_proxy(ascend_op.Cast, (y_op, "FLOAT16"), {}) + y_op = self.get_const_proxy(y, out_dtype) return self.get_proxy(ascend_op.Div, (x, y_op), {}) @register_conversion(aten.slice.Tensor) @@ -316,8 +328,7 @@ def Bernoulli(self, x, p, generator=None): assert generator is None dtype = x.node.meta['val'].dtype shape_op = self.get_proxy(ascend_op.Shape, (x,)) - prop_op = self.get_proxy( - ascend_op.Const, ([float(p)], torch.float32, [])) + prop_op = self.get_const_proxy(float(p), torch.float32) seed_op = self.get_proxy(ascend_op.Const, ([-1], torch.int64, [])) offset_op = self.get_proxy(ascend_op.Const, ([0], torch.int64, [])) return self.get_proxy(ascend_op.StatelessBernoulli, (shape_op, prop_op, seed_op, offset_op, dtype)) @@ -444,15 +455,15 @@ def arange(self, end, start=0, step=1, dtype=None, device='xpu', layout=None, pi assert isinstance(step, torch.fx.proxy.Proxy) or type(step) in [int, float] if not isinstance(start, torch.fx.proxy.Proxy): # scalar const - start = self.get_proxy(ascend_op.Const, (start, out_dtype)) + start = self.get_const_proxy(start, out_dtype) elif start.node.meta['val'] != out_dtype: # align tensor dtype start = self.get_proxy(ascend_op.Cast, (start, get_ascend_dtype(out_dtype)), {}) if not isinstance(end, torch.fx.proxy.Proxy): - end = self.get_proxy(ascend_op.Const, (end, out_dtype)) + end = self.get_const_proxy(end, out_dtype) elif end.node.meta['val'] != out_dtype: end = self.get_proxy(ascend_op.Cast, (end, get_ascend_dtype(out_dtype)), {}) if not isinstance(step, torch.fx.proxy.Proxy): - step = self.get_proxy(ascend_op.Const, (step, out_dtype)) + step = self.get_const_proxy(step, out_dtype) elif step.node.meta['val'] != out_dtype: step = self.get_proxy(ascend_op.Cast, (step, get_ascend_dtype(out_dtype)), {}) return self.get_proxy(ascend_op.Range, (start, end, step)) @@ -489,14 +500,10 @@ def lt(self, x, y): @register_conversion(aten.masked_fill.Scalar) def masked_fill(self, x, mask, value): - x_dtype = x.node.meta['val'].dtype - const_dtype = torch.float32 if x_dtype == torch.float16 else x_dtype if str(value) == "-inf": value = -3.4028234663852886e+38 - mask_shape = list(mask.node.meta['val'].shape) - value = self.get_param_proxy(value, const_dtype, mask_shape) - if x_dtype == torch.float16: - value = self.get_proxy(ascend_op.Cast, (value, "FLOAT16")) + x_dtype = x.node.meta['val'].dtype + value = self.get_const_proxy(value, x_dtype) return self.get_proxy(ascend_op.MaskedFill, (x, mask, value)) @register_conversion([torch.ops.aten.scatter.src, torch.ops.aten.scatter.value]) @@ -508,7 +515,7 @@ def scatter(self, var, dim, index, value): value = self.get_proxy(ascend_op.Reshape, (value, preprocess)) else: out_dtype = fx_traceback.get_current_meta()['val'].dtype - value = self.get_proxy(ascend_op.Const, (value, out_dtype)) + value = self.get_const_proxy(value, out_dtype) shape = self.get_proxy(ascend_op.Shape, (index,)) value = self.get_proxy(ascend_op.BroadcastTo, (value, shape)) return self.get_proxy(ascend_op.ScatterElements, (var, index, value, dim)) @@ -577,12 +584,7 @@ def full(self, dims, value, dtype=torch.float32, layout=torch.strided, if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'): value = value.node.meta['val'] dims = self.get_shape_proxy(dims) - - # temporarily split the path for dynamic/static shape cases - if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: - value = self.get_proxy(ascend_op.Const, ([value], torch_dtype, [])) - else: - value = self.common_process_scalar(value, torch_dtype) + value = self.get_const_proxy(value, torch_dtype) return self.get_proxy(ascend_op.Fill, (dims, value)) @register_conversion(torch.ops.aten.fill.Scalar) @@ -592,7 +594,7 @@ def fills(self, x, value): @register_conversion(torch.ops.aten.topk.default) def topk(self, x, k, dim=-1, largest=True, sorted=True): if not isinstance(k, torch.fx.proxy.Proxy): - k = self.get_proxy(ascend_op.Const, ([k], torch.int32, [])) + k = self.get_const_proxy(k, torch.int32) return self.get_proxy(ascend_op.TopK, (x, k, dim, largest, sorted)) @register_conversion(torch.ops.aten.sort.default) @@ -716,8 +718,7 @@ def maxpool2dbackward(self, grad, x, kernel_size, stride, padding, dilation, if padding != [0, 0]: padding = [0, 0, 0, 0, padding[0], padding[0], padding[1], padding[1]] - padding_const = self.get_proxy( - ascend_op.Const, (padding, torch.int32, [8], "NCHW")) + padding_const = self.get_const_proxy(padding, torch.int32, format="NCHW") pad_op = self.get_proxy(ascend_op.PadV3, (x, padding_const, )) fwd_out = self.get_proxy(ascend_op.MaxPool, (pad_op, kernel_size, stride, "VALID", "NCHW")) @@ -739,8 +740,7 @@ def maxpool2d(self, x, ksize, strides, padding=[0, 0]): if padding != [0, 0]: padding = [0, 0, 0, 0, padding[0], padding[0], padding[1], padding[1]] - paddings_const = self.get_proxy( - ascend_op.Const, (padding, [4, 2], torch.int32, "NCHW")) + paddings_const = self.get_const_proxy(padding, torch.int32, format="NCHW") x = self.get_proxy(ascend_op.PadV3, (x, paddings_const)) fwd_out_op = self.get_proxy( ascend_op.MaxPool, (x, ksize, strides, "VALID", "NCHW")) @@ -778,8 +778,7 @@ def slice_backward(self, grad, input_shape, dim, start, end, step): if i == dim: pad[i][0] = start pad[i][1] = v - end - padding_const = self.get_proxy(ascend_op.Const, - (pad.flatten().tolist(), torch.int32, [rank, 2])) + padding_const = self.get_const_proxy(pad.flatten().tolist(), torch.int32, target_shape=[rank, 2]) return self.get_proxy(ascend_op.Pad, (grad, padding_const)) @register_conversion(torch.ops.aten.var) @@ -792,8 +791,7 @@ def var(self, x, axes=[], correction=1, keepdim=True): raise RuntimeError("not supported yet!") if not isinstance(axes, list): axes = [axes] - axes_op = self.get_proxy(ascend_op.Const, (axes, torch.int32, [ - len(axes)] if len(axes) > 0 else [])) + axes_op = self.get_const_proxy(axes, torch.int32) mean_op = self.get_proxy(ascend_op.ReduceMean, (x, axes_op)) input_shape_op = self.get_proxy(ascend_op.Shape, (x,)) broadcast_op = self.get_proxy( @@ -806,39 +804,26 @@ def pow(self, x, exp): return self.get_proxy(ascend_op.Pow, (x, exp)) # exp is scalar dtype = fx_traceback.get_current_meta()['val'].dtype - exp_const = self.get_proxy(ascend_op.Const, ([exp], dtype, [])) + exp_const = self.get_const_proxy(exp, dtype) return self.get_proxy(ascend_op.Pow, (x, exp_const)) @register_conversion(aten.maximum.default) def maximum(self, a, b): a_shape = list(a.node.meta['val'].shape) b_shape = list(b.node.meta['val'].shape) - if np.prod(b_shape) < np.prod(a_shape): - b = self.get_param_proxy(b, None, a_shape) - if a.node.meta['val'].dtype == torch.float16: - b = self.get_proxy(ascend_op.Cast, (b, "FLOAT16")) + b = self.promote_dtype(b, target_dtype=a.node.meta['val'].dtype) return self.get_proxy(ascend_op.Maximum, (a, b)) - def common_process_scalar(self, y, dtype): - need_cast = False - if dtype == torch.float16: - dtype = torch.float32 - need_cast = True - y = self.get_proxy(ascend_op.Const, (y, dtype)) - if need_cast: - y = self.get_proxy(ascend_op.Cast, (y, "FLOAT16")) - return y - @register_conversion(aten.sub) def sub(self, x, y): if not isinstance(y, torch.fx.proxy.Proxy): - y = self.common_process_scalar(y, x.node.meta['val'].dtype) + y = self.get_const_proxy(y, x.node.meta['val'].dtype) return self.get_proxy(ascend_op.Sub, (x, y)) @register_conversion(aten.rsub) def rsub(self, x, y): if not isinstance(y, torch.fx.proxy.Proxy): - y = self.common_process_scalar(y, x.node.meta['val'].dtype) + y = self.get_const_proxy(y, x.node.meta['val'].dtype) return self.get_proxy(ascend_op.Sub, (y, x)) @register_conversion(aten.transpose.int) @@ -873,16 +858,15 @@ def convolution(self, input, weight, bias, stride, padding, @register_conversion(_operator.mul) def inmul(self, x, y): assert (not isinstance(y, torch.fx.proxy.Proxy)) - y = self.get_proxy(ascend_op.Const, ([y], torch.int32, [])) + y = self.get_const_proxy(y, torch.int32) return self.get_proxy(ascend_op.Mul, (x, y)) @register_conversion(torch.ops.aten.sym_size) def symsize(self, x, dim): dim = [dim] if not isinstance(dim, list) else dim shape = self.get_proxy(ascend_op.Shape, (x,)) - axis = self.get_proxy(ascend_op.Const, ([0], torch.int32, [1])) - indices = self.get_proxy( - ascend_op.Const, (dim, torch.int32, [len(dim)])) + axis = self.get_const_proxy(0, torch.int32, target_shape=[1]) + indices = self.get_const_proxy(dim, torch.int32) return self.get_proxy(ascend_op.GatherV2, (shape, indices, axis)) @register_conversion(torch.ops.aten.mm.default) @@ -914,9 +898,8 @@ def bmm(self, x, y): @register_conversion(torch.torch.ops.aten.addmm) def addmm(self, c, a, b, beta=1.0, alpha=1.0): - beta_op = self.get_proxy(ascend_op.Const, ([beta], torch.float32, [])) - alpha_op = self.get_proxy( - ascend_op.Const, ([alpha], torch.float32, [])) + beta_op = self.get_const_proxy(beta, torch.float32) + alpha_op = self.get_const_proxy(alpha, torch.float32) c_beta_op = self.get_proxy(ascend_op.Mul, (c, beta_op)) a_alpha_op = self.get_proxy(ascend_op.Mul, (a, alpha_op)) matmul_op = self.get_proxy( @@ -931,7 +914,7 @@ def mean(self, x, dims=[], keepdim=False): @register_conversion(torch.ops.aten.cumsum.default) def cumsum(self, x, dim, dtype=None): - dim_const = self.get_proxy(ascend_op.Const, ([dim], torch.int32, [1])) + dim_const = self.get_const_proxy(dim, torch.int32, target_shape=[1]) return self.get_proxy(ascend_op.Cumsum, (x, dim_const)) @register_conversion(torch.ops.aten._log_softmax.default) @@ -952,7 +935,7 @@ def repeat_interleave(self, repeats, output_size=1): assert x_shape[0] == 1 # TODO! fix implementation of repeatinterleave # Consider situation for repeats > 1 - return self.get_proxy(ascend_op.Const, ([0], torch.int64, [1])) + return self.get_const_proxy(0, torch.int64, target_shape=[1]) @register_conversion([aten.lift_fresh_copy, aten.lift_fresh_copy.default]) def lift_fresh_copy(self, tensor_constant): @@ -990,7 +973,7 @@ def exp(self, a): def embedding(self, weight, indices, padding_idx=-1): # TODO! consider situation for padding_idx # during training stage - axis = self.get_proxy(ascend_op.Const, ([0], torch.int32, [1])) + axis = self.get_const_proxy(0, torch.int32, target_shape=[1]) return self.get_proxy(ascend_op.GatherV2, (weight, indices, axis)) @register_conversion(torch.ops.aten.gather) @@ -1073,7 +1056,7 @@ def identity(self, x, idx): @register_conversion(torch.ops.aten.full_like) def fulllike(self, x, value, dtype=torch.float32, layout=torch.strided, device='cpu', pin_memory=False, memory_format=torch.preserve_format): - return self.get_proxy(ascend_op.Fills, (x,float(value))) + return self.get_proxy(ascend_op.Fills, (x, float(value))) @register_conversion(torch.ops.aten.zeros_like.default) def zeros_like(self, x, dtype=torch.float32, layout=torch.strided, @@ -1084,13 +1067,13 @@ def zeros_like(self, x, dtype=torch.float32, layout=torch.strided, def RandLike(self, x, dtype=torch.float32, layout=torch.strided, device='cpu', pin_memory=False, memory_format=torch.preserve_format): ascend_dtype = get_ascend_dtype(x.node.meta['val'].dtype) - key_op = self.get_proxy(ascend_op.Const, ([0], torch.int32, [1])) + key_op = self.get_const_proxy(0, torch.int32) key_cast_op = self.get_proxy(ascend_op.Cast, (key_op, "UINT64")) counter_op = self.get_proxy( ascend_op.Const, ([0, 0], torch.int32, [2])) counter_cast_op = self.get_proxy( ascend_op.Cast, (counter_op, "UINT64")) - alg_op = self.get_proxy(ascend_op.Const, ([1], torch.int32, [])) + alg_op = self.get_const_proxy(1, torch.int32) shape_op = self.get_proxy(ascend_op.Shape, (x,)) return self.get_proxy(ascend_op.StatelessRandomUniformV2, (shape_op, key_cast_op, counter_cast_op, alg_op, @@ -1099,28 +1082,14 @@ def RandLike(self, x, dtype=torch.float32, layout=torch.strided, @register_conversion(torch.ops.aten.gt.Scalar) def GtScalar(self, x, y): dtype = get_ascend_dtype(x.node.meta['val'].dtype) - scalar_op = self.get_proxy( - ascend_op.Const, ([float(y)], torch.float, [])) + scalar_op = self.get_const_proxy(float(y), torch.float32) cast_op = self.get_proxy(ascend_op.Cast, (scalar_op, dtype)) return self.get_proxy(ascend_op.Greater, (x, cast_op)) @register_conversion(torch.ops.aten.addcmul.default) def AddCMul(self, a, b, c, value=1): dtype = a.node.meta['val'].dtype - not_support_type = False - orig_ascend_dtype = get_ascend_dtype(dtype) - try: - cpp_dtype = get_cpp_dtype(dtype) - except Exception: - not_support_type = True - value_op = None - if not_support_type: - const_op = self.get_proxy( - ascend_op.Const, ([float(value)], torch.float32, [])) - value_op = self.get_proxy( - ascend_op.Cast, (const_op, orig_ascend_dtype)) - else: - value_op = self.get_proxy(ascend_op.Const, ([value], dtype, [])) + value_op = self.get_const_proxy(float(value), dtype) return self.get_proxy(ascend_op.Addcmul, (a, b, c, value_op)) @register_conversion(torch.ops.aten.reciprocal.default) @@ -1133,7 +1102,7 @@ def NativeDropout(self, x, p, train): dtype = x.node.meta['val'].dtype p = 1. - p shape = self.get_proxy(ascend_op.Shape, (x,)) - prob = self.get_proxy(ascend_op.Const, ([float(p)], torch.float, [])) + prob = self.get_const_proxy(float(p), torch.float32) mask = self.get_proxy(ascend_op.DropOutGenMaskV4, (shape, prob)) prob_op = prob if dtype == torch.float16: @@ -1146,9 +1115,5 @@ def NativeDropout(self, x, p, train): def NativeDropoutBackward(self, grad_output, mask, scale): dtype = grad_output.node.meta['val'].dtype p = 1. - scale - prob = self.get_proxy(ascend_op.Const, ([float(p)], torch.float, [])) - prob_op = prob - if dtype == torch.float16: - cast = self.get_proxy(ascend_op.Cast, (prob, "FLOAT16")) - prob_op = cast + prob_op = self.get_const_proxy(float(p), dtype) return self.get_proxy(ascend_op.DropOutDoMaskV3, (grad_output, mask, prob_op)) diff --git a/dicp/dicp/vendor/AscendGraph/infer_res_utils.py b/dicp/dicp/vendor/AscendGraph/infer_res_utils.py index 10cd5c167..0f7db62c0 100644 --- a/dicp/dicp/vendor/AscendGraph/infer_res_utils.py +++ b/dicp/dicp/vendor/AscendGraph/infer_res_utils.py @@ -50,12 +50,11 @@ def get_op_const_arg_kwarg( """ new_args = const_arg[0] len_args = len(new_args) - assert ( - len_args >= 2 and len_args <= 3 - ), " :currently, op 'Const' support only 2 or 3 params passed!" + assert (len_args >= 2 and len_args <= 4) arg0, dtype = new_args[0], new_args[1] - shape = new_args[2] if len(new_args) == 3 else None - return arg0, dtype, shape + shape = new_args[2] if len(new_args) >= 3 else None + ascend_format = new_args[3] if len(new_args) == 4 else None + return arg0, dtype, shape, ascend_format """analyze dtype,format"""