diff --git a/dicp/dicp/dynamo_bridge/utils.py b/dicp/dicp/dynamo_bridge/utils.py index 050102ad4..9d8f34189 100644 --- a/dicp/dicp/dynamo_bridge/utils.py +++ b/dicp/dicp/dynamo_bridge/utils.py @@ -7,6 +7,13 @@ from torch.fx.node import Argument, Target +def not_all_num_shape(shape): + for elem in shape: + if not isinstance(elem, int): + return True + return False + + def symint_in_shape(shape): for elem in shape: if isinstance(elem, torch.SymInt): @@ -14,6 +21,55 @@ def symint_in_shape(shape): return False +def neg_in_shape(shape): + for elem in shape: + if isinstance(elem, int): + if elem < 0: + return True + return False + + +def process_sym_name(st): + # dynamic shape feature + # return string wrapper in new version + # node.str() will not fallback SymInt value form + if isinstance(st, torch.SymInt): + return st.node.str() + return str(st) + + +def preprocess_expression(expr): + elem_str = process_sym_name(expr) + elem_str = elem_str.replace('+', ' + ') + elem_str = elem_str.replace('-', ' - ') + elem_str = elem_str.replace('*', ' * ') + elem_str = elem_str.replace('//', ' // ') + elem_str = elem_str.replace('(', ' ( ') + elem_str = elem_str.replace(')', ' ) ') + elems = elem_str.split(' ') + elems = [e for e in elems if e != ''] + return elems + + +def find_root_num(set_num, num): + while set_num[num] != num: + num = set_num[num] + return num + + +def merge_disjoint_set(set_num, idx_a, idx_b): + root_a = find_root_num(set_num, idx_a) + root_b = find_root_num(set_num, idx_b) + # an example for (s5 / 8) - (s5 / 16) + # num: 0 1 2 3 + # step1 - > set_num: 0 1 2 3 + # step2 - > set_num: 0 0 2 2 + # step3 - > set_num: 0 0 0 0 + + # return merged set from root_b to root_a + return [root_a if find_root_num(set_num, s) == root_b else s for s in set_num] + + def save_cpu_gm(gm: torch.fx.GraphModule, folder: str): Path(folder).mkdir(exist_ok=True) cpu_gm = copy_gm_to_cpu(gm) diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index 201187101..c1c85cab0 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -217,6 +217,14 @@ def infer_result(self, x): return common_unary_op_infer(x) +class Triu(Operator): + def __init__(self): + super().__init__("Triu") + + def infer_result(self, x, diag): + return common_unary_op_infer(x) + + class Sqrt(Operator): def __init__(self): super().__init__("Sqrt") @@ -567,6 +575,27 @@ def __init__(self): super().__init__("CastToCpu") +class SequenceAt(Operator): + def __init__(self): + super().__init__("SequenceAt") + + def infer_result(self, x, idx=None): + x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x) + if isinstance(x, (List, Tuple)): + return x[idx] + out_dtype = x_dtype + if x_dtype == torch.complex64: # for complex64 + out_shape = list(x_shape) + if idx == 0 or idx == 1: + out_dtype = torch.float32 + out_shape.append(1) + else: + out_shape = [x_shape[idx]] if idx is not None else list(x_shape) + return torch.empty( + out_shape, dtype=out_dtype, memory_format=get_memory_format(x) + ) + + class Identity(Operator): def __init__(self): super().__init__("Identity") @@ -750,7 +779,6 @@ def __init__(self, x): def infer_result(self, x): return common_unary_op_infer(x) - class SplitD(Operator): def __init__(self): super().__init__("SplitD") @@ -770,6 +798,12 @@ def infer_result(self, x, split_dim, num_split, y, from_view_complex=False): memory_format=get_memory_format(x), ) +class SplitToSequence(Operator): + def __init__(self): + super().__init__("SplitToSequence") + + def infer_result(self, x, split_dim, split_size): + torch.split(x, split_size, split_dim) class Slice(Operator): def __init__(self): diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index a7930c4b7..2e9fb5d3a 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -6,7 +6,7 @@ from torch.fx.node import Node from torch.utils._pytree import tree_map_only from torch._inductor.utils import IndentedBuffer -from dicp.dynamo_bridge.utils import symint_in_shape +from dicp.dynamo_bridge.utils import symint_in_shape, process_sym_name from dicp.vendor.AscendGraph.codegen.utils import ( get_ascend_dtype, get_cpp_dtype, @@ -219,58 +219,11 @@ def check_tensor(a, b, atol=5e-2, rtol=1e-2): ) return self.import_code.getvalue() - def process_sym_name(self, st): - # dynamic shape feature - if st.isdigit(): - return st - elif '+' in st: - sp = st.split('+') - if len(sp) > 2: - sp = [sp[0], '+'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0].isdigit(): - (sp[1], sp[0]) = (sp[0], sp[1]) - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '+' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '+' + sp[1] - else: - return self.process_sym_name(sp[0]) + '+' + sp[1] - elif '-' in st: - sp = st.split('-') - if len(sp) > 2: - sp = [sp[0], '-'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '-' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '-' + sp[1] - else: - return self.process_sym_name(sp[0]) + '-' + sp[1] - elif '*' in st: - sp = st.split('*') - if len(sp) > 2: - sp = [sp[0], '*'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0].isdigit(): - (sp[1], sp[0]) = (sp[0], sp[1]) - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '*' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '*' + sp[1] - else: - return self.process_sym_name(sp[0]) + '*' + sp[1] - else: - if st in self.sym_in_args: - arg, idx = self.sym_in_args[st] - return "{}.shape[{}]".format(arg, idx) - return self.sym_to_inputs[st] + def operator_in_str(self, st): + for op in ['+', '-', '*', '/']: + if op in st: + return True + return False def gen_call_func(self): # TODO check scalar input @@ -283,16 +236,23 @@ def gen_call_func(self): args = ['_' if arg not in shape_symint and arg not in self.sym_to_inputs.values() else arg for arg in self.args] call_body.writeline(f"({','.join(args)}) = args") + # assign SymInt to InputArgs relationship + if len(self.sym_in_args) > 0: + for key in self.sym_in_args.keys(): + if not key.isdigit() and not self.operator_in_str(key): + call_body.writeline(f"{key} = {self.sym_in_args[key][0]}.shape[{self.sym_in_args[key][1]}]") + if len(self.sym_to_inputs) > 0: + for key in self.sym_to_inputs.keys(): + if not key.isdigit() and not self.operator_in_str(key): + call_body.writeline(f"{key} = {self.sym_to_inputs[key]}") + # generate input dims if len(self.dynamic_inputs) > 0: - dim_len = 0 - for shape in self.actual_shape: - dim_len += len(shape) dims = 'dims = {' for idx, elem in enumerate(self.actual_shape): if len(elem) == 0: continue - elem = [self.process_sym_name(dim) for dim in elem] + elem = [process_sym_name(dim) for dim in elem] dims += str(self.dynamic_index[idx]) + \ ":[" + ','.join(map(str, elem)) + '],' dims = dims[:-1] + '}' @@ -315,7 +275,7 @@ def gen_call_func(self): shape = list(elem.shape) if len(shape) == 0: raise RuntimeError("Error handling empty output_shape") - shape = [self.process_sym_name(str(dim)) for dim in shape] + shape = [process_sym_name(dim) for dim in shape] shape_str += "[" + ','.join(map(str, shape)) + "]," # process output_shape with modified args @@ -323,12 +283,12 @@ def gen_call_func(self): shape = list(self.input_args[elem[1]].meta['val'].shape) if len(shape) == 0: raise RuntimeError("Error handling empty output_shape") - shape = [self.process_sym_name(str(dim)) for dim in shape] + shape = [process_sym_name(dim) for dim in shape] shape_str += "[" + ','.join(map(str, shape)) + "]," stride = list(self.input_args[elem[1]].meta['val'].stride()) if len(stride) == 0: raise RuntimeError("Error handling empty output_stride") - stride = [self.process_sym_name(str(dim)) for dim in stride] + stride = [process_sym_name(dim) for dim in stride] extra_stride_str += '[' + ','.join(map(str, stride)) + '],' extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ',' shape_str = shape_str[:-1] + ''']''' @@ -342,21 +302,23 @@ def gen_call_func(self): for elem in self.output_args: if hasattr(elem, 'meta'): elem = elem.meta['val'] - if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool): - out_strides.append('[1]') - out_storage_offsets.append('0') - continue - if elem.dim() == 0: # temporary solution for sum.default(a) whose result is a scalar(no dim no stride) + + # temporary solution for sum.default(a) whose result is scalar or with no dim no stride + if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool) or elem.dim() == 0: out_strides.append('[1]') out_storage_offsets.append('0') continue stride = list(elem.stride()) - stride = [self.process_sym_name(str(dim)) for dim in stride] - out_strides.append(str(stride)) + stride = [process_sym_name(dim) for dim in stride] + out_strides.append('[' + ','.join(map(str, stride)) + ']') out_storage_offsets.append(elem.storage_offset()) - call_body.writeline(f'out_stride = {out_strides}') + call_body.writeline(f'''out_stride = [{','.join(out_strides)}]''') call_body.writeline(f'out_storage_offset = {out_storage_offsets}') + # In precision debug mode, modified array recording InputArgs integer needed + if precision_check and self.aten_graph is not None: + call_body.writeline(f"modified = [idx for idx in range(len(args))] if isinstance(args[idx], int)") + call_body.splice(""" import torch_dipu dipu_device_str = torch_dipu.dipu.device.__diputype__ @@ -827,6 +789,13 @@ def Rsqrt(name, x): op.set_input("x", x) return op.to_node() + @staticmethod + def Triu(name, x, diag): + op = OP(name, "Triu") + op.set_input("x", x) + op.set_attr_int("diagonal", diag) + return op.to_node() + @staticmethod def Conv2D(name, input, weight, stride, padding, dilation, groups, format, bias): @@ -915,6 +884,14 @@ def Identity(name, input, index=None): op.set_input("x", input) return op.to_node() + @staticmethod + def SequenceAt(name, input, index=None): + op = OP(name, "SequenceAt") + assert index is not None + op.set_input("handle", input) + op.set_input("index", index) + return op.to_node() + @staticmethod def IdentityInp(name, input, dst=None): op = OP(name, "Identity") @@ -1377,6 +1354,14 @@ def SplitD(name, x, dim, num_split, y, from_view_complex): split_op.set_dynamic_output("y", y) return split_op.to_node() + @staticmethod + def SplitToSequence(name, x, dim, split_size): + split_op = OP(name, "SplitToSequence") + split_op.set_input("x", x) + split_op.set_input("split", split_size) + split_op.set_attr_int("axis", dim) + return split_op.to_node() + @staticmethod def Pack(name, x, axis): x_name = [] diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index a56674fba..ce593bf84 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -10,11 +10,13 @@ Number, ) import numpy as np +import sympy import torch.fx.traceback as fx_traceback from torch.fx.immutable_collections import immutable_list from torch._subclasses import FakeTensor import dicp.vendor.AscendGraph.ascend_op as ascend_op -from dicp.dynamo_bridge.utils import symint_in_shape +from dicp.dynamo_bridge.utils import symint_in_shape, neg_in_shape, not_all_num_shape, process_sym_name +from dicp.dynamo_bridge.utils import preprocess_expression, find_root_num, merge_disjoint_set from dicp.vendor.AscendGraph.codegen.utils import ( get_ascend_dtype ) @@ -44,8 +46,20 @@ def try_to_get_dtype(x): if isinstance(x, torch.fx.proxy.Proxy): if hasattr(x.node, "meta") and "val" in x.node.meta.keys(): return x.node.meta['val'].dtype + elif isinstance(x.node.target, ascend_op.Const): + # handle with const proxy dtype + assert len(x.node.args) > 1 + return x.node.args[1] else: return None + + # handle with basic scalar type + if isinstance(x, bool): + return torch.bool + elif isinstance(x, int): + return torch.int32 + elif isinstance(x, float): + return torch.float32 return None @@ -74,79 +88,214 @@ def process_dynamic_shape(self, shape): x_names = [] def generate_digits_op(shapes): - const_op = self.get_proxy( - ascend_op.Const, (shapes, torch.int32, [len(shapes)])) + const_op = self.get_const_proxy(shapes, torch.int32) x_names.append(const_op) - def generate_sym_int(elem): - elem = elem.node.str() - elems = elem.strip().split(' ') - - arg = None - # dynamic shape feature - if elems[0] in self.sym_in_args: - arg, idx = self.sym_in_args[elems[0]] + def replace_elem_proxy(elem_str): + # exit if already a proxy + if isinstance(elem_str, torch.fx.proxy.Proxy): + return elem_str + assert not elem_str in ['+', '-', '*', '//', '(', ')'] + + # handle with integer + if elem_str.isdigit(): + return self.get_const_proxy(int(elem_str), torch.int32) + + # handle with NodeProxy string + if 'Proxy' in elem_str: + # recover '()' from '[]' + elem_str = elem_str.replace('[', '(') + elem_str = elem_str.replace(']', ')') + + # split for sym_in_args candidate + idx = -1 + if elem_str[0] == '(' and elem_str[-1] == ')': + elem_str, idx = elem_str.strip('()').split(',') + + # search & replace + replace_proxy = None + arg_symint_candidate = [value[0] for value in self.sym_in_args.values()] + list(self.sym_to_inputs.values()) + for arg_sym in arg_symint_candidate: + if elem_str == str(arg_sym): + # is sym_in_args candidate + if int(idx) > -1: + shape = self.get_proxy(ascend_op.Shape, (arg_sym,)) + axis = self.get_const_proxy(0, torch.int32) + indice = self.get_const_proxy(int(idx), torch.int32) + replace_proxy = self.get_proxy( + ascend_op.GatherV2, (shape, indice, axis)) + else: + replace_proxy = arg_sym + break + assert replace_proxy is not None + return replace_proxy + + # handle if elem in shape of InputArgs + if elem_str in self.sym_in_args: + arg, idx = self.sym_in_args[elem_str] shape = self.get_proxy(ascend_op.Shape, (arg,)) - axis = self.get_proxy( - ascend_op.Const, ([0], torch.int32, [1])) - indice = self.get_proxy( - ascend_op.Const, ([idx], torch.int32, [1])) + axis = self.get_const_proxy(0, torch.int32) + indice = self.get_const_proxy(idx, torch.int32) gather = self.get_proxy( ascend_op.GatherV2, (shape, indice, axis)) + return gather + + # handle if SymInt InputArg needed + return self.sym_to_inputs[elem_str] + + def generate_not_num(elem): + # dynamic shape feature + # situation for NodeProxy + if isinstance(elem, torch.fx.proxy.Proxy): + x_names.append(elem) + return + + # case for NodeProxy string or SymInt + elems = preprocess_expression(elem) + # prepare for expression calculation if len(elems) > 1: - assert len(elems) == 3 - assert elems[2].isdigit() - assert elems[1] == '+' or elems[1] == '-' - const_op = self.get_proxy( - ascend_op.Const, ([int(elems[2])], torch.int32, [1])) - if arg is not None: - args = (gather, const_op) - else: - args = (self.sym_to_inputs[elems[0]], const_op) - if elems[1] == '+': - x_names.append(self.get_proxy(ascend_op.Add, args)) - else: - x_names.append(self.get_proxy(ascend_op.Sub, args)) + set_num = [] + priority = [] + nest = 0 + + # calculate priority for each operator + # set initial set number + for idx, e in enumerate(elems): + if e == '+' or e =='-': + priority.append(nest * 3 + 0) + elif e == '*' or e == '//': + priority.append(nest * 3 + 1) + else: + if e == '(': + nest += 1 + elif e == ')': + nest -= 1 + priority.append(-1) + + # init set number + if not e in ['+', '-', '*', '//', '(', ')']: + set_num.append(idx) + else: + set_num.append(-1) + + # start merge disjoint-set + if len(set_num) > 1: + while len(set(set_num)) > 2: + # seek the highest priority operator + max = -1 + m_idx = -1 + for idx, prio in enumerate(priority): + if prio > max: + max = prio + m_idx = idx + + # merge the highest priority two elements calculation + # find left & right element + left_idx = m_idx - 1 + while left_idx > 0 and str(elems[left_idx]) in ['(', ')']: + left_idx -= 1 + right_idx = m_idx + 1 + while right_idx < len(elems) - 1 and str(elems[right_idx]) in ['(', ')']: + right_idx += 1 + left_idx = find_root_num(set_num, set_num[left_idx]) + right_idx = find_root_num(set_num, set_num[right_idx]) + left_elem = replace_elem_proxy(elems[left_idx]) + right_elem = replace_elem_proxy(elems[right_idx]) + + # generate calculation operator + if elems[m_idx] == '+': + elems[left_idx] = self.get_proxy(ascend_op.Add, (left_elem, right_elem)) + elif elems[m_idx] == '-': + elems[left_idx] = self.get_proxy(ascend_op.Sub, (left_elem, right_elem)) + elif elems[m_idx] == '*': + elems[left_idx] = self.get_proxy(ascend_op.Mul, (left_elem, right_elem)) + else: + elems[left_idx] = self.get_proxy(ascend_op.Div, (left_elem, right_elem)) + + # merge set number and priority + set_num = merge_disjoint_set(set_num, left_idx, right_idx) + priority[m_idx] = -1 + + # add final element proxy + final_idx = 0 + while final_idx < len(elems) - 1 and str(elems[final_idx]) in ['(', ')']: + final_idx += 1 + final_elem = replace_elem_proxy(elems[final_idx]) + x_names.append(final_elem) else: - if arg is not None: - x_names.append(gather) - else: - x_names.append(self.sym_to_inputs[elems[0]]) + # only one not num element + node = replace_elem_proxy(elems[0]) + x_names.append(node) dims = [] for elem in shape: - if not isinstance(elem, torch.SymInt): + # process number + if isinstance(elem, int): dims.append(elem) continue - st = elem.node.str() + st = str(elem) if st.isdigit(): dims.append(int(st)) continue + # add number block if len(dims) > 0: generate_digits_op(dims) dims = [] - generate_sym_int(elem) + generate_not_num(elem) + + # add last number block if len(dims) > 0: generate_digits_op(dims) + # concat all ops - return self.get_proxy(ascend_op.ConcatD, (x_names, 0)) + if len(x_names) > 1: + return self.get_proxy(ascend_op.ConcatD, (x_names, 0)) + return self.get_proxy(ascend_op.Unsqueeze, (x_names[0], [0])) def get_shape_proxy(self, shape, dtype=torch.int32): + def symint_to_str(shape): + result_shape = [] + for dim in shape: + if isinstance(dim, torch.SymInt): + # split expression elements to compare SymInt string + dim_str = dim.node.str() + elems = preprocess_expression(dim_str) + + # replace SymInt in expression using sympy function + for elem in elems: + if 's' in elem: + # cover sym_to_inputs for higher priority to sym_in_args + sym_keys = {**self.sym_to_inputs, **self.sym_in_args} + replace_str = str(sym_keys[elem]).replace(' ', '') + + # '[]' will not mixed with expression calculation priority + replace_str = replace_str.replace('(', '[') + replace_str = replace_str.replace(')', ']') + dim_str = dim_str.replace(elem, replace_str) + result_shape.append(dim_str) + else: + result_shape.append(dim) + return result_shape + if isinstance(shape, torch.fx.proxy.Proxy) or isinstance(shape, FakeTensor): return shape - elif isinstance(shape, list) and symint_in_shape(shape): - return self.process_dynamic_shape(shape) - else: - return self.get_proxy( - ascend_op.Const, (shape, dtype, [len(shape)])) + elif isinstance(shape, list): + # handle SymInt alone + if symint_in_shape(shape): + shape = symint_to_str(shape) + + # both fit for SymInt & NodeProxy, pass all number cases + if not_all_num_shape(shape): + return self.process_dynamic_shape(shape) + return self.get_const_proxy(shape, dtype) def get_const_proxy(self, param, dtype, format=None, target_shape=None): if not isinstance(param, torch.fx.proxy.Proxy) and not isinstance(param, FakeTensor): format = "ND" if format is None else format if target_shape is None: - shape = [len(param)] if isinstance(param, list) else [] + shape = [len(param)] if isinstance(param, list) and len(param) > 1 else [] else: shape = target_shape param = param if isinstance(param, list) else [param] @@ -160,8 +309,31 @@ def get_const_proxy(self, param, dtype, format=None, target_shape=None): return param def promote_dtype(self, *args, target_dtype): + priority = {torch.bool:0, + torch.int32:1, + torch.int64:2, + torch.float16:3, + torch.float32:4, + torch.float64:5, + None:-1} result = [] + + # align maximum arg priority + max_prio = -1 + for arg in args: + cur_prio = priority[try_to_get_dtype(arg)] + if cur_prio > max_prio: + max_prio = cur_prio + + # align priority between arg max and target_dtype + cur_prio = priority[target_dtype] + if cur_prio > max_prio: + max_prio = cur_prio + assert max_prio > -1 + target_dtype = list(priority.keys())[max_prio] ascend_dtype = get_ascend_dtype(target_dtype) + + # core of dtype conversion for arg in args: if isinstance(arg, torch.fx.proxy.Proxy): current_dtype = try_to_get_dtype(arg) @@ -172,6 +344,9 @@ def promote_dtype(self, *args, target_dtype): # 1. unable to get tensor dtype # 2. current_dtype != target_dtype result.append(self.get_proxy(ascend_op.Cast, (arg, ascend_dtype))) + elif try_to_get_dtype(arg) is not None: + # handle with scalar case + result.append(self.get_const_proxy(arg, target_dtype)) else: raise RuntimeError("Not implemented") return tuple(result) if len(result) > 1 else result[0] @@ -299,6 +474,10 @@ def rsqrt(self, x): cond_op = self.get_proxy(ascend_op.Less, (x, zero_op)) return self.get_proxy(ascend_op.Select, (cond_op, nan_op, rsqrt_op)) + @register_conversion(aten.triu) + def triu(self, x, diag): + return self.get_proxy(ascend_op.Triu, (x, diag)) + @register_conversion(_operator.ge) def inge(self, x, y): if not isinstance(y, torch.fx.proxy.Proxy): @@ -306,12 +485,16 @@ def inge(self, x, y): y = self.get_const_proxy(y, torch.int32) return self.get_proxy(ascend_op.GreaterEqual, (x, y)) - @register_conversion(aten.div) + @register_conversion([aten.div, _operator.floordiv]) def div(self, x, y): if isinstance(y, torch.fx.proxy.Proxy): return self.get_proxy(ascend_op.DivNoNan, (x, y)) assert y != 0 - out_dtype = fx_traceback.get_current_meta()['val'].dtype + out = fx_traceback.get_current_meta()['val'] + if not isinstance(out, torch.SymInt): + out_dtype = out.dtype + else: + out_dtype = torch.int32 y_op = self.get_const_proxy(y, out_dtype) return self.get_proxy(ascend_op.Div, (x, y_op), {}) @@ -322,7 +505,8 @@ def split(self, x, split_size, dim=0): if dim < 0: dim += len(shape) assert shape[dim] > 0 - num_split = int((shape[dim] + split_size - 1) / split_size) + + num_split = len(fx_traceback.get_current_meta()['val']) return self.get_proxy(ascend_op.SplitD, (x, dim, num_split, num_split), splitD_kw) @register_conversion(aten.slice.Tensor) @@ -330,15 +514,14 @@ def slice(self, x, dim=0, start=None, end=None, step=1): # TODO(tangzhiyi): miss step parameter x_shape = list(x.node.meta['val'].shape) y_shape = list(fx_traceback.get_current_meta()['val'].shape) - # y_shape = fx_traceback.get_current_meta()['val'].shape dim = int(dim) start = int(start) if start is not None else 0 start = start if start >= 0 else x_shape[dim] + start - assert dim == -1 or dim >= 0 and dim < len(x_shape) assert start is None or start >= 0 and start < x_shape[dim] + + assert dim == -1 or dim >= 0 and dim < len(x_shape) offset = [0] * len(x_shape) offset[dim] = start - # import pdb; pdb.set_trace() offset = self.get_shape_proxy(offset) size = self.get_shape_proxy(y_shape) @@ -361,8 +544,7 @@ def NewEmptyStrided(self, x, size, stride, dtype=torch.float32, layout=torch.str @register_conversion(aten.empty) def empty(self, size, dtype=torch.int64, layout=torch.strided, device='cpu', memory_format=torch.contiguous_format, pin_memory=False): - shape_op = self.get_proxy( - ascend_op.Const, (size, torch.int32, [len(size)])) + shape_op = self.get_shape_proxy(size) return self.get_proxy(ascend_op.Empty, (shape_op, dtype, layout, device, memory_format)) @register_conversion(aten.empty_like.default) @@ -425,13 +607,10 @@ def view(self, x, size): shape = list(result_val.shape) if x.node.meta["val"].dtype == torch.complex64: shape.append(1) + size_tmp = [s for s in size] + [1] + size = immutable_list(size_tmp) numel = result_val.numel() - neg = False - for i in shape: - if not isinstance(i, torch.SymInt): - if i == -1: - neg = True - break + neg = neg_in_shape(shape) if neg: prod = 1 for i in shape: @@ -453,6 +632,8 @@ def view(self, x, size): raise RuntimeError( "cannot handle with both negative and symint!") shape = real_shape + elif not neg_in_shape(size) and not_all_num_shape(shape): + shape = size shape = self.get_shape_proxy(shape) if x.node.meta["val"].dtype == torch.complex64: real = self.get_proxy(ascend_op.Identity, (x, 0)) @@ -516,8 +697,8 @@ def lt(self, x, y): out = list(fx_traceback.get_current_meta()['val'].shape) out_shape = self.get_shape_proxy(out) x, y = self.binary_cmp_cast_input(x, y) - dynamic_shape = symint_in_shape(x_shape) or symint_in_shape( - y_shape) or symint_in_shape(out) + dynamic_shape = not_all_num_shape(x_shape) or not_all_num_shape( + y_shape) or not_all_num_shape(out) if dynamic_shape and (self.shape_prod(x_shape) < self.shape_prod(out)): x = self.get_proxy(ascend_op.BroadcastTo, (x, out_shape)) if dynamic_shape and (self.shape_prod(y_shape) < self.shape_prod(out)): @@ -552,7 +733,8 @@ def nll_loss_forward(self, x, target, weight, reduction, ignore_index): assert ignore_index == -100 reduction_str = get_reduction_str(reduction) csize = [list(x.node.meta['val'].shape)[1]] - target = self.get_proxy(ascend_op.Cast, (target, "INT32")) + if target.node.meta['val'].dtype != torch.int32: + target = self.get_proxy(ascend_op.Cast, (target, "INT32")) weight = self.get_proxy(ascend_op.FillV2D, (1.0, csize)) return self.get_proxy(ascend_op.NLLLoss, (x, target, weight, reduction_str, ignore_index)) @@ -562,7 +744,8 @@ def nll_loss_backward(self, grad_output, x, target, weight, reduction, ignore_in assert ignore_index == -100 reduction_str = get_reduction_str(reduction) csize = [list(x.node.meta['val'].shape)[1]] - target = self.get_proxy(ascend_op.Cast, (target, "INT32")) + if target.node.meta['val'].dtype != torch.int32: + target = self.get_proxy(ascend_op.Cast, (target, "INT32")) weight = self.get_proxy(ascend_op.FillV2D, (1.0, csize)) return self.get_proxy(ascend_op.NLLLossGrad, (x, grad_output, target, weight, total_weight, @@ -616,8 +799,6 @@ def full(self, dims, value, dtype=torch.float32, layout=torch.strided, if len(dims) == 0: return self.get_const_proxy(value, torch_dtype) - dims = [dim.node.meta['val'] if isinstance(dim, torch.fx.proxy.Proxy) and hasattr( - dim.node, 'meta') else dim for dim in dims] if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'): value = value.node.meta['val'] dims = self.get_shape_proxy(dims) @@ -790,7 +971,7 @@ def compute_stacked_indices(self, indices, src_shape): tensor_unsqueeze_len = none_count_in_indices - i if contiguous_flag \ else none_count_in_indices indice_i_shape = index.node.meta['val'].shape - assert not symint_in_shape(indice_i_shape) + assert not not_all_num_shape(indice_i_shape) tensor_reshape_shape.append(list(indice_i_shape) + [1] * tensor_unsqueeze_len) assert first_tensor_pos != -1, "all elements of indices is None, unsupported" tensor_broadcast_shape = list(torch.broadcast_shapes(*tensor_reshape_shape)) @@ -860,7 +1041,7 @@ def index_put_default(self, x, indices, values): return self.masked_fill(x, index, values) reshape_shape = index_shape + [1] * \ (x_shape_size - index_shape_size) - reshape_op = self.get_const_proxy(reshape_shape, torch.int32) + reshape_op = self.get_shape_proxy(reshape_shape) index = self.get_proxy(ascend_op.Reshape, (index, reshape_op)) return self.masked_fill(x, index, values) @@ -1055,13 +1236,14 @@ def expand(self, x, shape): y_shape = list(fx_traceback.get_current_meta()['val'].shape) if x_shape == y_shape: return self.get_proxy(ascend_op.Identity, (x, None)) + + # Cast needed only when x_dtype is int64 if x.node.meta['val'].dtype == torch.int64: x = self.get_proxy(ascend_op.Cast, (x, "INT32")) - shape = [dim.node.meta['val'] if hasattr( - dim, 'node') else dim for dim in shape] - if isinstance(shape, list) and symint_in_shape(shape): - preprocess_shape = self.process_dynamic_shape(shape) - return self.get_proxy(ascend_op.Expand, (x, preprocess_shape)) + + if isinstance(shape, list) and not_all_num_shape(shape): + shape = self.get_shape_proxy(shape) + return self.get_proxy(ascend_op.Expand, (x, shape)) else: return self.get_proxy(ascend_op.ExpandD, (x, shape)) @@ -1158,28 +1340,31 @@ def convolution(self, input, weight, bias, stride, padding, @register_conversion(_operator.mul) def inmul(self, x, y): - assert (not isinstance(y, torch.fx.proxy.Proxy)) - y = self.get_const_proxy(y, torch.int32) + if not isinstance(y, torch.fx.proxy.Proxy): + y = self.get_const_proxy(y, torch.int32) return self.get_proxy(ascend_op.Mul, (x, y)) @register_conversion(torch.ops.aten.sym_size) def symsize(self, x, dim): dim = [dim] if not isinstance(dim, list) else dim shape = self.get_proxy(ascend_op.Shape, (x,)) - axis = self.get_const_proxy(0, torch.int32, target_shape=[1]) + axis = self.get_const_proxy(0, torch.int32) indices = self.get_const_proxy(dim, torch.int32) return self.get_proxy(ascend_op.GatherV2, (shape, indices, axis)) @register_conversion(torch.ops.aten.mm.default) def mm(self, x, y): + out_dtype = fx_traceback.get_current_meta()['val'].dtype + # TODO! MatMul not support fp32 input # for higher precision in some cases if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: x = self.get_proxy(ascend_op.Unsqueeze, (x, [0])) y = self.get_proxy(ascend_op.Unsqueeze, (y, [0])) - mm = self.get_proxy(ascend_op.BatchMatMul, (x, y, False, False)) - return self.get_proxy(ascend_op.Squeeze, (mm, [0])) - out_dtype = fx_traceback.get_current_meta()['val'].dtype + bmm = self.get_proxy(ascend_op.BatchMatMul, (x, y, False, False)) + cast = self.get_proxy(ascend_op.Cast, (bmm, get_ascend_dtype(out_dtype))) + return self.get_proxy(ascend_op.Squeeze, (cast, [0])) + trans_x = False trans_y = False if isinstance(x.node.target, ascend_op.Permute) and x.node.args[1] == [1, 0]: @@ -1189,12 +1374,16 @@ def mm(self, x, y): y = self.get_proxy_from_node(y.node.args[0]) trans_y = True mm = self.get_proxy(ascend_op.MatMul, (x, y, trans_x, trans_y)) + + # TODO! complicated logic in MatMul output dtype return self.get_proxy(ascend_op.Cast, (mm, get_ascend_dtype(out_dtype))) @register_conversion(aten.bmm.default) def bmm(self, x, y): out_dtype = fx_traceback.get_current_meta()['val'].dtype bmm = self.get_proxy(ascend_op.BatchMatMul, (x, y, False, False, sd_fp16 ^ 1)) + + # TODO! complicated logic in BatchMatMul output dtype return self.get_proxy(ascend_op.Cast, (bmm, get_ascend_dtype(out_dtype))) @register_conversion(torch.torch.ops.aten.addmm) @@ -1288,7 +1477,7 @@ def exp(self, a): def embedding(self, weight, indices, padding_idx=-1): # TODO! consider situation for padding_idx # during training stage - axis = self.get_const_proxy(0, torch.int32, target_shape=[1]) + axis = self.get_const_proxy(0, torch.int32) return self.get_proxy(ascend_op.GatherV2, (weight, indices, axis)) @register_conversion(torch.ops.aten.gather) @@ -1404,9 +1593,8 @@ def RandLike(self, x, dtype=torch.float32, layout=torch.strided, @register_conversion(torch.ops.aten.gt.Scalar) def GtScalar(self, x, y): dtype = get_ascend_dtype(x.node.meta['val'].dtype) - scalar_op = self.get_const_proxy(float(y), torch.float32) - cast_op = self.get_proxy(ascend_op.Cast, (scalar_op, dtype)) - return self.get_proxy(ascend_op.Greater, (x, cast_op)) + x, scalar_op = self.promote_dtype(x, scalar_op, target_dtype=dtype) + return self.get_proxy(ascend_op.Greater, (x, scalar_op)) @register_conversion(torch.ops.aten.addcmul.default) def AddCMul(self, a, b, c, value=1): @@ -1468,9 +1656,11 @@ def Repeat(self, x, repeats): @register_conversion([torch.ops.aten.ge.Scalar, torch.ops.aten.ge.Tensor]) def Ge(self, x, y): - if not isinstance(y, torch.fx.proxy.Proxy): - dtype = x.node.meta['val'].dtype - y = self.get_const_proxy(y, dtype) + x_dtype = x.node.meta['val'].dtype + if isinstance(y, torch.fx.proxy.Proxy) and isinstance(y.node.meta['val'], torch.SymInt): + y = self.get_shape_proxy([y]) + y = self.get_proxy(ascend_op.Squeeze, (y, [0])) + x, y = self.promote_dtype(x, y, target_dtype=x_dtype) return self.get_proxy(ascend_op.GreaterEqual, (x, y)) @register_conversion(torch.ops.aten.logical_or.default) @@ -1529,7 +1719,7 @@ def lightllm_rotary_emb(self, x, cos, sin): seq_len = x_shape[0] dim = x_shape[2] - cos_sin_shape = self.get_const_proxy([seq_len, 1, dim // 2], torch.int32) + cos_sin_shape = self.get_shape_proxy([seq_len, 1, dim // 2]) cos = self.get_proxy(ascend_op.Reshape, (cos, cos_sin_shape)) sin = self.get_proxy(ascend_op.Reshape, (sin, cos_sin_shape)) @@ -1548,14 +1738,21 @@ def lightllm_rms_norm(self, x, weight, eps): @register_conversion(torch.ops.lightllm.prompt_attention_inference.default) def prompt_attention_inference(self, q, k, v, seqlen, num_head, head_dim): q_shape = list(q.node.meta['val'].shape) + q_dtype = q.node.meta['val'].dtype seq_len = q_shape[1] shape = [seq_len, seq_len] - shape = self.get_proxy(ascend_op.Const, (shape, torch.int32, [len(shape)])) + shape = self.get_shape_proxy(shape) mask = self.get_proxy(ascend_op.Empty, (shape, torch.bool)) mask = self.get_proxy(ascend_op.OnesLike, (mask,)) mask = self.get_proxy(ascend_op.Tril, (mask,)) mask = self.get_proxy(ascend_op.LogicalNot, (mask,)) + if q_dtype != torch.float16: + q = self.get_proxy(ascend_op.Cast, (q, get_ascend_dtype(torch.float16))) fa = self.get_proxy(ascend_op.PromptFlashAttention, (q, k, v, num_head, seqlen, mask, head_dim)) + + # cast back fa to float32 + if q_dtype != torch.float16: + return self.get_proxy(ascend_op.Cast, (fa, get_ascend_dtype(q_dtype))) return fa def incre_flash_attention(self, q, k, v, head_num, kv_head_num, dim): @@ -1588,28 +1785,58 @@ def select_scatter(self, x, src, dim, index): @register_conversion(torch.ops.lightllm.copy_with_offset.default) def copy_with_offset(self, x, src, start_dim, end_dim): - dims = [x for x in range(start_dim, end_dim)] - dims = self.get_const_proxy(dims, torch.int32, target_shape=[len(dims), 1]) + if isinstance(start_dim, int) and isinstance(end_dim, int): + dims = [x for x in range(start_dim, end_dim)] + dims = self.get_const_proxy(dims, torch.int32, target_shape=[len(dims), 1]) + return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) + + step = self.get_const_proxy(1, torch.int32) + if isinstance(start_dim, int): + start_dim = self.get_const_proxy(start_dim, torch.int32) + if isinstance(end_dim, int): + end_dim = self.get_const_proxy(end_dim, torch.int32) + dims = self.get_proxy(ascend_op.Range, (start_dim, end_dim, step)) + dims = self.get_proxy(ascend_op.Unsqueeze, (dims, [-1])) + + x_dtype = x.node.meta['val'].dtype + src_dtype = src.node.meta['val'].dtype + if x_dtype != src_dtype: + src = self.get_proxy(ascend_op.Cast, (src, get_ascend_dtype(src_dtype))) return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) @register_conversion(torch.ops.lightllm.flash_attention_inference.default) - def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): + def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len, kvhead=-1, head=-1, dim=-1): q_shape = list(q.node.meta['val'].shape) - batch, head, dim = q_shape[0], q_shape[1], q_shape[2] + batch = q_shape[0] + if head < 0: + head = q_shape[1] + if dim < 0: + dim = q_shape[2] + k_shape = list(all_k.node.meta['val'].shape) - kvhead = k_shape[1] + if kvhead < 0: + kvhead = k_shape[1] res = [] compute_batch = 1 select_axis = self.get_const_proxy(0, torch.int32) for i in range(batch): - current_len = current_len[i] + current_len = current_lens[i] select_index = self.get_const_proxy(i, torch.int32) xq = self.get_proxy(ascend_op.GatherV2, (q, select_index, select_axis)) kv_start_index = self.get_const_proxy([i * max_len, 0, 0], torch.int32) - kv_end_index = self.get_const_proxy([i * max_len + current_len, kvhead, dim], torch.int32) + + # split path for dynamic & static 7B + if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: + imax_const = self.get_const_proxy(i * max_len, torch.int32) + curlen_const = self.get_const_proxy(current_len, torch.int32) + end_proxy = self.get_proxy(ascend_op.Add, (curlen_const, imax_const)) + kv_end_index = self.get_shape_proxy([end_proxy, kvhead, dim], torch.int32) + else: + kv_end_index = self.get_const_proxy([i * max_len + current_len, kvhead, dim], torch.int32) + kv_seq_len = current_len kv_gather_shape = self.get_shape_proxy([compute_batch, kv_seq_len, kvhead, dim]) diff --git a/dicp/dicp/vendor/AscendGraph/ext_ops.py b/dicp/dicp/vendor/AscendGraph/ext_ops.py index 324d2a9b2..d2dfcc6ed 100644 --- a/dicp/dicp/vendor/AscendGraph/ext_ops.py +++ b/dicp/dicp/vendor/AscendGraph/ext_ops.py @@ -88,21 +88,23 @@ def lightllm_prompt_attention_inference_impl(q, k, v, seqlen, num_head, head_dim @torch._custom_op.impl.custom_op('lightllm::flash_attention_inference') -def flash_attention_inference(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int) -> Tensor: +def flash_attention_inference(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int, kvhead: int, head: int, dim: int) -> Tensor: ... @flash_attention_inference.impl_abstract() -def lightllm_flash_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int): +def lightllm_flash_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int, kvhead: int, head: int, dim: int): return torch.empty_like(q) @flash_attention_inference.impl(['cpu', 'cuda']) -def lightllm_flash_attention_inference_impl(q, all_k, all_v, current_lens, max_len): +def lightllm_flash_attention_inference_impl(q, all_k, all_v, current_lens, max_len, kvhead=-1, head=-1, dim=-1): # q: batch, head, dim batch = q.shape[0] - head = q.shape[1] - dim = q.shape[2] + if head < 0: + head = q.shape[1] + if dim < 0: + dim = q.shape[2] res = [] compute_batch = 1 for i in range(batch): diff --git a/dicp/test/ascend_scripts/ops/dynamic.ini b/dicp/test/ascend_scripts/ops/dynamic.ini index 9e0c79387..50aa2a9cd 100644 --- a/dicp/test/ascend_scripts/ops/dynamic.ini +++ b/dicp/test/ascend_scripts/ops/dynamic.ini @@ -1,4 +1,5 @@ [pytest] testpaths = ../../op python_files = + test_split_dynamic.py \ No newline at end of file diff --git a/dicp/test/ascend_scripts/ops/static.ini b/dicp/test/ascend_scripts/ops/static.ini index c73cb4f71..5077e5f0e 100644 --- a/dicp/test/ascend_scripts/ops/static.ini +++ b/dicp/test/ascend_scripts/ops/static.ini @@ -71,6 +71,7 @@ python_files = test_repeat_interleave.py ; test_transpose.py test_tril.py + test_triu.py test_unsqueeze.py test_view_as_complex.py test_view_as_real.py diff --git a/dicp/test/op/test_lightllm_incre_attention.py b/dicp/test/op/test_lightllm_incre_attention.py index f2c35ca1b..a8bd1d4da 100644 --- a/dicp/test/op/test_lightllm_incre_attention.py +++ b/dicp/test/op/test_lightllm_incre_attention.py @@ -14,7 +14,7 @@ class OpModule(torch.nn.Module): def forward(self, q, k, v, int_index_list, max_seq_length): - res = torch.ops.lightllm.flash_attention_inference.default(q, k, v, int_index_list, max_seq_length) + res = torch.ops.lightllm.flash_attention_inference.default(q, k, v, int_index_list, max_seq_length, -1, -1, -1) return res diff --git a/dicp/test/op/test_split_dynamic.py b/dicp/test/op/test_split_dynamic.py new file mode 100644 index 000000000..73535ee8f --- /dev/null +++ b/dicp/test/op/test_split_dynamic.py @@ -0,0 +1,45 @@ +import pytest +import operator +from ..common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, a, b): + split = torch.ops.aten.split.Tensor(a, 1) + res = operator.getitem(split, b) + return res + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + +torch._dynamo.config.dynamic_shapes = True +torch._dynamo.config.assume_static_by_default = True + +class TestSequenceAt(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) + @pytest.mark.parametrize("dim", [0]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_operator_sequence_at(self, sizes, dim, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + + in_tensor = torch.randn(size, dtype=dtype) + dicp_input = in_tensor.to(device) + + output = model(in_tensor, dim) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input, dim) + + assert torch.allclose(output, dicp_output.cpu(), equal_nan=True) diff --git a/dicp/test/op/test_triu.py b/dicp/test/op/test_triu.py new file mode 100644 index 000000000..91b0f1f81 --- /dev/null +++ b/dicp/test/op/test_triu.py @@ -0,0 +1,40 @@ +import pytest +from ..common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, a): + res_default = torch.ops.aten.triu(a, 1) + return res_default + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + + +class TestTriu(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size((3, 5), (3, 5)), Size((2, 3, 4), (2, 4))]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_torch_triu(self, sizes, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + input1 = torch.ones(size, dtype=dtype) + + dicp_input1 = input1.to(device) + + output = model(input1) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input1) + + assert torch.allclose(output, dicp_output.cpu(), equal_nan=True)