From 39d38b0bcebfef25c7191927eb4418af97b3909b Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Tue, 23 Apr 2024 10:19:49 +0000 Subject: [PATCH 01/42] Refine code structure of dynamic shape handling. --- dicp/dicp/dynamo_bridge/compile_fx.py | 2 + dicp/dicp/dynamo_bridge/utils.py | 7 +++ dicp/dicp/vendor/AscendGraph/conversion.py | 67 ++++++++++------------ 3 files changed, 38 insertions(+), 38 deletions(-) diff --git a/dicp/dicp/dynamo_bridge/compile_fx.py b/dicp/dicp/dynamo_bridge/compile_fx.py index 5eaee327c..ba4a8f7c9 100644 --- a/dicp/dicp/dynamo_bridge/compile_fx.py +++ b/dicp/dicp/dynamo_bridge/compile_fx.py @@ -47,6 +47,8 @@ def get_fake_mode_from_tensors(input_tensors): def used_nodes_all_symint(nodes): for node in nodes: + # if 'view' in str(node) or 'reshape' in str(node): + # import pdb; pdb.set_trace() if node.op == 'placeholder' and len(node.users) > 0: if hasattr(node, 'meta'): node = node.meta['val'] diff --git a/dicp/dicp/dynamo_bridge/utils.py b/dicp/dicp/dynamo_bridge/utils.py index 050102ad4..ddeabf086 100644 --- a/dicp/dicp/dynamo_bridge/utils.py +++ b/dicp/dicp/dynamo_bridge/utils.py @@ -7,6 +7,13 @@ from torch.fx.node import Argument, Target +def proxy_in_shape(shape): + for elem in shape: + if isinstance(elem, torch.fx.proxy.Proxy): + return True + return False + + def symint_in_shape(shape): for elem in shape: if isinstance(elem, torch.SymInt): diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 11e72be2e..9cb410b2a 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -14,7 +14,7 @@ from torch.fx.immutable_collections import immutable_list from torch._subclasses import FakeTensor import dicp.vendor.AscendGraph.ascend_op as ascend_op -from dicp.dynamo_bridge.utils import symint_in_shape +from dicp.dynamo_bridge.utils import proxy_in_shape from dicp.vendor.AscendGraph.codegen.utils import ( get_ascend_dtype, get_cpp_dtype @@ -70,7 +70,7 @@ def register_conversion(aten_fn): class AtenToAscendTransformer(SingleOpTransformer): def __init__(self, gm): super().__init__(gm, conversions) - + def process_dynamic_shape(self, shape): x_names = [] @@ -80,47 +80,34 @@ def generate_digits_op(shapes): x_names.append(const_op) def generate_sym_int(elem): - elem = elem.node.str() - elems = elem.strip().split(' ') + elem_str = str(elem) + elems = elem_str.strip().split(' ') - arg = None # dynamic shape feature - if elems[0] in self.sym_in_args: - arg, idx = self.sym_in_args[elems[0]] - shape = self.get_proxy(ascend_op.Shape, (arg,)) - axis = self.get_proxy( - ascend_op.Const, ([0], torch.int32, [1])) - indice = self.get_proxy( - ascend_op.Const, ([idx], torch.int32, [1])) - gather = self.get_proxy( - ascend_op.GatherV2, (shape, indice, axis)) + assert 'gather' in elems[0] or 'arg' in elems[0] + replace_node = elem if len(elems) > 1: + import pdb; pdb.set_trace() assert len(elems) == 3 assert elems[2].isdigit() assert elems[1] == '+' or elems[1] == '-' const_op = self.get_proxy( ascend_op.Const, ([int(elems[2])], torch.int32, [1])) - if arg is not None: - args = (gather, const_op) - else: - args = (self.sym_to_inputs[elems[0]], const_op) + args = (replace_node, const_op) if elems[1] == '+': x_names.append(self.get_proxy(ascend_op.Add, args)) else: x_names.append(self.get_proxy(ascend_op.Sub, args)) else: - if arg is not None: - x_names.append(gather) - else: - x_names.append(self.sym_to_inputs[elems[0]]) + x_names.append(replace_node) dims = [] for elem in shape: - if not isinstance(elem, torch.SymInt): + if not isinstance(elem, torch.fx.proxy.Proxy): dims.append(elem) continue - st = elem.node.str() + st = str(elem) if st.isdigit(): dims.append(int(st)) continue @@ -137,11 +124,12 @@ def generate_sym_int(elem): def get_shape_proxy(self, shape): if isinstance(shape, torch.fx.proxy.Proxy) or isinstance(shape, FakeTensor): return shape - elif isinstance(shape, list) and symint_in_shape(shape): - return self.process_dynamic_shape(shape) - else: - return self.get_proxy( - ascend_op.Const, (shape, torch.int32, [len(shape)])) + elif isinstance(shape, list): + shape = [self.sym_to_inputs[dim.node.str()] if isinstance(dim, torch.SymInt) else dim for dim in shape] + if proxy_in_shape(shape): + return self.process_dynamic_shape(shape) + return self.get_proxy( + ascend_op.Const, (shape, torch.int32, [len(shape)])) def get_const_proxy(self, param, dtype, format=None, target_shape=None): if not isinstance(param, torch.fx.proxy.Proxy) and not isinstance(param, FakeTensor): @@ -418,6 +406,7 @@ def view(self, x, size): shape = list(result_val.shape) if x.node.meta["val"].dtype == torch.complex64: shape.append(1) + size.append(1) numel = result_val.numel() neg = False for i in shape: @@ -437,7 +426,7 @@ def view(self, x, size): real_shape = [] for i in shape: - if not isinstance(i, torch.SymInt): + if not isinstance(i, torch.fx.proxy.Proxy): if i > 0: real_shape.append(str(i)) else: @@ -446,6 +435,8 @@ def view(self, x, size): raise RuntimeError( "cannot handle with both negative and symint!") shape = real_shape + else: + shape = size shape = self.get_shape_proxy(shape) if x.node.meta["val"].dtype == torch.complex64: real = self.get_proxy(ascend_op.Identity, (x, 0)) @@ -509,8 +500,8 @@ def lt(self, x, y): out = list(fx_traceback.get_current_meta()['val'].shape) out_shape = self.get_shape_proxy(out) x, y = self.binary_cmp_cast_input(x, y) - dynamic_shape = symint_in_shape(x_shape) or symint_in_shape( - y_shape) or symint_in_shape(out) + dynamic_shape = proxy_in_shape(x_shape) or proxy_in_shape( + y_shape) or proxy_in_shape(out) if dynamic_shape and (self.shape_prod(x_shape) < self.shape_prod(out)): x = self.get_proxy(ascend_op.BroadcastTo, (x, out_shape)) if dynamic_shape and (self.shape_prod(y_shape) < self.shape_prod(out)): @@ -609,8 +600,8 @@ def full(self, dims, value, dtype=torch.float32, layout=torch.strided, if len(dims) == 0: return self.get_const_proxy(value, torch_dtype) - dims = [dim.node.meta['val'] if isinstance(dim, torch.fx.proxy.Proxy) and hasattr( - dim.node, 'meta') else dim for dim in dims] + # dims = [dim.node.meta['val'] if isinstance(dim, torch.fx.proxy.Proxy) and hasattr( + # dim.node, 'meta') else dim for dim in dims] if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'): value = value.node.meta['val'] dims = self.get_shape_proxy(dims) @@ -783,7 +774,7 @@ def compute_stacked_indices(self, indices, src_shape): tensor_unsqueeze_len = none_count_in_indices - i if contiguous_flag \ else none_count_in_indices indice_i_shape = index.node.meta['val'].shape - assert not symint_in_shape(indice_i_shape) + assert not proxy_in_shape(indice_i_shape) tensor_reshape_shape.append(list(indice_i_shape) + [1] * tensor_unsqueeze_len) assert first_tensor_pos != -1, "all elements of indices is None, unsupported" tensor_broadcast_shape = list(torch.broadcast_shapes(*tensor_reshape_shape)) @@ -1039,9 +1030,9 @@ def expand(self, x, shape): return self.get_proxy(ascend_op.Identity, (x, None)) if x.node.meta['val'].dtype == torch.int64: x = self.get_proxy(ascend_op.Cast, (x, "INT32")) - shape = [dim.node.meta['val'] if hasattr( - dim, 'node') else dim for dim in shape] - if isinstance(shape, list) and symint_in_shape(shape): + # shape = [dim.node.meta['val'] if hasattr( + # dim, 'node') else dim for dim in shape] + if isinstance(shape, list) and proxy_in_shape(shape): preprocess_shape = self.process_dynamic_shape(shape) return self.get_proxy(ascend_op.Expand, (x, preprocess_shape)) else: From 5a97b9b646d80e203fb8e68ac3fd8298d0e7329f Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Wed, 24 Apr 2024 04:40:20 +0000 Subject: [PATCH 02/42] Adjust symint_to_args relationship code logic. --- dicp/dicp/dynamo_bridge/utils.py | 4 +- dicp/dicp/vendor/AscendGraph/conversion.py | 69 ++++++++++++++++------ 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/dicp/dicp/dynamo_bridge/utils.py b/dicp/dicp/dynamo_bridge/utils.py index ddeabf086..cb62f3db3 100644 --- a/dicp/dicp/dynamo_bridge/utils.py +++ b/dicp/dicp/dynamo_bridge/utils.py @@ -7,9 +7,9 @@ from torch.fx.node import Argument, Target -def proxy_in_shape(shape): +def not_all_num_shape(shape): for elem in shape: - if isinstance(elem, torch.fx.proxy.Proxy): + if not isinstance(elem, int): return True return False diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 9cb410b2a..882e40cc4 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -10,11 +10,12 @@ Number, ) import numpy as np +import sympy import torch.fx.traceback as fx_traceback from torch.fx.immutable_collections import immutable_list from torch._subclasses import FakeTensor import dicp.vendor.AscendGraph.ascend_op as ascend_op -from dicp.dynamo_bridge.utils import proxy_in_shape +from dicp.dynamo_bridge.utils import not_all_num_shape, symint_in_shape from dicp.vendor.AscendGraph.codegen.utils import ( get_ascend_dtype, get_cpp_dtype @@ -79,22 +80,32 @@ def generate_digits_op(shapes): ascend_op.Const, (shapes, torch.int32, [len(shapes)])) x_names.append(const_op) - def generate_sym_int(elem): + def generate_sym_or_proxy(elem): elem_str = str(elem) elems = elem_str.strip().split(' ') # dynamic shape feature - assert 'gather' in elems[0] or 'arg' in elems[0] - replace_node = elem - + replace_node = None + if isinstance(elem, torch.fx.proxy.Proxy): + assert 'gather' in elems[0] or 'arg' in elems[0] + replace_node = elem + else: + for convert in self.sym_to_inputs.values(): + if elems[0] == str(convert): + replace_node = convert + break + assert replace_node is not None + + # process simple expression including +/- + # or single node for else block if len(elems) > 1: - import pdb; pdb.set_trace() assert len(elems) == 3 assert elems[2].isdigit() assert elems[1] == '+' or elems[1] == '-' const_op = self.get_proxy( ascend_op.Const, ([int(elems[2])], torch.int32, [1])) args = (replace_node, const_op) + if elems[1] == '+': x_names.append(self.get_proxy(ascend_op.Add, args)) else: @@ -104,7 +115,8 @@ def generate_sym_int(elem): dims = [] for elem in shape: - if not isinstance(elem, torch.fx.proxy.Proxy): + # process number + if isinstance(elem, int): dims.append(elem) continue st = str(elem) @@ -112,21 +124,46 @@ def generate_sym_int(elem): dims.append(int(st)) continue + # process SymInt or NodeProxy if len(dims) > 0: generate_digits_op(dims) dims = [] - generate_sym_int(elem) + generate_sym_or_proxy(elem) + + # last number block if len(dims) > 0: generate_digits_op(dims) + # concat all ops return self.get_proxy(ascend_op.ConcatD, (x_names, 0)) def get_shape_proxy(self, shape): + def symint_to_str(shape): + result_shape = [] + for dim in shape: + if isinstance(dim, torch.SymInt): + # split expression elements to compare SymInt string + dim_str = dim.node.str() + elems = dim_str.strip().split(' ') + + # replace SymInt in expression using sympy function + for elem in elems: + if 's' in elem: + dim_str = str(sympy.simplify(dim_str).subs(elem, str(self.sym_to_inputs[elem]))) + result_shape.append(dim_str) + else: + result_shape.append(dim) + return result_shape + if isinstance(shape, torch.fx.proxy.Proxy) or isinstance(shape, FakeTensor): return shape elif isinstance(shape, list): - shape = [self.sym_to_inputs[dim.node.str()] if isinstance(dim, torch.SymInt) else dim for dim in shape] - if proxy_in_shape(shape): + # handle SymInt alone + if symint_in_shape(shape): + shape = symint_to_str(shape) + + # both fit for SymInt & NodeProxy, pass all number cases + if not_all_num_shape(shape): return self.process_dynamic_shape(shape) return self.get_proxy( ascend_op.Const, (shape, torch.int32, [len(shape)])) @@ -500,8 +537,8 @@ def lt(self, x, y): out = list(fx_traceback.get_current_meta()['val'].shape) out_shape = self.get_shape_proxy(out) x, y = self.binary_cmp_cast_input(x, y) - dynamic_shape = proxy_in_shape(x_shape) or proxy_in_shape( - y_shape) or proxy_in_shape(out) + dynamic_shape = not_all_num_shape(x_shape) or not_all_num_shape( + y_shape) or not_all_num_shape(out) if dynamic_shape and (self.shape_prod(x_shape) < self.shape_prod(out)): x = self.get_proxy(ascend_op.BroadcastTo, (x, out_shape)) if dynamic_shape and (self.shape_prod(y_shape) < self.shape_prod(out)): @@ -600,8 +637,6 @@ def full(self, dims, value, dtype=torch.float32, layout=torch.strided, if len(dims) == 0: return self.get_const_proxy(value, torch_dtype) - # dims = [dim.node.meta['val'] if isinstance(dim, torch.fx.proxy.Proxy) and hasattr( - # dim.node, 'meta') else dim for dim in dims] if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'): value = value.node.meta['val'] dims = self.get_shape_proxy(dims) @@ -774,7 +809,7 @@ def compute_stacked_indices(self, indices, src_shape): tensor_unsqueeze_len = none_count_in_indices - i if contiguous_flag \ else none_count_in_indices indice_i_shape = index.node.meta['val'].shape - assert not proxy_in_shape(indice_i_shape) + assert not not_all_num_shape(indice_i_shape) tensor_reshape_shape.append(list(indice_i_shape) + [1] * tensor_unsqueeze_len) assert first_tensor_pos != -1, "all elements of indices is None, unsupported" tensor_broadcast_shape = list(torch.broadcast_shapes(*tensor_reshape_shape)) @@ -1030,9 +1065,7 @@ def expand(self, x, shape): return self.get_proxy(ascend_op.Identity, (x, None)) if x.node.meta['val'].dtype == torch.int64: x = self.get_proxy(ascend_op.Cast, (x, "INT32")) - # shape = [dim.node.meta['val'] if hasattr( - # dim, 'node') else dim for dim in shape] - if isinstance(shape, list) and proxy_in_shape(shape): + if isinstance(shape, list) and not_all_num_shape(shape): preprocess_shape = self.process_dynamic_shape(shape) return self.get_proxy(ascend_op.Expand, (x, preprocess_shape)) else: From e6f4e43595c73ac88d65ec845a1c6f548599bb5e Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Wed, 24 Apr 2024 07:11:20 +0000 Subject: [PATCH 03/42] Remove redundant code. --- dicp/dicp/dynamo_bridge/compile_fx.py | 2 -- dicp/dicp/vendor/AscendGraph/conversion.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/dicp/dicp/dynamo_bridge/compile_fx.py b/dicp/dicp/dynamo_bridge/compile_fx.py index ba4a8f7c9..5eaee327c 100644 --- a/dicp/dicp/dynamo_bridge/compile_fx.py +++ b/dicp/dicp/dynamo_bridge/compile_fx.py @@ -47,8 +47,6 @@ def get_fake_mode_from_tensors(input_tensors): def used_nodes_all_symint(nodes): for node in nodes: - # if 'view' in str(node) or 'reshape' in str(node): - # import pdb; pdb.set_trace() if node.op == 'placeholder' and len(node.users) > 0: if hasattr(node, 'meta'): node = node.meta['val'] diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 882e40cc4..e3f1f386e 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -71,7 +71,7 @@ def register_conversion(aten_fn): class AtenToAscendTransformer(SingleOpTransformer): def __init__(self, gm): super().__init__(gm, conversions) - + def process_dynamic_shape(self, shape): x_names = [] From d2333fce47e1e1134cd8a3397e57ac0cef17422f Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Wed, 24 Apr 2024 12:11:31 +0000 Subject: [PATCH 04/42] Enable 70B get_qkv stage dynamic shape. --- dicp/dicp/dynamo_bridge/utils.py | 7 + .../dicp/vendor/AscendGraph/codegen/ascend.py | 15 ++ dicp/dicp/vendor/AscendGraph/conversion.py | 179 ++++++++++++++---- 3 files changed, 167 insertions(+), 34 deletions(-) diff --git a/dicp/dicp/dynamo_bridge/utils.py b/dicp/dicp/dynamo_bridge/utils.py index 050102ad4..46b6857d8 100644 --- a/dicp/dicp/dynamo_bridge/utils.py +++ b/dicp/dicp/dynamo_bridge/utils.py @@ -14,6 +14,13 @@ def symint_in_shape(shape): return False +def not_all_num_shape(shape): + for elem in shape: + if not isinstance(elem, int): + return True + return False + + def save_cpu_gm(gm: torch.fx.GraphModule, folder: str): Path(folder).mkdir(exist_ok=True) cpu_gm = copy_gm_to_cpu(gm) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 9b9fc24f4..b5b85bbd7 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -264,6 +264,21 @@ def process_sym_name(self, st): return self.sym_to_inputs[sp[0]] + '*' + sp[1] else: return self.process_sym_name(sp[0]) + '*' + sp[1] + elif '//' in st: + sp = st.strip('()').split('//') + if len(sp) > 2: + sp = [sp[0], '//'.join(sp[1:])] + assert (len(sp) == 2) + sp = [elem.strip() for elem in sp] + if sp[0].isdigit(): + (sp[1], sp[0]) = (sp[0], sp[1]) + if sp[0] in self.sym_in_args: + arg, idx = self.sym_in_args[sp[0]] + return "{}.shape[{}]".format(arg, idx) + '//' + sp[1] + if sp[0] in self.sym_to_inputs.keys(): + return self.sym_to_inputs[sp[0]] + '//' + sp[1] + else: + return self.process_sym_name(sp[0]) + '//' + sp[1] else: if st in self.sym_in_args: arg, idx = self.sym_in_args[st] diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 11e72be2e..adfce1c8c 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -14,7 +14,7 @@ from torch.fx.immutable_collections import immutable_list from torch._subclasses import FakeTensor import dicp.vendor.AscendGraph.ascend_op as ascend_op -from dicp.dynamo_bridge.utils import symint_in_shape +from dicp.dynamo_bridge.utils import symint_in_shape, not_all_num_shape from dicp.vendor.AscendGraph.codegen.utils import ( get_ascend_dtype, get_cpp_dtype @@ -79,14 +79,38 @@ def generate_digits_op(shapes): ascend_op.Const, (shapes, torch.int32, [len(shapes)])) x_names.append(const_op) - def generate_sym_int(elem): - elem = elem.node.str() - elems = elem.strip().split(' ') + def find_root_num(set_num, num): + while set_num[num] != num: + num = set_num[num] + return num + + def merge_disjoint_set(set_num, idx_a, idx_b): + root_a = find_root_num(set_num, idx_a) + root_b = find_root_num(set_num, idx_b) + # an example for (s5 / 8) - (s5 / 16) + # num: 0 1 2 3 + # step1 - > set_num: 0 1 2 3 + # step2 - > set_num: 0 0 2 2 + # step3 - > set_num: 0 0 0 0 + + # return merged set from root_b to root_a + return [root_a if find_root_num(set_num, s) == root_b else s for s in set_num] + + def replace_elem_proxy(elem_str): + # exit if already a proxy + if isinstance(elem_str, torch.fx.proxy.Proxy): + return elem_str + assert not elem_str in ['+', '-', '*', '//', '(', ')'] + + # handle with integer + if elem_str.isdigit(): + const_op = self.get_proxy( + ascend_op.Const, ([int(elem_str)], torch.int32, [1])) + return const_op - arg = None - # dynamic shape feature - if elems[0] in self.sym_in_args: - arg, idx = self.sym_in_args[elems[0]] + # handle if elem in shape of InputArgs + if elem_str in self.sym_in_args: + arg, idx = self.sym_in_args[elem_str] shape = self.get_proxy(ascend_op.Shape, (arg,)) axis = self.get_proxy( ascend_op.Const, ([0], torch.int32, [1])) @@ -94,50 +118,131 @@ def generate_sym_int(elem): ascend_op.Const, ([idx], torch.int32, [1])) gather = self.get_proxy( ascend_op.GatherV2, (shape, indice, axis)) + return gather + # handle if SymInt InputArg needed + return self.sym_to_inputs[elem_str] + + def generate_not_num(elem): + # situation for NodeProxy + if isinstance(elem, torch.fx.proxy.Proxy): + x_names.append(elem) + return + + elem_str = elem.node.str() + elem_str = elem_str.replace('+', ' + ') + elem_str = elem_str.replace('-', ' - ') + elem_str = elem_str.replace('*', ' * ') + elem_str = elem_str.replace('//', ' // ') + elem_str = elem_str.replace('(', ' ( ') + elem_str = elem_str.replace(')', ' ) ') + elems = elem_str.split(' ') + elems = [e for e in elems if e != ''] + + # dynamic shape feature if len(elems) > 1: - assert len(elems) == 3 - assert elems[2].isdigit() - assert elems[1] == '+' or elems[1] == '-' - const_op = self.get_proxy( - ascend_op.Const, ([int(elems[2])], torch.int32, [1])) - if arg is not None: - args = (gather, const_op) - else: - args = (self.sym_to_inputs[elems[0]], const_op) - if elems[1] == '+': - x_names.append(self.get_proxy(ascend_op.Add, args)) - else: - x_names.append(self.get_proxy(ascend_op.Sub, args)) + set_num = [] + priority = [] + nest = 0 + + # calculate priority for each operator + # set initial set number + for idx, e in enumerate(elems): + if e == '+' or e =='-': + priority.append(nest * 3 + 0) + elif e == '*' or e == '//': + priority.append(nest * 3 + 1) + else: + if e == '(': + nest += 1 + elif e == ')': + nest -= 1 + priority.append(-1) + + # init set number + if not e in ['+', '-', '*', '//', '(', ')']: + set_num.append(idx) + else: + set_num.append(-1) + + # start merge disjoint-set + if len(set_num) > 1: + while len(set(set_num)) > 2: + # seek the highest priority operator + max = -1 + m_idx = -1 + for idx, prio in enumerate(priority): + if prio > max: + max = prio + m_idx = idx + + # merge the highest priority two elements calculation + # find left & right element + left_idx = m_idx - 1 + while left_idx > 0 and str(elems[left_idx]) in ['(', ')']: + left_idx -= 1 + right_idx = m_idx + 1 + while right_idx < len(elems) - 1 and str(elems[right_idx]) in ['(', ')']: + right_idx += 1 + left_idx = find_root_num(set_num, set_num[left_idx]) + right_idx = find_root_num(set_num, set_num[right_idx]) + left_elem = replace_elem_proxy(elems[left_idx]) + right_elem = replace_elem_proxy(elems[right_idx]) + + # generate calculation operator + if elems[m_idx] == '+': + elems[left_idx] = self.get_proxy(ascend_op.Add, (left_elem, right_elem)) + elif elems[m_idx] == '-': + elems[left_idx] = self.get_proxy(ascend_op.Sub, (left_elem, right_elem)) + elif elems[m_idx] == '*': + elems[left_idx] = self.get_proxy(ascend_op.Mul, (left_elem, right_elem)) + else: + elems[left_idx] = self.get_proxy(ascend_op.Div, (left_elem, right_elem)) + + # merge set number and priority + set_num = merge_disjoint_set(set_num, left_idx, right_idx) + priority[m_idx] = -1 + + # add final element proxy + final_idx = 0 + while final_idx < len(elems) - 1 and str(elems[final_idx]) in ['(', ')']: + final_idx += 1 + final_elem = replace_elem_proxy(elems[final_idx]) + x_names.append(final_elem) else: - if arg is not None: - x_names.append(gather) - else: - x_names.append(self.sym_to_inputs[elems[0]]) + # only one not num element + node = replace_elem_proxy(elems[0]) + x_names.append(node) dims = [] for elem in shape: - if not isinstance(elem, torch.SymInt): + # process number + if isinstance(elem, int): dims.append(elem) continue - st = elem.node.str() + st = str(elem) if st.isdigit(): dims.append(int(st)) continue + # add number block if len(dims) > 0: generate_digits_op(dims) dims = [] - generate_sym_int(elem) + generate_not_num(elem) + + # add last number block if len(dims) > 0: generate_digits_op(dims) + # concat all ops return self.get_proxy(ascend_op.ConcatD, (x_names, 0)) def get_shape_proxy(self, shape): if isinstance(shape, torch.fx.proxy.Proxy) or isinstance(shape, FakeTensor): return shape - elif isinstance(shape, list) and symint_in_shape(shape): + elif isinstance(shape, list) and not_all_num_shape(shape): + # include both SymInt & NodeProxy return self.process_dynamic_shape(shape) else: return self.get_proxy( @@ -307,12 +412,16 @@ def inge(self, x, y): y = self.get_const_proxy(y, torch.int32) return self.get_proxy(ascend_op.GreaterEqual, (x, y)) - @register_conversion(aten.div) + @register_conversion([aten.div, _operator.floordiv]) def div(self, x, y): if isinstance(y, torch.fx.proxy.Proxy): return self.get_proxy(ascend_op.DivNoNan, (x, y)) assert y != 0 - out_dtype = fx_traceback.get_current_meta()['val'].dtype + out = fx_traceback.get_current_meta()['val'] + if not isinstance(out, torch.SymInt): + out_dtype = out.dtype + else: + out_dtype = torch.int32 y_op = self.get_const_proxy(y, out_dtype) return self.get_proxy(ascend_op.Div, (x, y_op), {}) @@ -332,10 +441,12 @@ def slice(self, x, dim=0, start=None, end=None, step=1): x_shape = list(x.node.meta['val'].shape) y_shape = list(fx_traceback.get_current_meta()['val'].shape) dim = int(dim) - start = int(start) if start is not None else 0 - start = start if start >= 0 else x_shape[dim] + start + if not isinstance(start, torch.fx.proxy.Proxy): + start = int(start) if start is not None else 0 + start = start if start >= 0 else x_shape[dim] + start + assert start is None or start >= 0 and start < x_shape[dim] + assert dim == -1 or dim >= 0 and dim < len(x_shape) - assert start is None or start >= 0 and start < x_shape[dim] offset = [0] * len(x_shape) offset[dim] = start offset = self.get_shape_proxy(offset) From 54167a9ea28ed69db2bcaf90d6bfad2db6f73174 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 25 Apr 2024 06:28:26 +0000 Subject: [PATCH 05/42] Fix complex size append. --- dicp/dicp/vendor/AscendGraph/conversion.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index e3f1f386e..c1ee13694 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -443,7 +443,9 @@ def view(self, x, size): shape = list(result_val.shape) if x.node.meta["val"].dtype == torch.complex64: shape.append(1) - size.append(1) + size_tmp = [s for s in size] + size_tmp.append(1) + size = immutable_list(size_tmp) numel = result_val.numel() neg = False for i in shape: @@ -463,7 +465,7 @@ def view(self, x, size): real_shape = [] for i in shape: - if not isinstance(i, torch.fx.proxy.Proxy): + if not isinstance(i, torch.SymInt): if i > 0: real_shape.append(str(i)) else: @@ -472,7 +474,7 @@ def view(self, x, size): raise RuntimeError( "cannot handle with both negative and symint!") shape = real_shape - else: + elif not_all_num_shape(shape): shape = size shape = self.get_shape_proxy(shape) if x.node.meta["val"].dtype == torch.complex64: From 3dec4354d60573bba049693dc0eb55f09042d4d9 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 25 Apr 2024 07:25:02 +0000 Subject: [PATCH 06/42] Change load_and_run in/out shape assignment. --- .../dicp/vendor/AscendGraph/codegen/ascend.py | 95 +++++-------------- 1 file changed, 26 insertions(+), 69 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index b5b85bbd7..46ea900e0 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -217,73 +217,19 @@ def check_tensor(a, b, atol=5e-2, rtol=1e-2): ) return self.import_code.getvalue() + def operator_in_str(self, st): + for op in ['+', '-', '*', '/']: + if op in st: + return True + return False + def process_sym_name(self, st): # dynamic shape feature - if st.isdigit(): - return st - elif '+' in st: - sp = st.split('+') - if len(sp) > 2: - sp = [sp[0], '+'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0].isdigit(): - (sp[1], sp[0]) = (sp[0], sp[1]) - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '+' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '+' + sp[1] - else: - return self.process_sym_name(sp[0]) + '+' + sp[1] - elif '-' in st: - sp = st.split('-') - if len(sp) > 2: - sp = [sp[0], '-'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '-' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '-' + sp[1] - else: - return self.process_sym_name(sp[0]) + '-' + sp[1] - elif '*' in st: - sp = st.split('*') - if len(sp) > 2: - sp = [sp[0], '*'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0].isdigit(): - (sp[1], sp[0]) = (sp[0], sp[1]) - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '*' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '*' + sp[1] - else: - return self.process_sym_name(sp[0]) + '*' + sp[1] - elif '//' in st: - sp = st.strip('()').split('//') - if len(sp) > 2: - sp = [sp[0], '//'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0].isdigit(): - (sp[1], sp[0]) = (sp[0], sp[1]) - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '//' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '//' + sp[1] - else: - return self.process_sym_name(sp[0]) + '//' + sp[1] - else: - if st in self.sym_in_args: - arg, idx = self.sym_in_args[st] - return "{}.shape[{}]".format(arg, idx) - return self.sym_to_inputs[st] + # return string wrapper in new version + # node.str() will not fallback SymInt value form + if isinstance(st, torch.SymInt): + return st.node.str() + return str(st) def gen_call_func(self): # TODO check scalar input @@ -293,9 +239,20 @@ def gen_call_func(self): # dynamic shape feature if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: + # import args needed for map assignment args = ['_' if not arg in shape_symint and not arg in self.sym_to_inputs.values() else arg for arg in self.args] call_body.writeline(f"({','.join(args)}) = args") + # assign SymInt to InputArgs relationship + if len(self.sym_in_args) > 0: + for key in self.sym_in_args.keys(): + if not key.isdigit() and not self.operator_in_str(key): + call_body.writeline(f"{key} = {self.sym_in_args[key][0]}.shape[{self.sym_in_args[key][1]}]") + if len(self.sym_to_inputs) > 0: + for key in self.sym_to_inputs.keys(): + if not key.isdigit() and not self.operator_in_str(key): + call_body.writeline(f"{key} = {self.sym_to_inputs[key]}") + # generate input dims if len(self.dynamic_inputs) > 0: dim_len = 0 @@ -328,7 +285,7 @@ def gen_call_func(self): shape = list(elem.shape) if len(shape) == 0: raise RuntimeError("Error handling empty output_shape") - shape = [self.process_sym_name(str(dim)) for dim in shape] + shape = [self.process_sym_name(dim) for dim in shape] shape_str += "[" + ','.join(map(str, shape)) + "]," # process output_shape with modified args @@ -336,12 +293,12 @@ def gen_call_func(self): shape = list(self.input_args[elem[1]].meta['val'].shape) if len(shape) == 0: raise RuntimeError("Error handling empty output_shape") - shape = [self.process_sym_name(str(dim)) for dim in shape] + shape = [self.process_sym_name(dim) for dim in shape] shape_str += "[" + ','.join(map(str, shape)) + "]," stride = list(self.input_args[elem[1]].meta['val'].stride()) if len(stride) == 0: raise RuntimeError("Error handling empty output_stride") - stride = [self.process_sym_name(str(dim)) for dim in stride] + stride = [self.process_sym_name(dim) for dim in stride] extra_stride_str += '[' + ','.join(map(str, stride)) + '],' extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ',' shape_str = shape_str[:-1] + f''']''' @@ -364,7 +321,7 @@ def gen_call_func(self): out_storage_offsets.append('0') continue stride = list(elem.stride()) - stride = [self.process_sym_name(str(dim)) for dim in stride] + stride = [self.process_sym_name(dim) for dim in stride] out_strides.append(str(stride)) out_storage_offsets.append(elem.storage_offset()) call_body.writeline(f'out_stride = {out_strides}') From 2666a3f7a9b4140f5de6e3fe2be94c36289bfd1f Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 25 Apr 2024 08:07:10 +0000 Subject: [PATCH 07/42] Refine variable replacement in in/out shape structure. --- .../dicp/vendor/AscendGraph/codegen/ascend.py | 80 ++++++------------- 1 file changed, 26 insertions(+), 54 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 9b9fc24f4..46ea900e0 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -217,58 +217,19 @@ def check_tensor(a, b, atol=5e-2, rtol=1e-2): ) return self.import_code.getvalue() + def operator_in_str(self, st): + for op in ['+', '-', '*', '/']: + if op in st: + return True + return False + def process_sym_name(self, st): # dynamic shape feature - if st.isdigit(): - return st - elif '+' in st: - sp = st.split('+') - if len(sp) > 2: - sp = [sp[0], '+'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0].isdigit(): - (sp[1], sp[0]) = (sp[0], sp[1]) - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '+' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '+' + sp[1] - else: - return self.process_sym_name(sp[0]) + '+' + sp[1] - elif '-' in st: - sp = st.split('-') - if len(sp) > 2: - sp = [sp[0], '-'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '-' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '-' + sp[1] - else: - return self.process_sym_name(sp[0]) + '-' + sp[1] - elif '*' in st: - sp = st.split('*') - if len(sp) > 2: - sp = [sp[0], '*'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0].isdigit(): - (sp[1], sp[0]) = (sp[0], sp[1]) - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '*' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '*' + sp[1] - else: - return self.process_sym_name(sp[0]) + '*' + sp[1] - else: - if st in self.sym_in_args: - arg, idx = self.sym_in_args[st] - return "{}.shape[{}]".format(arg, idx) - return self.sym_to_inputs[st] + # return string wrapper in new version + # node.str() will not fallback SymInt value form + if isinstance(st, torch.SymInt): + return st.node.str() + return str(st) def gen_call_func(self): # TODO check scalar input @@ -278,9 +239,20 @@ def gen_call_func(self): # dynamic shape feature if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: + # import args needed for map assignment args = ['_' if not arg in shape_symint and not arg in self.sym_to_inputs.values() else arg for arg in self.args] call_body.writeline(f"({','.join(args)}) = args") + # assign SymInt to InputArgs relationship + if len(self.sym_in_args) > 0: + for key in self.sym_in_args.keys(): + if not key.isdigit() and not self.operator_in_str(key): + call_body.writeline(f"{key} = {self.sym_in_args[key][0]}.shape[{self.sym_in_args[key][1]}]") + if len(self.sym_to_inputs) > 0: + for key in self.sym_to_inputs.keys(): + if not key.isdigit() and not self.operator_in_str(key): + call_body.writeline(f"{key} = {self.sym_to_inputs[key]}") + # generate input dims if len(self.dynamic_inputs) > 0: dim_len = 0 @@ -313,7 +285,7 @@ def gen_call_func(self): shape = list(elem.shape) if len(shape) == 0: raise RuntimeError("Error handling empty output_shape") - shape = [self.process_sym_name(str(dim)) for dim in shape] + shape = [self.process_sym_name(dim) for dim in shape] shape_str += "[" + ','.join(map(str, shape)) + "]," # process output_shape with modified args @@ -321,12 +293,12 @@ def gen_call_func(self): shape = list(self.input_args[elem[1]].meta['val'].shape) if len(shape) == 0: raise RuntimeError("Error handling empty output_shape") - shape = [self.process_sym_name(str(dim)) for dim in shape] + shape = [self.process_sym_name(dim) for dim in shape] shape_str += "[" + ','.join(map(str, shape)) + "]," stride = list(self.input_args[elem[1]].meta['val'].stride()) if len(stride) == 0: raise RuntimeError("Error handling empty output_stride") - stride = [self.process_sym_name(str(dim)) for dim in stride] + stride = [self.process_sym_name(dim) for dim in stride] extra_stride_str += '[' + ','.join(map(str, stride)) + '],' extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ',' shape_str = shape_str[:-1] + f''']''' @@ -349,7 +321,7 @@ def gen_call_func(self): out_storage_offsets.append('0') continue stride = list(elem.stride()) - stride = [self.process_sym_name(str(dim)) for dim in stride] + stride = [self.process_sym_name(dim) for dim in stride] out_strides.append(str(stride)) out_storage_offsets.append(elem.storage_offset()) call_body.writeline(f'out_stride = {out_strides}') From 0389a642d177608134a0481fa444f0a85183dccc Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 25 Apr 2024 08:52:23 +0000 Subject: [PATCH 08/42] Fix merge bugs. --- dicp/dicp/dynamo_bridge/utils.py | 12 ++++--- .../dicp/vendor/AscendGraph/codegen/ascend.py | 20 ++++------- dicp/dicp/vendor/AscendGraph/conversion.py | 33 +++++++++++++++++-- 3 files changed, 43 insertions(+), 22 deletions(-) diff --git a/dicp/dicp/dynamo_bridge/utils.py b/dicp/dicp/dynamo_bridge/utils.py index c64d61a9f..c241227c0 100644 --- a/dicp/dicp/dynamo_bridge/utils.py +++ b/dicp/dicp/dynamo_bridge/utils.py @@ -21,11 +21,13 @@ def symint_in_shape(shape): return False -def not_all_num_shape(shape): - for elem in shape: - if not isinstance(elem, int): - return True - return False +def process_sym_name(st): + # dynamic shape feature + # return string wrapper in new version + # node.str() will not fallback SymInt value form + if isinstance(st, torch.SymInt): + return st.node.str() + return str(st) def save_cpu_gm(gm: torch.fx.GraphModule, folder: str): diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 46ea900e0..aa4dd5640 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -5,7 +5,7 @@ from typing import Any, List from torch.fx.node import Node from torch._inductor.utils import IndentedBuffer -from dicp.dynamo_bridge.utils import symint_in_shape +from dicp.dynamo_bridge.utils import symint_in_shape, process_sym_name from dicp.vendor.AscendGraph.codegen.utils import ( get_ascend_dtype, get_cpp_dtype, @@ -223,14 +223,6 @@ def operator_in_str(self, st): return True return False - def process_sym_name(self, st): - # dynamic shape feature - # return string wrapper in new version - # node.str() will not fallback SymInt value form - if isinstance(st, torch.SymInt): - return st.node.str() - return str(st) - def gen_call_func(self): # TODO check scalar input call_body = IndentedBuffer() @@ -262,7 +254,7 @@ def gen_call_func(self): for idx, elem in enumerate(self.actual_shape): if len(elem) == 0: continue - elem = [self.process_sym_name(dim) for dim in elem] + elem = [process_sym_name(dim) for dim in elem] dims += str(self.dynamic_index[idx]) + \ ":[" + ','.join(map(str, elem)) + '],' dims = dims[:-1] + '}' @@ -285,7 +277,7 @@ def gen_call_func(self): shape = list(elem.shape) if len(shape) == 0: raise RuntimeError("Error handling empty output_shape") - shape = [self.process_sym_name(dim) for dim in shape] + shape = [process_sym_name(dim) for dim in shape] shape_str += "[" + ','.join(map(str, shape)) + "]," # process output_shape with modified args @@ -293,12 +285,12 @@ def gen_call_func(self): shape = list(self.input_args[elem[1]].meta['val'].shape) if len(shape) == 0: raise RuntimeError("Error handling empty output_shape") - shape = [self.process_sym_name(dim) for dim in shape] + shape = [process_sym_name(dim) for dim in shape] shape_str += "[" + ','.join(map(str, shape)) + "]," stride = list(self.input_args[elem[1]].meta['val'].stride()) if len(stride) == 0: raise RuntimeError("Error handling empty output_stride") - stride = [self.process_sym_name(dim) for dim in stride] + stride = [process_sym_name(dim) for dim in stride] extra_stride_str += '[' + ','.join(map(str, stride)) + '],' extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ',' shape_str = shape_str[:-1] + f''']''' @@ -321,7 +313,7 @@ def gen_call_func(self): out_storage_offsets.append('0') continue stride = list(elem.stride()) - stride = [self.process_sym_name(dim) for dim in stride] + stride = [process_sym_name(dim) for dim in stride] out_strides.append(str(stride)) out_storage_offsets.append(elem.storage_offset()) call_body.writeline(f'out_stride = {out_strides}') diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 3f7267df0..c9406dd88 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -15,7 +15,7 @@ from torch.fx.immutable_collections import immutable_list from torch._subclasses import FakeTensor import dicp.vendor.AscendGraph.ascend_op as ascend_op -from dicp.dynamo_bridge.utils import symint_in_shape, not_all_num_shape +from dicp.dynamo_bridge.utils import symint_in_shape, not_all_num_shape, process_sym_name from dicp.vendor.AscendGraph.codegen.utils import ( get_ascend_dtype, get_cpp_dtype @@ -109,6 +109,22 @@ def replace_elem_proxy(elem_str): ascend_op.Const, ([int(elem_str)], torch.int32, [1])) return const_op + # handle with NodeProxy string + if 'Proxy' in elem_str: + # recover '()' from '[]' + elem_str = elem_str.replace('[', '(') + elem_str = elem_str.replace(']', ')') + + # search & replace + replace_proxy = None + arg_symint_candidate = [value[0] for value in self.sym_in_args.values()] + list(self.sym_to_inputs.values()) + for convert in arg_symint_candidate: + if elem_str == str(convert): + replace_proxy = convert + break + assert replace_proxy is not None + return replace_proxy + # handle if elem in shape of InputArgs if elem_str in self.sym_in_args: arg, idx = self.sym_in_args[elem_str] @@ -125,12 +141,23 @@ def replace_elem_proxy(elem_str): return self.sym_to_inputs[elem_str] def generate_not_num(elem): + # dynamic shape feature # situation for NodeProxy if isinstance(elem, torch.fx.proxy.Proxy): x_names.append(elem) return - elem_str = elem.node.str() + # string form of NodeProxy, convert it + if isinstance(elem, str) and 'Proxy' in elem: + # special case handling '()' in NodeProxy string + # '[]' will not mixed with expression calculation priority + elem = elem.replace('(', '[') + elem = elem.replace(')', ']') + elif not isinstance(elem, torch.SymInt): + raise RuntimeError("Not num objects only include SymInt or NodeProxy!") + + # case for NodeProxy string or SymInt + elem_str = process_sym_name(elem) elem_str = elem_str.replace('+', ' + ') elem_str = elem_str.replace('-', ' - ') elem_str = elem_str.replace('*', ' * ') @@ -140,7 +167,7 @@ def generate_not_num(elem): elems = elem_str.split(' ') elems = [e for e in elems if e != ''] - # dynamic shape feature + # prepare for expression calculation if len(elems) > 1: set_num = [] priority = [] From 5a2fd6ab1f233b8dac90e82290666ef670da87a6 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 25 Apr 2024 09:14:32 +0000 Subject: [PATCH 09/42] Change one comment and variable name. --- dicp/dicp/vendor/AscendGraph/conversion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 74395e192..cd5ba91c5 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -117,9 +117,9 @@ def replace_elem_proxy(elem_str): # search & replace replace_proxy = None arg_symint_candidate = [value[0] for value in self.sym_in_args.values()] + list(self.sym_to_inputs.values()) - for convert in arg_symint_candidate: - if elem_str == str(convert): - replace_proxy = convert + for arg_sym in arg_symint_candidate: + if elem_str == str(arg_sym): + replace_proxy = arg_sym break assert replace_proxy is not None return replace_proxy @@ -146,7 +146,7 @@ def generate_not_num(elem): x_names.append(elem) return - # string form of NodeProxy, convert it + # string form of NodeProxy if isinstance(elem, str) and 'Proxy' in elem: # special case handling '()' in NodeProxy string # '[]' will not mixed with expression calculation priority From 03ba1a4ed4aa1045c3cff95273b70e981ef3e7f5 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 25 Apr 2024 10:47:21 +0000 Subject: [PATCH 10/42] Fix an array assignment change. --- dicp/dicp/vendor/AscendGraph/conversion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index cd5ba91c5..e0d8eadc6 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -585,8 +585,7 @@ def view(self, x, size): shape = list(result_val.shape) if x.node.meta["val"].dtype == torch.complex64: shape.append(1) - size_tmp = [s for s in size] - size_tmp.append(1) + size_tmp = [s for s in size] + [1] size = immutable_list(size_tmp) numel = result_val.numel() neg = False From 71f6c612c290730b24f74480c2cf9ab74dc048f3 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 26 Apr 2024 03:33:53 +0000 Subject: [PATCH 11/42] Code refinement including: 1.Remove redundant Cast operator. 2.Change logic of Expand shape proxy. 3.Merge output stride executing path. --- .../dicp/vendor/AscendGraph/codegen/ascend.py | 12 +++++----- dicp/dicp/vendor/AscendGraph/conversion.py | 22 +++++++++++++------ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 3e77397ad..5c176e99e 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -305,11 +305,9 @@ def gen_call_func(self): for elem in self.output_args: if hasattr(elem, 'meta'): elem = elem.meta['val'] - if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool): - out_strides.append('[1]') - out_storage_offsets.append('0') - continue - if elem.dim() == 0: # temporary solution for sum.default(a) whose result is a scalar(no dim no stride) + + # temporary solution for sum.default(a) whose result is scalar or with no dim no stride + if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool) or elem.dim() == 0: out_strides.append('[1]') out_storage_offsets.append('0') continue @@ -320,6 +318,10 @@ def gen_call_func(self): call_body.writeline(f'out_stride = {out_strides}') call_body.writeline(f'out_storage_offset = {out_storage_offsets}') + # In precision debug mode, modified array recording InputArgs integer needed + if precision_check and self.aten_graph is not None: + call_body.writeline(f"modified = [idx for idx in range(len(args))] if isinstance(args[idx], int)") + call_body.splice(""" import torch_dipu dipu_device_str = torch_dipu.dipu.device.__diputype__ diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index e0d8eadc6..8b11aab14 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -716,7 +716,8 @@ def nll_loss_forward(self, x, target, weight, reduction, ignore_index): assert ignore_index == -100 reduction_str = get_reduction_str(reduction) csize = [list(x.node.meta['val'].shape)[1]] - target = self.get_proxy(ascend_op.Cast, (target, "INT32")) + if target.node.meta['val'].dtype != torch.int32: + target = self.get_proxy(ascend_op.Cast, (target, "INT32")) weight = self.get_proxy(ascend_op.FillV2D, (1.0, csize)) return self.get_proxy(ascend_op.NLLLoss, (x, target, weight, reduction_str, ignore_index)) @@ -726,7 +727,8 @@ def nll_loss_backward(self, grad_output, x, target, weight, reduction, ignore_in assert ignore_index == -100 reduction_str = get_reduction_str(reduction) csize = [list(x.node.meta['val'].shape)[1]] - target = self.get_proxy(ascend_op.Cast, (target, "INT32")) + if target.node.meta['val'].dtype != torch.int32: + target = self.get_proxy(ascend_op.Cast, (target, "INT32")) weight = self.get_proxy(ascend_op.FillV2D, (1.0, csize)) return self.get_proxy(ascend_op.NLLLossGrad, (x, grad_output, target, weight, total_weight, @@ -1219,11 +1221,12 @@ def expand(self, x, shape): return self.get_proxy(ascend_op.Identity, (x, None)) if x.node.meta['val'].dtype == torch.int64: x = self.get_proxy(ascend_op.Cast, (x, "INT32")) - if isinstance(shape, list) and not_all_num_shape(shape): - preprocess_shape = self.process_dynamic_shape(shape) - return self.get_proxy(ascend_op.Expand, (x, preprocess_shape)) else: - return self.get_proxy(ascend_op.ExpandD, (x, shape)) + # check situation other than integer type + assert x.node.meta['val'].dtype == torch.int32 + + shape = self.get_shape_proxy(shape) + return self.get_proxy(ascend_op.Expand, (x, shape)) @register_conversion(torch.ops.aten.slice_backward.default) def slice_backward(self, grad, input_shape, dim, start, end, step): @@ -1349,12 +1352,16 @@ def mm(self, x, y): y = self.get_proxy_from_node(y.node.args[0]) trans_y = True mm = self.get_proxy(ascend_op.MatMul, (x, y, trans_x, trans_y)) + + # TODO! complicated logic in MatMul output dtype return self.get_proxy(ascend_op.Cast, (mm, get_ascend_dtype(out_dtype))) @register_conversion(aten.bmm.default) def bmm(self, x, y): out_dtype = fx_traceback.get_current_meta()['val'].dtype bmm = self.get_proxy(ascend_op.BatchMatMul, (x, y, False, False, sd_fp16 ^ 1)) + + # TODO! complicated logic in BatchMatMul output dtype return self.get_proxy(ascend_op.Cast, (bmm, get_ascend_dtype(out_dtype))) @register_conversion(torch.torch.ops.aten.addmm) @@ -1565,7 +1572,8 @@ def RandLike(self, x, dtype=torch.float32, layout=torch.strided, def GtScalar(self, x, y): dtype = get_ascend_dtype(x.node.meta['val'].dtype) scalar_op = self.get_const_proxy(float(y), torch.float32) - cast_op = self.get_proxy(ascend_op.Cast, (scalar_op, dtype)) + if dtype != torch.float32: + cast_op = self.get_proxy(ascend_op.Cast, (scalar_op, dtype)) return self.get_proxy(ascend_op.Greater, (x, cast_op)) @register_conversion(torch.ops.aten.addcmul.default) From 25c5c5694d99448cb498726f4ef093bb08ef41ee Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 26 Apr 2024 03:47:17 +0000 Subject: [PATCH 12/42] Get clear idea for expand Cast situation. --- dicp/dicp/vendor/AscendGraph/conversion.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 8b11aab14..b941b267d 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1219,12 +1219,10 @@ def expand(self, x, shape): y_shape = list(fx_traceback.get_current_meta()['val'].shape) if x_shape == y_shape: return self.get_proxy(ascend_op.Identity, (x, None)) + + # Cast needed only when x_dtype is int64 if x.node.meta['val'].dtype == torch.int64: x = self.get_proxy(ascend_op.Cast, (x, "INT32")) - else: - # check situation other than integer type - assert x.node.meta['val'].dtype == torch.int32 - shape = self.get_shape_proxy(shape) return self.get_proxy(ascend_op.Expand, (x, shape)) From 90250199087be2f1116145cf1b96e25a1d06b996 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 26 Apr 2024 10:48:26 +0000 Subject: [PATCH 13/42] Apply some idea from Gpt AI. --- .../dicp/vendor/AscendGraph/codegen/ascend.py | 166 +++++++++--------- dicp/dicp/vendor/AscendGraph/conversion.py | 88 +++++----- 2 files changed, 127 insertions(+), 127 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 5c176e99e..33f276a8a 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -220,108 +220,104 @@ def check_tensor(a, b, atol=5e-2, rtol=1e-2): return self.import_code.getvalue() def operator_in_str(self, st): - for op in ['+', '-', '*', '/']: - if op in st: - return True - return False + return any(op in st for op in ['+', '-', '*', '/']) def gen_call_func(self): + def _generate_output_shapes(call_body): + # dynamic shape feature + # generate output shapes + extra_stride_str = '' + extra_storage_offset_str = '' + if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: + # process if has dynamic shape + shape_str = '''output_shape = [''' + for elem in self.output_args: + if hasattr(elem, 'meta'): + elem = elem.meta['val'] + if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool): + shape_str += '[1],' + continue + shape = list(elem.shape) + if len(shape) == 0: + raise RuntimeError("Error handling empty output_shape") + shape = [process_sym_name(dim) for dim in shape] + shape_str += "[" + ','.join(map(str, shape)) + "]," + + # process output_shape with modified args + for elem in self.assign_args: + shape = list(self.input_args[elem[1]].meta['val'].shape) + if len(shape) == 0: + raise RuntimeError("Error handling empty output_shape") + shape = [process_sym_name(dim) for dim in shape] + shape_str += "[" + ','.join(map(str, shape)) + "]," + stride = list(self.input_args[elem[1]].meta['val'].stride()) + if len(stride) == 0: + raise RuntimeError("Error handling empty output_stride") + stride = [process_sym_name(dim) for dim in stride] + extra_stride_str += '[' + ','.join(map(str, stride)) + '],' + extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ',' + shape_str = shape_str[:-1] + ''']''' + call_body.writeline(shape_str) + else: + call_body.writeline('''output_shape = None''') + + def _handle_output_strides_and_offsets(call_body): + # add stride & storage_offset info + out_strides = [] + out_storage_offsets = [] + for elem in self.output_args: + if hasattr(elem, 'meta'): + elem = elem.meta['val'] + + # temporary solution for sum.default(a) whose result is scalar or with no dim no stride + if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool) or elem.dim() == 0: + out_strides.append('[1]') + out_storage_offsets.append('0') + continue + stride = list(elem.stride()) + stride = [process_sym_name(dim) for dim in stride] + out_strides.append(str(stride)) + out_storage_offsets.append(elem.storage_offset()) + call_body.writeline(f'out_stride = {out_strides}') + call_body.writeline(f'out_storage_offset = {out_storage_offsets}') + # TODO check scalar input call_body = IndentedBuffer() self.args = [self.args_dict[x.name] for x in self.input_args] shape_symint = [value[0] for value in self.sym_in_args.values()] # dynamic shape feature - if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: - args = ['_' if arg not in shape_symint and arg not in self.sym_to_inputs.values() else arg for arg in self.args] - call_body.writeline(f"({','.join(args)}) = args") - - # assign SymInt to InputArgs relationship - if len(self.sym_in_args) > 0: - for key in self.sym_in_args.keys(): - if not key.isdigit() and not self.operator_in_str(key): - call_body.writeline(f"{key} = {self.sym_in_args[key][0]}.shape[{self.sym_in_args[key][1]}]") - if len(self.sym_to_inputs) > 0: - for key in self.sym_to_inputs.keys(): - if not key.isdigit() and not self.operator_in_str(key): - call_body.writeline(f"{key} = {self.sym_to_inputs[key]}") + # assign SymInt to InputArgs relationship + args_condition = lambda arg: '_' if arg not in shape_symint and arg not in self.sym_to_inputs.values() else arg + args = map(args_condition, self.args) + call_body.writeline(f"({','.join(args)}) = args") + + # combine dicts for simplicity + # cover sym_to_inputs for higher priority to sym_in_args + sym_keys = {**self.sym_to_inputs, **self.sym_in_args} + for key, val in sym_keys.items(): + # skip if key is a digit or contains an operator + if key.isdigit() or self.operator_in_str(key): + continue + shape = f"{val[0]}.shape[{val[1]}]" if key in self.sym_in_args.keys() else val + call_body.writeline(f"{key} = {shape}") # generate input dims - if len(self.dynamic_inputs) > 0: - dim_len = 0 - for shape in self.actual_shape: - dim_len += len(shape) - dims = 'dims = {' - for idx, elem in enumerate(self.actual_shape): - if len(elem) == 0: - continue - elem = [process_sym_name(dim) for dim in elem] - dims += str(self.dynamic_index[idx]) + \ - ":[" + ','.join(map(str, elem)) + '],' - dims = dims[:-1] + '}' - call_body.writeline(dims) + if self.dynamic_inputs: + dims = {self.dynamic_index[idx]: [process_sym_name(dim) for dim in elem] + for idx, elem in enumerate(self.actual_shape) if len(elem) > 0} + call_body.writeline(f"dims = {dims}".replace("'","")) else: - call_body.writeline('''dims = None''') + call_body.writeline("dims = None") - # generate output shapes - # dynamic shape feature - extra_stride_str = '' - extra_storage_offset_str = '' - if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: - shape_str = '''output_shape = [''' - for elem in self.output_args: - if hasattr(elem, 'meta'): - elem = elem.meta['val'] - if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool): - shape_str += '[1],' - continue - shape = list(elem.shape) - if len(shape) == 0: - raise RuntimeError("Error handling empty output_shape") - shape = [process_sym_name(dim) for dim in shape] - shape_str += "[" + ','.join(map(str, shape)) + "]," - - # process output_shape with modified args - for elem in self.assign_args: - shape = list(self.input_args[elem[1]].meta['val'].shape) - if len(shape) == 0: - raise RuntimeError("Error handling empty output_shape") - shape = [process_sym_name(dim) for dim in shape] - shape_str += "[" + ','.join(map(str, shape)) + "]," - stride = list(self.input_args[elem[1]].meta['val'].stride()) - if len(stride) == 0: - raise RuntimeError("Error handling empty output_stride") - stride = [process_sym_name(dim) for dim in stride] - extra_stride_str += '[' + ','.join(map(str, stride)) + '],' - extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ',' - shape_str = shape_str[:-1] + ''']''' - call_body.writeline(shape_str) - else: - call_body.writeline('''output_shape = None''') - - # add stride & storage_offset info - out_strides = [] - out_storage_offsets = [] - for elem in self.output_args: - if hasattr(elem, 'meta'): - elem = elem.meta['val'] - - # temporary solution for sum.default(a) whose result is scalar or with no dim no stride - if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool) or elem.dim() == 0: - out_strides.append('[1]') - out_storage_offsets.append('0') - continue - stride = list(elem.stride()) - stride = [process_sym_name(dim) for dim in stride] - out_strides.append(str(stride)) - out_storage_offsets.append(elem.storage_offset()) - call_body.writeline(f'out_stride = {out_strides}') - call_body.writeline(f'out_storage_offset = {out_storage_offsets}') + _generate_output_shapes(call_body) + _handle_output_strides_and_offsets(call_body) # In precision debug mode, modified array recording InputArgs integer needed if precision_check and self.aten_graph is not None: call_body.writeline(f"modified = [idx for idx in range(len(args))] if isinstance(args[idx], int)") + # Conversion for integer args call_body.splice(""" import torch_dipu dipu_device_str = torch_dipu.dipu.device.__diputype__ diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index b941b267d..ce4e82f9d 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -75,8 +75,7 @@ def process_dynamic_shape(self, shape): x_names = [] def generate_digits_op(shapes): - const_op = self.get_proxy( - ascend_op.Const, (shapes, torch.int32, [len(shapes)])) + const_op = self.get_const_proxy(shapes, torch.int32) x_names.append(const_op) def find_root_num(set_num, num): @@ -104,8 +103,7 @@ def replace_elem_proxy(elem_str): # handle with integer if elem_str.isdigit(): - const_op = self.get_proxy( - ascend_op.Const, ([int(elem_str)], torch.int32, [1])) + const_op = self.get_const_proxy(int(elem_str), torch.int32) return const_op # handle with NodeProxy string @@ -128,10 +126,8 @@ def replace_elem_proxy(elem_str): if elem_str in self.sym_in_args: arg, idx = self.sym_in_args[elem_str] shape = self.get_proxy(ascend_op.Shape, (arg,)) - axis = self.get_proxy( - ascend_op.Const, ([0], torch.int32, [1])) - indice = self.get_proxy( - ascend_op.Const, ([idx], torch.int32, [1])) + axis = self.get_const_proxy(0, torch.int32) + indice = self.get_const_proxy(idx, torch.int32) gather = self.get_proxy( ascend_op.GatherV2, (shape, indice, axis)) return gather @@ -140,34 +136,31 @@ def replace_elem_proxy(elem_str): return self.sym_to_inputs[elem_str] def generate_not_num(elem): - # dynamic shape feature - # situation for NodeProxy - if isinstance(elem, torch.fx.proxy.Proxy): - x_names.append(elem) - return - - # string form of NodeProxy - if isinstance(elem, str) and 'Proxy' in elem: - # special case handling '()' in NodeProxy string - # '[]' will not mixed with expression calculation priority - elem = elem.replace('(', '[') - elem = elem.replace(')', ']') - elif not isinstance(elem, torch.SymInt): - raise RuntimeError("Not num objects only include SymInt or NodeProxy!") - - # case for NodeProxy string or SymInt - elem_str = process_sym_name(elem) - elem_str = elem_str.replace('+', ' + ') - elem_str = elem_str.replace('-', ' - ') - elem_str = elem_str.replace('*', ' * ') - elem_str = elem_str.replace('//', ' // ') - elem_str = elem_str.replace('(', ' ( ') - elem_str = elem_str.replace(')', ' ) ') - elems = elem_str.split(' ') - elems = [e for e in elems if e != ''] - - # prepare for expression calculation - if len(elems) > 1: + def _init_stage(elem): + # string form of NodeProxy + if isinstance(elem, str) and 'Proxy' in elem: + # special case handling '()' in NodeProxy string + # '[]' will not mixed with expression calculation priority + elem = elem.replace('(', '[') + elem = elem.replace(')', ']') + elif not isinstance(elem, torch.SymInt): + raise RuntimeError("Not num objects only include SymInt or NodeProxy!") + + # case for NodeProxy string or SymInt + replacements = { + '+': ' + ', + '-': ' - ', + '*': ' * ', + '//': ' // ', + '(': ' ( ', + ')': ' ) ' + } + elem_str = ''.join(replacements.get(c, c) for c in elem) + elems = [e for e in elem_str.split(' ') if e != ''] + return elems + + def _prepare_for_calc(elems): + # prepare for expression calculation set_num = [] priority = [] nest = 0 @@ -191,7 +184,9 @@ def generate_not_num(elem): set_num.append(idx) else: set_num.append(-1) + return set_num, priority + def _merge_disjoint_set_stage(elems, set_num, priority): # start merge disjoint-set if len(set_num) > 1: while len(set(set_num)) > 2: @@ -229,17 +224,27 @@ def generate_not_num(elem): # merge set number and priority set_num = merge_disjoint_set(set_num, left_idx, right_idx) priority[m_idx] = -1 + return elems + def _final_stage(elems): # add final element proxy final_idx = 0 while final_idx < len(elems) - 1 and str(elems[final_idx]) in ['(', ')']: final_idx += 1 final_elem = replace_elem_proxy(elems[final_idx]) x_names.append(final_elem) - else: - # only one not num element - node = replace_elem_proxy(elems[0]) - x_names.append(node) + + # dynamic shape feature + # situation for NodeProxy + if isinstance(elem, torch.fx.proxy.Proxy): + x_names.append(elem) + return + + # four stage splits + elems = _init_stage(elem) + set_num, priority = _prepare_for_calc(elems) + elems = _merge_disjoint_set_stage(elems, set_num, priority) + _final_stage(elems) dims = [] for elem in shape: @@ -293,8 +298,7 @@ def symint_to_str(shape): # both fit for SymInt & NodeProxy, pass all number cases if not_all_num_shape(shape): return self.process_dynamic_shape(shape) - return self.get_proxy( - ascend_op.Const, (shape, dtype, [len(shape)])) + return self.get_const_proxy(shape, dtype) def get_const_proxy(self, param, dtype, format=None, target_shape=None): if not isinstance(param, torch.fx.proxy.Proxy) and not isinstance(param, FakeTensor): From 62a6b36b250e18fb83072de74d7583fc6b2c98ff Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 26 Apr 2024 11:07:27 +0000 Subject: [PATCH 14/42] Revert "Apply some idea from Gpt AI." This reverts commit 90250199087be2f1116145cf1b96e25a1d06b996. --- .../dicp/vendor/AscendGraph/codegen/ascend.py | 166 +++++++++--------- dicp/dicp/vendor/AscendGraph/conversion.py | 88 +++++----- 2 files changed, 127 insertions(+), 127 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 33f276a8a..5c176e99e 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -220,104 +220,108 @@ def check_tensor(a, b, atol=5e-2, rtol=1e-2): return self.import_code.getvalue() def operator_in_str(self, st): - return any(op in st for op in ['+', '-', '*', '/']) + for op in ['+', '-', '*', '/']: + if op in st: + return True + return False def gen_call_func(self): - def _generate_output_shapes(call_body): - # dynamic shape feature - # generate output shapes - extra_stride_str = '' - extra_storage_offset_str = '' - if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: - # process if has dynamic shape - shape_str = '''output_shape = [''' - for elem in self.output_args: - if hasattr(elem, 'meta'): - elem = elem.meta['val'] - if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool): - shape_str += '[1],' - continue - shape = list(elem.shape) - if len(shape) == 0: - raise RuntimeError("Error handling empty output_shape") - shape = [process_sym_name(dim) for dim in shape] - shape_str += "[" + ','.join(map(str, shape)) + "]," - - # process output_shape with modified args - for elem in self.assign_args: - shape = list(self.input_args[elem[1]].meta['val'].shape) - if len(shape) == 0: - raise RuntimeError("Error handling empty output_shape") - shape = [process_sym_name(dim) for dim in shape] - shape_str += "[" + ','.join(map(str, shape)) + "]," - stride = list(self.input_args[elem[1]].meta['val'].stride()) - if len(stride) == 0: - raise RuntimeError("Error handling empty output_stride") - stride = [process_sym_name(dim) for dim in stride] - extra_stride_str += '[' + ','.join(map(str, stride)) + '],' - extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ',' - shape_str = shape_str[:-1] + ''']''' - call_body.writeline(shape_str) - else: - call_body.writeline('''output_shape = None''') - - def _handle_output_strides_and_offsets(call_body): - # add stride & storage_offset info - out_strides = [] - out_storage_offsets = [] - for elem in self.output_args: - if hasattr(elem, 'meta'): - elem = elem.meta['val'] - - # temporary solution for sum.default(a) whose result is scalar or with no dim no stride - if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool) or elem.dim() == 0: - out_strides.append('[1]') - out_storage_offsets.append('0') - continue - stride = list(elem.stride()) - stride = [process_sym_name(dim) for dim in stride] - out_strides.append(str(stride)) - out_storage_offsets.append(elem.storage_offset()) - call_body.writeline(f'out_stride = {out_strides}') - call_body.writeline(f'out_storage_offset = {out_storage_offsets}') - # TODO check scalar input call_body = IndentedBuffer() self.args = [self.args_dict[x.name] for x in self.input_args] shape_symint = [value[0] for value in self.sym_in_args.values()] # dynamic shape feature - # assign SymInt to InputArgs relationship - args_condition = lambda arg: '_' if arg not in shape_symint and arg not in self.sym_to_inputs.values() else arg - args = map(args_condition, self.args) - call_body.writeline(f"({','.join(args)}) = args") - - # combine dicts for simplicity - # cover sym_to_inputs for higher priority to sym_in_args - sym_keys = {**self.sym_to_inputs, **self.sym_in_args} - for key, val in sym_keys.items(): - # skip if key is a digit or contains an operator - if key.isdigit() or self.operator_in_str(key): - continue - shape = f"{val[0]}.shape[{val[1]}]" if key in self.sym_in_args.keys() else val - call_body.writeline(f"{key} = {shape}") + if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: + args = ['_' if arg not in shape_symint and arg not in self.sym_to_inputs.values() else arg for arg in self.args] + call_body.writeline(f"({','.join(args)}) = args") + + # assign SymInt to InputArgs relationship + if len(self.sym_in_args) > 0: + for key in self.sym_in_args.keys(): + if not key.isdigit() and not self.operator_in_str(key): + call_body.writeline(f"{key} = {self.sym_in_args[key][0]}.shape[{self.sym_in_args[key][1]}]") + if len(self.sym_to_inputs) > 0: + for key in self.sym_to_inputs.keys(): + if not key.isdigit() and not self.operator_in_str(key): + call_body.writeline(f"{key} = {self.sym_to_inputs[key]}") # generate input dims - if self.dynamic_inputs: - dims = {self.dynamic_index[idx]: [process_sym_name(dim) for dim in elem] - for idx, elem in enumerate(self.actual_shape) if len(elem) > 0} - call_body.writeline(f"dims = {dims}".replace("'","")) + if len(self.dynamic_inputs) > 0: + dim_len = 0 + for shape in self.actual_shape: + dim_len += len(shape) + dims = 'dims = {' + for idx, elem in enumerate(self.actual_shape): + if len(elem) == 0: + continue + elem = [process_sym_name(dim) for dim in elem] + dims += str(self.dynamic_index[idx]) + \ + ":[" + ','.join(map(str, elem)) + '],' + dims = dims[:-1] + '}' + call_body.writeline(dims) else: - call_body.writeline("dims = None") + call_body.writeline('''dims = None''') - _generate_output_shapes(call_body) - _handle_output_strides_and_offsets(call_body) + # generate output shapes + # dynamic shape feature + extra_stride_str = '' + extra_storage_offset_str = '' + if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: + shape_str = '''output_shape = [''' + for elem in self.output_args: + if hasattr(elem, 'meta'): + elem = elem.meta['val'] + if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool): + shape_str += '[1],' + continue + shape = list(elem.shape) + if len(shape) == 0: + raise RuntimeError("Error handling empty output_shape") + shape = [process_sym_name(dim) for dim in shape] + shape_str += "[" + ','.join(map(str, shape)) + "]," + + # process output_shape with modified args + for elem in self.assign_args: + shape = list(self.input_args[elem[1]].meta['val'].shape) + if len(shape) == 0: + raise RuntimeError("Error handling empty output_shape") + shape = [process_sym_name(dim) for dim in shape] + shape_str += "[" + ','.join(map(str, shape)) + "]," + stride = list(self.input_args[elem[1]].meta['val'].stride()) + if len(stride) == 0: + raise RuntimeError("Error handling empty output_stride") + stride = [process_sym_name(dim) for dim in stride] + extra_stride_str += '[' + ','.join(map(str, stride)) + '],' + extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ',' + shape_str = shape_str[:-1] + ''']''' + call_body.writeline(shape_str) + else: + call_body.writeline('''output_shape = None''') + + # add stride & storage_offset info + out_strides = [] + out_storage_offsets = [] + for elem in self.output_args: + if hasattr(elem, 'meta'): + elem = elem.meta['val'] + + # temporary solution for sum.default(a) whose result is scalar or with no dim no stride + if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool) or elem.dim() == 0: + out_strides.append('[1]') + out_storage_offsets.append('0') + continue + stride = list(elem.stride()) + stride = [process_sym_name(dim) for dim in stride] + out_strides.append(str(stride)) + out_storage_offsets.append(elem.storage_offset()) + call_body.writeline(f'out_stride = {out_strides}') + call_body.writeline(f'out_storage_offset = {out_storage_offsets}') # In precision debug mode, modified array recording InputArgs integer needed if precision_check and self.aten_graph is not None: call_body.writeline(f"modified = [idx for idx in range(len(args))] if isinstance(args[idx], int)") - # Conversion for integer args call_body.splice(""" import torch_dipu dipu_device_str = torch_dipu.dipu.device.__diputype__ diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index ce4e82f9d..b941b267d 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -75,7 +75,8 @@ def process_dynamic_shape(self, shape): x_names = [] def generate_digits_op(shapes): - const_op = self.get_const_proxy(shapes, torch.int32) + const_op = self.get_proxy( + ascend_op.Const, (shapes, torch.int32, [len(shapes)])) x_names.append(const_op) def find_root_num(set_num, num): @@ -103,7 +104,8 @@ def replace_elem_proxy(elem_str): # handle with integer if elem_str.isdigit(): - const_op = self.get_const_proxy(int(elem_str), torch.int32) + const_op = self.get_proxy( + ascend_op.Const, ([int(elem_str)], torch.int32, [1])) return const_op # handle with NodeProxy string @@ -126,8 +128,10 @@ def replace_elem_proxy(elem_str): if elem_str in self.sym_in_args: arg, idx = self.sym_in_args[elem_str] shape = self.get_proxy(ascend_op.Shape, (arg,)) - axis = self.get_const_proxy(0, torch.int32) - indice = self.get_const_proxy(idx, torch.int32) + axis = self.get_proxy( + ascend_op.Const, ([0], torch.int32, [1])) + indice = self.get_proxy( + ascend_op.Const, ([idx], torch.int32, [1])) gather = self.get_proxy( ascend_op.GatherV2, (shape, indice, axis)) return gather @@ -136,31 +140,34 @@ def replace_elem_proxy(elem_str): return self.sym_to_inputs[elem_str] def generate_not_num(elem): - def _init_stage(elem): - # string form of NodeProxy - if isinstance(elem, str) and 'Proxy' in elem: - # special case handling '()' in NodeProxy string - # '[]' will not mixed with expression calculation priority - elem = elem.replace('(', '[') - elem = elem.replace(')', ']') - elif not isinstance(elem, torch.SymInt): - raise RuntimeError("Not num objects only include SymInt or NodeProxy!") - - # case for NodeProxy string or SymInt - replacements = { - '+': ' + ', - '-': ' - ', - '*': ' * ', - '//': ' // ', - '(': ' ( ', - ')': ' ) ' - } - elem_str = ''.join(replacements.get(c, c) for c in elem) - elems = [e for e in elem_str.split(' ') if e != ''] - return elems - - def _prepare_for_calc(elems): - # prepare for expression calculation + # dynamic shape feature + # situation for NodeProxy + if isinstance(elem, torch.fx.proxy.Proxy): + x_names.append(elem) + return + + # string form of NodeProxy + if isinstance(elem, str) and 'Proxy' in elem: + # special case handling '()' in NodeProxy string + # '[]' will not mixed with expression calculation priority + elem = elem.replace('(', '[') + elem = elem.replace(')', ']') + elif not isinstance(elem, torch.SymInt): + raise RuntimeError("Not num objects only include SymInt or NodeProxy!") + + # case for NodeProxy string or SymInt + elem_str = process_sym_name(elem) + elem_str = elem_str.replace('+', ' + ') + elem_str = elem_str.replace('-', ' - ') + elem_str = elem_str.replace('*', ' * ') + elem_str = elem_str.replace('//', ' // ') + elem_str = elem_str.replace('(', ' ( ') + elem_str = elem_str.replace(')', ' ) ') + elems = elem_str.split(' ') + elems = [e for e in elems if e != ''] + + # prepare for expression calculation + if len(elems) > 1: set_num = [] priority = [] nest = 0 @@ -184,9 +191,7 @@ def _prepare_for_calc(elems): set_num.append(idx) else: set_num.append(-1) - return set_num, priority - def _merge_disjoint_set_stage(elems, set_num, priority): # start merge disjoint-set if len(set_num) > 1: while len(set(set_num)) > 2: @@ -224,27 +229,17 @@ def _merge_disjoint_set_stage(elems, set_num, priority): # merge set number and priority set_num = merge_disjoint_set(set_num, left_idx, right_idx) priority[m_idx] = -1 - return elems - def _final_stage(elems): # add final element proxy final_idx = 0 while final_idx < len(elems) - 1 and str(elems[final_idx]) in ['(', ')']: final_idx += 1 final_elem = replace_elem_proxy(elems[final_idx]) x_names.append(final_elem) - - # dynamic shape feature - # situation for NodeProxy - if isinstance(elem, torch.fx.proxy.Proxy): - x_names.append(elem) - return - - # four stage splits - elems = _init_stage(elem) - set_num, priority = _prepare_for_calc(elems) - elems = _merge_disjoint_set_stage(elems, set_num, priority) - _final_stage(elems) + else: + # only one not num element + node = replace_elem_proxy(elems[0]) + x_names.append(node) dims = [] for elem in shape: @@ -298,7 +293,8 @@ def symint_to_str(shape): # both fit for SymInt & NodeProxy, pass all number cases if not_all_num_shape(shape): return self.process_dynamic_shape(shape) - return self.get_const_proxy(shape, dtype) + return self.get_proxy( + ascend_op.Const, (shape, dtype, [len(shape)])) def get_const_proxy(self, param, dtype, format=None, target_shape=None): if not isinstance(param, torch.fx.proxy.Proxy) and not isinstance(param, FakeTensor): From 2f6bd528c143d11b96eefadb520b441f054c7bc7 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Sun, 28 Apr 2024 02:35:39 +0000 Subject: [PATCH 15/42] Remove dead use, replace const proxy. --- dicp/dicp/vendor/AscendGraph/codegen/ascend.py | 3 --- dicp/dicp/vendor/AscendGraph/conversion.py | 16 +++++----------- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 5c176e99e..39f7f13e8 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -248,9 +248,6 @@ def gen_call_func(self): # generate input dims if len(self.dynamic_inputs) > 0: - dim_len = 0 - for shape in self.actual_shape: - dim_len += len(shape) dims = 'dims = {' for idx, elem in enumerate(self.actual_shape): if len(elem) == 0: diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index b941b267d..34b88f54b 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -75,8 +75,7 @@ def process_dynamic_shape(self, shape): x_names = [] def generate_digits_op(shapes): - const_op = self.get_proxy( - ascend_op.Const, (shapes, torch.int32, [len(shapes)])) + const_op = self.get_const_proxy(shapes, torch.int32) x_names.append(const_op) def find_root_num(set_num, num): @@ -104,9 +103,7 @@ def replace_elem_proxy(elem_str): # handle with integer if elem_str.isdigit(): - const_op = self.get_proxy( - ascend_op.Const, ([int(elem_str)], torch.int32, [1])) - return const_op + return self.get_const_proxy(int(elem_str), torch.int32) # handle with NodeProxy string if 'Proxy' in elem_str: @@ -128,10 +125,8 @@ def replace_elem_proxy(elem_str): if elem_str in self.sym_in_args: arg, idx = self.sym_in_args[elem_str] shape = self.get_proxy(ascend_op.Shape, (arg,)) - axis = self.get_proxy( - ascend_op.Const, ([0], torch.int32, [1])) - indice = self.get_proxy( - ascend_op.Const, ([idx], torch.int32, [1])) + axis = self.get_const_proxy(0, torch.int32) + indice = self.get_const_proxy(idx, torch.int32) gather = self.get_proxy( ascend_op.GatherV2, (shape, indice, axis)) return gather @@ -293,8 +288,7 @@ def symint_to_str(shape): # both fit for SymInt & NodeProxy, pass all number cases if not_all_num_shape(shape): return self.process_dynamic_shape(shape) - return self.get_proxy( - ascend_op.Const, (shape, dtype, [len(shape)])) + return self.get_const_proxy(shape, dtype) def get_const_proxy(self, param, dtype, format=None, target_shape=None): if not isinstance(param, torch.fx.proxy.Proxy) and not isinstance(param, FakeTensor): From 92521ebd372fcdfc42f012f0aa853c173d8bee02 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 10 May 2024 02:21:35 +0000 Subject: [PATCH 16/42] Support 7B dynamic shape version. --- dicp/dicp/dynamo_bridge/utils.py | 8 ++ .../dicp/vendor/AscendGraph/codegen/ascend.py | 4 +- dicp/dicp/vendor/AscendGraph/conversion.py | 123 +++++++++++------- 3 files changed, 87 insertions(+), 48 deletions(-) diff --git a/dicp/dicp/dynamo_bridge/utils.py b/dicp/dicp/dynamo_bridge/utils.py index c241227c0..4116cbbe5 100644 --- a/dicp/dicp/dynamo_bridge/utils.py +++ b/dicp/dicp/dynamo_bridge/utils.py @@ -21,6 +21,14 @@ def symint_in_shape(shape): return False +def neg_in_shape(shape): + for elem in shape: + if isinstance(elem, int): + if elem < 0: + return True + return False + + def process_sym_name(st): # dynamic shape feature # return string wrapper in new version diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 39f7f13e8..be7bf49e0 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -310,9 +310,9 @@ def gen_call_func(self): continue stride = list(elem.stride()) stride = [process_sym_name(dim) for dim in stride] - out_strides.append(str(stride)) + out_strides.append('[' + ','.join(map(str, stride)) + ']') out_storage_offsets.append(elem.storage_offset()) - call_body.writeline(f'out_stride = {out_strides}') + call_body.writeline(f'''out_stride = [{','.join(out_strides)}]''') call_body.writeline(f'out_storage_offset = {out_storage_offsets}') # In precision debug mode, modified array recording InputArgs integer needed diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 34b88f54b..3f77681c3 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -15,7 +15,7 @@ from torch.fx.immutable_collections import immutable_list from torch._subclasses import FakeTensor import dicp.vendor.AscendGraph.ascend_op as ascend_op -from dicp.dynamo_bridge.utils import symint_in_shape, not_all_num_shape, process_sym_name +from dicp.dynamo_bridge.utils import symint_in_shape, neg_in_shape, not_all_num_shape, process_sym_name from dicp.vendor.AscendGraph.codegen.utils import ( get_ascend_dtype ) @@ -71,6 +71,18 @@ class AtenToAscendTransformer(SingleOpTransformer): def __init__(self, gm): super().__init__(gm, conversions) + def preprocess_expression(self, expr): + elem_str = process_sym_name(expr) + elem_str = elem_str.replace('+', ' + ') + elem_str = elem_str.replace('-', ' - ') + elem_str = elem_str.replace('*', ' * ') + elem_str = elem_str.replace('//', ' // ') + elem_str = elem_str.replace('(', ' ( ') + elem_str = elem_str.replace(')', ' ) ') + elems = elem_str.split(' ') + elems = [e for e in elems if e != ''] + return elems + def process_dynamic_shape(self, shape): x_names = [] @@ -111,12 +123,25 @@ def replace_elem_proxy(elem_str): elem_str = elem_str.replace('[', '(') elem_str = elem_str.replace(']', ')') + # split for sym_in_args candidate + idx = -1 + if elem_str[0] == '(' and elem_str[-1] == ')': + elem_str, idx = elem_str.strip('()').split(',') + # search & replace replace_proxy = None arg_symint_candidate = [value[0] for value in self.sym_in_args.values()] + list(self.sym_to_inputs.values()) for arg_sym in arg_symint_candidate: if elem_str == str(arg_sym): - replace_proxy = arg_sym + # is sym_in_args candidate + if int(idx) > -1: + shape = self.get_proxy(ascend_op.Shape, (arg_sym,)) + axis = self.get_const_proxy(0, torch.int32) + indice = self.get_const_proxy(int(idx), torch.int32) + replace_proxy = self.get_proxy( + ascend_op.GatherV2, (shape, indice, axis)) + else: + replace_proxy = arg_sym break assert replace_proxy is not None return replace_proxy @@ -141,25 +166,8 @@ def generate_not_num(elem): x_names.append(elem) return - # string form of NodeProxy - if isinstance(elem, str) and 'Proxy' in elem: - # special case handling '()' in NodeProxy string - # '[]' will not mixed with expression calculation priority - elem = elem.replace('(', '[') - elem = elem.replace(')', ']') - elif not isinstance(elem, torch.SymInt): - raise RuntimeError("Not num objects only include SymInt or NodeProxy!") - # case for NodeProxy string or SymInt - elem_str = process_sym_name(elem) - elem_str = elem_str.replace('+', ' + ') - elem_str = elem_str.replace('-', ' - ') - elem_str = elem_str.replace('*', ' * ') - elem_str = elem_str.replace('//', ' // ') - elem_str = elem_str.replace('(', ' ( ') - elem_str = elem_str.replace(')', ' ) ') - elems = elem_str.split(' ') - elems = [e for e in elems if e != ''] + elems = self.preprocess_expression(elem) # prepare for expression calculation if len(elems) > 1: @@ -258,7 +266,9 @@ def generate_not_num(elem): generate_digits_op(dims) # concat all ops - return self.get_proxy(ascend_op.ConcatD, (x_names, 0)) + pack = self.get_proxy(ascend_op.Pack, (x_names, 0)) + const_shape = self.get_const_proxy([len(shape)], torch.int32) + return self.get_proxy(ascend_op.Reshape, (pack, const_shape)) def get_shape_proxy(self, shape, dtype=torch.int32): def symint_to_str(shape): @@ -267,12 +277,19 @@ def symint_to_str(shape): if isinstance(dim, torch.SymInt): # split expression elements to compare SymInt string dim_str = dim.node.str() - elems = dim_str.strip().split(' ') + elems = self.preprocess_expression(dim_str) # replace SymInt in expression using sympy function for elem in elems: if 's' in elem: - dim_str = str(sympy.simplify(dim_str).subs(elem, str(self.sym_to_inputs[elem]))) + # cover sym_to_inputs for higher priority to sym_in_args + sym_keys = {**self.sym_to_inputs, **self.sym_in_args} + replace_str = str(sym_keys[elem]).replace(' ', '') + + # '[]' will not mixed with expression calculation priority + replace_str = replace_str.replace('(', '[') + replace_str = replace_str.replace(')', ']') + dim_str = dim_str.replace(elem, replace_str) result_shape.append(dim_str) else: result_shape.append(dim) @@ -294,7 +311,7 @@ def get_const_proxy(self, param, dtype, format=None, target_shape=None): if not isinstance(param, torch.fx.proxy.Proxy) and not isinstance(param, FakeTensor): format = "ND" if format is None else format if target_shape is None: - shape = [len(param)] if isinstance(param, list) else [] + shape = [len(param)] if isinstance(param, list) and len(param) > 1 else [] else: shape = target_shape param = param if isinstance(param, list) else [param] @@ -492,7 +509,6 @@ def slice(self, x, dim=0, start=None, end=None, step=1): assert dim == -1 or dim >= 0 and dim < len(x_shape) offset = [0] * len(x_shape) offset[dim] = start - # import pdb; pdb.set_trace() offset = self.get_shape_proxy(offset) size = self.get_shape_proxy(y_shape) @@ -515,8 +531,7 @@ def NewEmptyStrided(self, x, size, stride, dtype=torch.float32, layout=torch.str @register_conversion(aten.empty) def empty(self, size, dtype=torch.int64, layout=torch.strided, device='cpu', memory_format=torch.contiguous_format, pin_memory=False): - shape_op = self.get_proxy( - ascend_op.Const, (size, torch.int32, [len(size)])) + shape_op = self.get_shape_proxy(size) return self.get_proxy(ascend_op.Empty, (shape_op, dtype, layout, device, memory_format)) @register_conversion(aten.empty_like.default) @@ -565,12 +580,14 @@ def select(self, x, dim, index): @register_conversion(_operator.add) def inadd(self, x, y): + out_dtype = fx_traceback.get_current_meta()['val'].dtype if not isinstance(x, torch.fx.proxy.Proxy): assert isinstance(x, int) x = self.get_proxy(ascend_op.Const, (x, torch.int32)) if not isinstance(y, torch.fx.proxy.Proxy): assert isinstance(y, int) y = self.get_proxy(ascend_op.Const, (y, torch.int32)) + x, y = self.promote_dtype(x, y, target_dtype=out_dtype) return self.get_proxy(ascend_op.AddV2, (x, y)) @register_conversion([aten.view.default, aten._unsafe_view, aten._unsafe_view.default]) @@ -582,12 +599,7 @@ def view(self, x, size): size_tmp = [s for s in size] + [1] size = immutable_list(size_tmp) numel = result_val.numel() - neg = False - for i in shape: - if not isinstance(i, torch.SymInt): - if i == -1: - neg = True - break + neg = neg_in_shape(shape) if neg: prod = 1 for i in shape: @@ -609,7 +621,7 @@ def view(self, x, size): raise RuntimeError( "cannot handle with both negative and symint!") shape = real_shape - elif not_all_num_shape(shape): + elif not neg_in_shape(size) and not_all_num_shape(shape): shape = size shape = self.get_shape_proxy(shape) if x.node.meta["val"].dtype == torch.complex64: @@ -1018,7 +1030,7 @@ def index_put_default(self, x, indices, values): return self.masked_fill(x, index, values) reshape_shape = index_shape + [1] * \ (x_shape_size - index_shape_size) - reshape_op = self.get_const_proxy(reshape_shape, torch.int32) + reshape_op = self.get_shape_proxy(reshape_shape) index = self.get_proxy(ascend_op.Reshape, (index, reshape_op)) return self.masked_fill(x, index, values) @@ -1313,15 +1325,15 @@ def convolution(self, input, weight, bias, stride, padding, @register_conversion(_operator.mul) def inmul(self, x, y): - assert (not isinstance(y, torch.fx.proxy.Proxy)) - y = self.get_const_proxy(y, torch.int32) + if not isinstance(y, torch.fx.proxy.Proxy): + y = self.get_const_proxy(y, torch.int32) return self.get_proxy(ascend_op.Mul, (x, y)) @register_conversion(torch.ops.aten.sym_size) def symsize(self, x, dim): dim = [dim] if not isinstance(dim, list) else dim shape = self.get_proxy(ascend_op.Shape, (x,)) - axis = self.get_const_proxy(0, torch.int32, target_shape=[1]) + axis = self.get_const_proxy(0, torch.int32) indices = self.get_const_proxy(dim, torch.int32) return self.get_proxy(ascend_op.GatherV2, (shape, indices, axis)) @@ -1374,7 +1386,7 @@ def mean(self, x, dims=[], keepdim=False): @register_conversion(torch.ops.aten.cumsum.default) def cumsum(self, x, dim, dtype=None): - dim_const = self.get_const_proxy(dim, torch.int32, target_shape=[1]) + dim_const = self.get_const_proxy(dim, torch.int32) return self.get_proxy(ascend_op.Cumsum, (x, dim_const)) @register_conversion(torch.ops.aten._log_softmax.default) @@ -1447,7 +1459,7 @@ def exp(self, a): def embedding(self, weight, indices, padding_idx=-1): # TODO! consider situation for padding_idx # during training stage - axis = self.get_const_proxy(0, torch.int32, target_shape=[1]) + axis = self.get_const_proxy(0, torch.int32) return self.get_proxy(ascend_op.GatherV2, (weight, indices, axis)) @register_conversion(torch.ops.aten.gather) @@ -1628,9 +1640,16 @@ def Repeat(self, x, repeats): @register_conversion([torch.ops.aten.ge.Scalar, torch.ops.aten.ge.Tensor]) def Ge(self, x, y): + x_dtype = x.node.meta['val'].dtype if not isinstance(y, torch.fx.proxy.Proxy): - dtype = x.node.meta['val'].dtype - y = self.get_const_proxy(y, dtype) + y = self.get_const_proxy(y, x_dtype) + else: + if not isinstance(y.node.meta['val'], torch.SymInt): + y_dtype = y.node.meta['val'].dtype + else: + y_dtype = torch.int32 + if x_dtype != y_dtype: + y = self.get_proxy(ascend_op.Cast, (y, get_ascend_dtype(x_dtype))) return self.get_proxy(ascend_op.GreaterEqual, (x, y)) @register_conversion(torch.ops.aten.logical_or.default) @@ -1689,7 +1708,7 @@ def lightllm_rotary_emb(self, x, cos, sin): seq_len = x_shape[0] dim = x_shape[2] - cos_sin_shape = self.get_const_proxy([seq_len, 1, dim // 2], torch.int32) + cos_sin_shape = self.get_shape_proxy([seq_len, 1, dim // 2]) cos = self.get_proxy(ascend_op.Reshape, (cos, cos_sin_shape)) sin = self.get_proxy(ascend_op.Reshape, (sin, cos_sin_shape)) @@ -1710,7 +1729,7 @@ def prompt_attention_inference(self, q, k, v, seqlen, num_head, head_dim): q_shape = list(q.node.meta['val'].shape) seq_len = q_shape[1] shape = [seq_len, seq_len] - shape = self.get_proxy(ascend_op.Const, (shape, torch.int32, [len(shape)])) + shape = self.get_shape_proxy(shape) mask = self.get_proxy(ascend_op.Empty, (shape, torch.bool)) mask = self.get_proxy(ascend_op.OnesLike, (mask,)) mask = self.get_proxy(ascend_op.Tril, (mask,)) @@ -1748,8 +1767,20 @@ def select_scatter(self, x, src, dim, index): @register_conversion(torch.ops.lightllm.copy_with_offset.default) def copy_with_offset(self, x, src, start_dim, end_dim): - dims = [x for x in range(start_dim, end_dim)] - dims = self.get_const_proxy(dims, torch.int32, target_shape=[len(dims), 1]) + if isinstance(start_dim, int) and isinstance(end_dim, int): + dims = [x for x in range(start_dim, end_dim)] + dims = self.get_const_proxy(dims, torch.int32) + else: + step = self.get_const_proxy(1, torch.int32) + if isinstance(start_dim, int): + start_dim = self.get_const_proxy(start_dim, torch.int32) + if isinstance(end_dim, int): + end_dim = self.get_const_proxy(end_dim, torch.int32) + dims = self.get_proxy(ascend_op.Range, (start_dim, end_dim, step)) + # dims = self.arange(end_dim, start_dim) + dims = self.get_proxy(ascend_op.Unsqueeze, (dims, [-1])) + x = self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(torch.float))) + src = self.get_proxy(ascend_op.Cast, (src, get_ascend_dtype(torch.float))) return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) @register_conversion(torch.ops.lightllm.flash_attention_inference.default) From 88985410e543cce08de7212dc51903e8ddd7ef76 Mon Sep 17 00:00:00 2001 From: chenchiyu Date: Sat, 11 May 2024 06:56:11 +0000 Subject: [PATCH 17/42] Pass 1st dynamic graph model. --- dicp/dicp/vendor/AscendGraph/ascend_op.py | 41 ++++++++++++++++- .../dicp/vendor/AscendGraph/codegen/ascend.py | 17 +++++++ dicp/dicp/vendor/AscendGraph/conversion.py | 46 ++++++++++++++----- 3 files changed, 91 insertions(+), 13 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index 201187101..59a39b1f8 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -567,6 +567,27 @@ def __init__(self): super().__init__("CastToCpu") +class SequenceAt(Operator): + def __init__(self): + super().__init__("SequenceAt") + + def infer_result(self, x, idx=None): + x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x) + if isinstance(x, (List, Tuple)): + return x[idx] + out_dtype = x_dtype + if x_dtype == torch.complex64: # for complex64 + out_shape = list(x_shape) + if idx == 0 or idx == 1: + out_dtype = torch.float32 + out_shape.append(1) + else: + out_shape = [x_shape[idx]] if idx is not None else list(x_shape) + return torch.empty( + out_shape, dtype=out_dtype, memory_format=get_memory_format(x) + ) + + class Identity(Operator): def __init__(self): super().__init__("Identity") @@ -750,7 +771,6 @@ def __init__(self, x): def infer_result(self, x): return common_unary_op_infer(x) - class SplitD(Operator): def __init__(self): super().__init__("SplitD") @@ -770,6 +790,25 @@ def infer_result(self, x, split_dim, num_split, y, from_view_complex=False): memory_format=get_memory_format(x), ) +class SplitToSequence(Operator): + def __init__(self): + super().__init__("SplitToSequence") + + def infer_result(self, x, split_dim, split_size, y, from_view_complex=False): + assert from_view_complex == True, ( + self.__class__.__name__ + + ": currently available only in op view_as_complex!" + ) + x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) + split_dim = (split_dim + x_dim) % x_dim + out_shape = list(x_shape) + del out_shape[-1] + return torch.empty( + out_shape, + dtype=torch.complex64 if from_view_complex else x_dtype, + memory_format=get_memory_format(x), + ) + class Slice(Operator): def __init__(self): diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index be7bf49e0..d69d8c5ec 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -877,6 +877,14 @@ def Identity(name, input, index=None): op.set_input("x", input) return op.to_node() + @staticmethod + def SequenceAt(name, input, index=None): + op = OP(name, "SequenceAt") + assert index is not None + op.set_input("handle", input) + op.set_input("index", index) + return op.to_node() + @staticmethod def IdentityInp(name, input, dst=None): op = OP(name, "Identity") @@ -1339,6 +1347,15 @@ def SplitD(name, x, dim, num_split, y, from_view_complex): split_op.set_dynamic_output("y", y) return split_op.to_node() + @staticmethod + def SplitToSequence(name, x, dim, split_size, y, from_view_complex): + split_op = OP(name, "SplitToSequence") + split_op.set_input("x", x) + split_op.set_input("split", split_size) + split_op.set_attr_int("axis", dim) + # split_op.set_dynamic_output("y", y) + return split_op.to_node() + @staticmethod def Pack(name, x, axis): x_name = [] diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 3f77681c3..e4e728aca 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -266,9 +266,9 @@ def generate_not_num(elem): generate_digits_op(dims) # concat all ops - pack = self.get_proxy(ascend_op.Pack, (x_names, 0)) - const_shape = self.get_const_proxy([len(shape)], torch.int32) - return self.get_proxy(ascend_op.Reshape, (pack, const_shape)) + if len(x_names) > 1: + return self.get_proxy(ascend_op.ConcatD, (x_names, 0)) + return self.get_proxy(ascend_op.Unsqueeze, (x_names[0], [0])) def get_shape_proxy(self, shape, dtype=torch.int32): def symint_to_str(shape): @@ -410,7 +410,9 @@ def add(self, x, y, alpha: Optional[Number] = 1): y = self.get_const_proxy(y, out_dtype) else: y = self.mul(y, alpha) if alpha != 1 else y - x, y = self.promote_dtype(x, y, target_dtype=out_dtype) + # x, y = self.promote_dtype(x, y, target_dtype=out_dtype) + x = self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(out_dtype))) + y = self.get_proxy(ascend_op.Cast, (y, get_ascend_dtype(out_dtype))) return self.get_proxy(ascend_op.AddV2, (x, y), {}) @register_conversion(torch.ops.aten.add.Scalar) @@ -491,8 +493,9 @@ def split(self, x, split_size, dim=0): if dim < 0: dim += len(shape) assert shape[dim] > 0 - num_split = int((shape[dim] + split_size - 1) / split_size) - return self.get_proxy(ascend_op.SplitD, (x, dim, num_split, num_split), splitD_kw) + # num_split = int((shape[dim] + split_size - 1) / split_size) + split_size = self.get_const_proxy(split_size, torch.int32) + return self.get_proxy(ascend_op.SplitToSequence, (x, dim, split_size, split_size), splitD_kw) @register_conversion(aten.slice.Tensor) def slice(self, x, dim=0, start=None, end=None, step=1): @@ -580,14 +583,12 @@ def select(self, x, dim, index): @register_conversion(_operator.add) def inadd(self, x, y): - out_dtype = fx_traceback.get_current_meta()['val'].dtype if not isinstance(x, torch.fx.proxy.Proxy): assert isinstance(x, int) x = self.get_proxy(ascend_op.Const, (x, torch.int32)) if not isinstance(y, torch.fx.proxy.Proxy): assert isinstance(y, int) y = self.get_proxy(ascend_op.Const, (y, torch.int32)) - x, y = self.promote_dtype(x, y, target_dtype=out_dtype) return self.get_proxy(ascend_op.AddV2, (x, y)) @register_conversion([aten.view.default, aten._unsafe_view, aten._unsafe_view.default]) @@ -1544,7 +1545,13 @@ def sigmoid(self, x): @register_conversion(operator.getitem) def identity(self, x, idx): - return self.get_proxy(ascend_op.Identity, (x, idx)) + # common case + if not 'split_to_sequence' in str(x): + return self.get_proxy(ascend_op.Identity, (x, idx)) + # split_to_sequence return handle sequence + else: + idx = self.get_const_proxy(idx, torch.int32) + return self.get_proxy(ascend_op.SequenceAt, (x, idx)) @register_conversion(torch.ops.aten.full_like) def fulllike(self, x, value, dtype=torch.float32, layout=torch.strided, @@ -1734,6 +1741,7 @@ def prompt_attention_inference(self, q, k, v, seqlen, num_head, head_dim): mask = self.get_proxy(ascend_op.OnesLike, (mask,)) mask = self.get_proxy(ascend_op.Tril, (mask,)) mask = self.get_proxy(ascend_op.LogicalNot, (mask,)) + q = self.get_proxy(ascend_op.Cast, (q, get_ascend_dtype(torch.float16))) fa = self.get_proxy(ascend_op.PromptFlashAttention, (q, k, v, num_head, seqlen, mask, head_dim)) return fa @@ -1784,23 +1792,37 @@ def copy_with_offset(self, x, src, start_dim, end_dim): return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) @register_conversion(torch.ops.lightllm.flash_attention_inference.default) - def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): + def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len): q_shape = list(q.node.meta['val'].shape) batch, head, dim = q_shape[0], q_shape[1], q_shape[2] + + # head_num not change for common cases + if isinstance(head, torch.SymInt): + head = head.node.hint + if isinstance(dim, torch.SymInt): + dim = dim.node.hint + k_shape = list(all_k.node.meta['val'].shape) kvhead = k_shape[1] + # the same for kvhead + if isinstance(kvhead, torch.SymInt): + kvhead = kvhead.node.hint + res = [] compute_batch = 1 select_axis = self.get_const_proxy(0, torch.int32) for i in range(batch): - current_len = current_len[i] + current_len = current_lens[i] select_index = self.get_const_proxy(i, torch.int32) xq = self.get_proxy(ascend_op.GatherV2, (q, select_index, select_axis)) kv_start_index = self.get_const_proxy([i * max_len, 0, 0], torch.int32) - kv_end_index = self.get_const_proxy([i * max_len + current_len, kvhead, dim], torch.int32) + imax_const = self.get_const_proxy(i * max_len, torch.int32) + end_proxy = self.get_proxy(ascend_op.Add, (current_len, imax_const)) + # kv_end_index = self.get_const_proxy([i * max_len + current_len, kvhead, dim], torch.int32) + kv_end_index = self.get_shape_proxy([end_proxy, kvhead, dim], torch.int32) kv_seq_len = current_len kv_gather_shape = self.get_shape_proxy([compute_batch, kv_seq_len, kvhead, dim]) From 8971b85597d51472f5299c23165c420be6148805 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Tue, 14 May 2024 07:21:58 +0000 Subject: [PATCH 18/42] Pass both two graph model for 7B dynamic shape version. --- dicp/dicp/vendor/AscendGraph/ascend_op.py | 17 ++--------------- dicp/dicp/vendor/AscendGraph/codegen/ascend.py | 3 +-- dicp/dicp/vendor/AscendGraph/conversion.py | 5 ++--- 3 files changed, 5 insertions(+), 20 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index 59a39b1f8..64659e64b 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -794,21 +794,8 @@ class SplitToSequence(Operator): def __init__(self): super().__init__("SplitToSequence") - def infer_result(self, x, split_dim, split_size, y, from_view_complex=False): - assert from_view_complex == True, ( - self.__class__.__name__ - + ": currently available only in op view_as_complex!" - ) - x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) - split_dim = (split_dim + x_dim) % x_dim - out_shape = list(x_shape) - del out_shape[-1] - return torch.empty( - out_shape, - dtype=torch.complex64 if from_view_complex else x_dtype, - memory_format=get_memory_format(x), - ) - + def infer_result(self, x, split_dim, split_size): + torch.split(x, split_size, split_dim) class Slice(Operator): def __init__(self): diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index d69d8c5ec..9bc0f04dc 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -1348,12 +1348,11 @@ def SplitD(name, x, dim, num_split, y, from_view_complex): return split_op.to_node() @staticmethod - def SplitToSequence(name, x, dim, split_size, y, from_view_complex): + def SplitToSequence(name, x, dim, split_size): split_op = OP(name, "SplitToSequence") split_op.set_input("x", x) split_op.set_input("split", split_size) split_op.set_attr_int("axis", dim) - # split_op.set_dynamic_output("y", y) return split_op.to_node() @staticmethod diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index e4e728aca..84426c464 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -493,9 +493,8 @@ def split(self, x, split_size, dim=0): if dim < 0: dim += len(shape) assert shape[dim] > 0 - # num_split = int((shape[dim] + split_size - 1) / split_size) - split_size = self.get_const_proxy(split_size, torch.int32) - return self.get_proxy(ascend_op.SplitToSequence, (x, dim, split_size, split_size), splitD_kw) + split_size = self.get_proxy(ascend_op.Squeeze, (split_size, [0])) + return self.get_proxy(ascend_op.SplitToSequence, (x, dim, split_size)) @register_conversion(aten.slice.Tensor) def slice(self, x, dim=0, start=None, end=None, step=1): From 08e73f51c1dd6193ea37f41c8cce4d752c46c6bf Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Tue, 14 May 2024 09:13:32 +0000 Subject: [PATCH 19/42] Fix ci case incre_flash_attention. --- dicp/dicp/vendor/AscendGraph/conversion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 84426c464..2b4505a29 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1819,8 +1819,8 @@ def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len): kv_start_index = self.get_const_proxy([i * max_len, 0, 0], torch.int32) imax_const = self.get_const_proxy(i * max_len, torch.int32) - end_proxy = self.get_proxy(ascend_op.Add, (current_len, imax_const)) - # kv_end_index = self.get_const_proxy([i * max_len + current_len, kvhead, dim], torch.int32) + curlen_const = self.get_const_proxy(current_len, torch.int32) + end_proxy = self.get_proxy(ascend_op.Add, (curlen_const, imax_const)) kv_end_index = self.get_shape_proxy([end_proxy, kvhead, dim], torch.int32) kv_seq_len = current_len From 53f728dbb3fe5dcc123dbbd40b553bb90bec1555 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Wed, 15 May 2024 08:25:53 +0000 Subject: [PATCH 20/42] Change split execution path for both shape mode. --- dicp/dicp/vendor/AscendGraph/ascend_op.py | 8 ++++++++ dicp/dicp/vendor/AscendGraph/codegen/ascend.py | 7 +++++++ dicp/dicp/vendor/AscendGraph/conversion.py | 13 +++++++++++-- 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index 64659e64b..c1c85cab0 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -217,6 +217,14 @@ def infer_result(self, x): return common_unary_op_infer(x) +class Triu(Operator): + def __init__(self): + super().__init__("Triu") + + def infer_result(self, x, diag): + return common_unary_op_infer(x) + + class Sqrt(Operator): def __init__(self): super().__init__("Sqrt") diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 9bc0f04dc..2e9fb5d3a 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -789,6 +789,13 @@ def Rsqrt(name, x): op.set_input("x", x) return op.to_node() + @staticmethod + def Triu(name, x, diag): + op = OP(name, "Triu") + op.set_input("x", x) + op.set_attr_int("diagonal", diag) + return op.to_node() + @staticmethod def Conv2D(name, input, weight, stride, padding, dilation, groups, format, bias): diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 2b4505a29..50551baa4 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -466,6 +466,10 @@ def rsqrt(self, x): cond_op = self.get_proxy(ascend_op.Less, (x, zero_op)) return self.get_proxy(ascend_op.Select, (cond_op, nan_op, rsqrt_op)) + @register_conversion(aten.triu) + def triu(self, x, diag): + return self.get_proxy(ascend_op.Triu, (x, diag)) + @register_conversion(_operator.ge) def inge(self, x, y): if not isinstance(y, torch.fx.proxy.Proxy): @@ -493,8 +497,13 @@ def split(self, x, split_size, dim=0): if dim < 0: dim += len(shape) assert shape[dim] > 0 - split_size = self.get_proxy(ascend_op.Squeeze, (split_size, [0])) - return self.get_proxy(ascend_op.SplitToSequence, (x, dim, split_size)) + # dynamic feature for lightllm llama 7B + if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: + split_size = self.get_proxy(ascend_op.Squeeze, (split_size, [0])) + return self.get_proxy(ascend_op.SplitToSequence, (x, dim, split_size)) + + num_split = int((shape[dim] + split_size - 1) / split_size) + return self.get_proxy(ascend_op.SplitD, (x, dim, num_split, num_split), splitD_kw) @register_conversion(aten.slice.Tensor) def slice(self, x, dim=0, start=None, end=None, step=1): From 3e36e2ee0e1151489b8786fdd3e8e44ab9c2d2a7 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Wed, 15 May 2024 08:54:24 +0000 Subject: [PATCH 21/42] Add execution path for copy_with_offset. --- dicp/dicp/vendor/AscendGraph/conversion.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 50551baa4..75287fe2d 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1783,6 +1783,12 @@ def select_scatter(self, x, src, dim, index): @register_conversion(torch.ops.lightllm.copy_with_offset.default) def copy_with_offset(self, x, src, start_dim, end_dim): + # dynamic feature for lightllm llama 7B + if len(self.sym_in_args) == 0 and len(self.sym_to_inputs) == 0: + dims = [x for x in range(start_dim, end_dim)] + dims = self.get_const_proxy(dims, torch.int32, target_shape=[len(dims), 1]) + return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) + if isinstance(start_dim, int) and isinstance(end_dim, int): dims = [x for x in range(start_dim, end_dim)] dims = self.get_const_proxy(dims, torch.int32) @@ -1793,7 +1799,6 @@ def copy_with_offset(self, x, src, start_dim, end_dim): if isinstance(end_dim, int): end_dim = self.get_const_proxy(end_dim, torch.int32) dims = self.get_proxy(ascend_op.Range, (start_dim, end_dim, step)) - # dims = self.arange(end_dim, start_dim) dims = self.get_proxy(ascend_op.Unsqueeze, (dims, [-1])) x = self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(torch.float))) src = self.get_proxy(ascend_op.Cast, (src, get_ascend_dtype(torch.float))) From 6c40ad4b8635d8f64d4a9386dc5a059031311f73 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Wed, 15 May 2024 12:19:31 +0000 Subject: [PATCH 22/42] Merge copy_with_offset shape path mode. --- dicp/dicp/vendor/AscendGraph/conversion.py | 39 +++++++++++----------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 75287fe2d..769e8ee8f 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -410,9 +410,7 @@ def add(self, x, y, alpha: Optional[Number] = 1): y = self.get_const_proxy(y, out_dtype) else: y = self.mul(y, alpha) if alpha != 1 else y - # x, y = self.promote_dtype(x, y, target_dtype=out_dtype) - x = self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(out_dtype))) - y = self.get_proxy(ascend_op.Cast, (y, get_ascend_dtype(out_dtype))) + x, y = self.promote_dtype(x, y, target_dtype=out_dtype) return self.get_proxy(ascend_op.AddV2, (x, y), {}) @register_conversion(torch.ops.aten.add.Scalar) @@ -1348,14 +1346,17 @@ def symsize(self, x, dim): @register_conversion(torch.ops.aten.mm.default) def mm(self, x, y): + out_dtype = fx_traceback.get_current_meta()['val'].dtype + # TODO! MatMul not support fp32 input # for higher precision in some cases if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: x = self.get_proxy(ascend_op.Unsqueeze, (x, [0])) y = self.get_proxy(ascend_op.Unsqueeze, (y, [0])) - mm = self.get_proxy(ascend_op.BatchMatMul, (x, y, False, False)) - return self.get_proxy(ascend_op.Squeeze, (mm, [0])) - out_dtype = fx_traceback.get_current_meta()['val'].dtype + bmm = self.get_proxy(ascend_op.BatchMatMul, (x, y, False, False)) + cast = self.get_proxy(ascend_op.Cast, (bmm, get_ascend_dtype(out_dtype))) + return self.get_proxy(ascend_op.Squeeze, (cast, [0])) + trans_x = False trans_y = False if isinstance(x.node.target, ascend_op.Permute) and x.node.args[1] == [1, 0]: @@ -1783,25 +1784,23 @@ def select_scatter(self, x, src, dim, index): @register_conversion(torch.ops.lightllm.copy_with_offset.default) def copy_with_offset(self, x, src, start_dim, end_dim): - # dynamic feature for lightllm llama 7B - if len(self.sym_in_args) == 0 and len(self.sym_to_inputs) == 0: + if isinstance(start_dim, int) and isinstance(end_dim, int): dims = [x for x in range(start_dim, end_dim)] dims = self.get_const_proxy(dims, torch.int32, target_shape=[len(dims), 1]) return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) - if isinstance(start_dim, int) and isinstance(end_dim, int): - dims = [x for x in range(start_dim, end_dim)] - dims = self.get_const_proxy(dims, torch.int32) - else: - step = self.get_const_proxy(1, torch.int32) - if isinstance(start_dim, int): - start_dim = self.get_const_proxy(start_dim, torch.int32) - if isinstance(end_dim, int): - end_dim = self.get_const_proxy(end_dim, torch.int32) - dims = self.get_proxy(ascend_op.Range, (start_dim, end_dim, step)) + step = self.get_const_proxy(1, torch.int32) + if isinstance(start_dim, int): + start_dim = self.get_const_proxy(start_dim, torch.int32) + if isinstance(end_dim, int): + end_dim = self.get_const_proxy(end_dim, torch.int32) + dims = self.get_proxy(ascend_op.Range, (start_dim, end_dim, step)) dims = self.get_proxy(ascend_op.Unsqueeze, (dims, [-1])) - x = self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(torch.float))) - src = self.get_proxy(ascend_op.Cast, (src, get_ascend_dtype(torch.float))) + + x_dtype = x.node.meta['val'].dtype + src_dtype = src.node.meta['val'].dtype + if x_dtype != src_dtype: + src = self.get_proxy(ascend_op.Cast, (src, get_ascend_dtype(src_dtype))) return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) @register_conversion(torch.ops.lightllm.flash_attention_inference.default) From 0ce746fc20950989417c6e38e90abf384c76b7b3 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 16 May 2024 01:39:35 +0000 Subject: [PATCH 23/42] Add const proxy for int split_size. --- dicp/dicp/vendor/AscendGraph/conversion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 769e8ee8f..11db29043 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -497,7 +497,10 @@ def split(self, x, split_size, dim=0): assert shape[dim] > 0 # dynamic feature for lightllm llama 7B if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: - split_size = self.get_proxy(ascend_op.Squeeze, (split_size, [0])) + if isinstance(split_size, int): + split_size = self.get_const_proxy(split_size, torch.int32) + else: + split_size = self.get_proxy(ascend_op.Squeeze, (split_size, [0])) return self.get_proxy(ascend_op.SplitToSequence, (x, dim, split_size)) num_split = int((shape[dim] + split_size - 1) / split_size) From e5c06abe565e8a4c4eb57345f4a3141220bce801 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 16 May 2024 07:02:47 +0000 Subject: [PATCH 24/42] Move some common functions into util. --- dicp/dicp/dynamo_bridge/utils.py | 32 ++++++++++++++++++++ dicp/dicp/vendor/AscendGraph/conversion.py | 34 ++-------------------- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/dicp/dicp/dynamo_bridge/utils.py b/dicp/dicp/dynamo_bridge/utils.py index 4116cbbe5..9d8f34189 100644 --- a/dicp/dicp/dynamo_bridge/utils.py +++ b/dicp/dicp/dynamo_bridge/utils.py @@ -38,6 +38,38 @@ def process_sym_name(st): return str(st) +def preprocess_expression(expr): + elem_str = process_sym_name(expr) + elem_str = elem_str.replace('+', ' + ') + elem_str = elem_str.replace('-', ' - ') + elem_str = elem_str.replace('*', ' * ') + elem_str = elem_str.replace('//', ' // ') + elem_str = elem_str.replace('(', ' ( ') + elem_str = elem_str.replace(')', ' ) ') + elems = elem_str.split(' ') + elems = [e for e in elems if e != ''] + return elems + + +def find_root_num(set_num, num): + while set_num[num] != num: + num = set_num[num] + return num + + +def merge_disjoint_set(set_num, idx_a, idx_b): + root_a = find_root_num(set_num, idx_a) + root_b = find_root_num(set_num, idx_b) + # an example for (s5 / 8) - (s5 / 16) + # num: 0 1 2 3 + # step1 - > set_num: 0 1 2 3 + # step2 - > set_num: 0 0 2 2 + # step3 - > set_num: 0 0 0 0 + + # return merged set from root_b to root_a + return [root_a if find_root_num(set_num, s) == root_b else s for s in set_num] + + def save_cpu_gm(gm: torch.fx.GraphModule, folder: str): Path(folder).mkdir(exist_ok=True) cpu_gm = copy_gm_to_cpu(gm) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 11db29043..5e257567c 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -16,6 +16,7 @@ from torch._subclasses import FakeTensor import dicp.vendor.AscendGraph.ascend_op as ascend_op from dicp.dynamo_bridge.utils import symint_in_shape, neg_in_shape, not_all_num_shape, process_sym_name +from dicp.dynamo_bridge.utils import preprocess_expression, find_root_num, merge_disjoint_set from dicp.vendor.AscendGraph.codegen.utils import ( get_ascend_dtype ) @@ -71,18 +72,6 @@ class AtenToAscendTransformer(SingleOpTransformer): def __init__(self, gm): super().__init__(gm, conversions) - def preprocess_expression(self, expr): - elem_str = process_sym_name(expr) - elem_str = elem_str.replace('+', ' + ') - elem_str = elem_str.replace('-', ' - ') - elem_str = elem_str.replace('*', ' * ') - elem_str = elem_str.replace('//', ' // ') - elem_str = elem_str.replace('(', ' ( ') - elem_str = elem_str.replace(')', ' ) ') - elems = elem_str.split(' ') - elems = [e for e in elems if e != ''] - return elems - def process_dynamic_shape(self, shape): x_names = [] @@ -90,23 +79,6 @@ def generate_digits_op(shapes): const_op = self.get_const_proxy(shapes, torch.int32) x_names.append(const_op) - def find_root_num(set_num, num): - while set_num[num] != num: - num = set_num[num] - return num - - def merge_disjoint_set(set_num, idx_a, idx_b): - root_a = find_root_num(set_num, idx_a) - root_b = find_root_num(set_num, idx_b) - # an example for (s5 / 8) - (s5 / 16) - # num: 0 1 2 3 - # step1 - > set_num: 0 1 2 3 - # step2 - > set_num: 0 0 2 2 - # step3 - > set_num: 0 0 0 0 - - # return merged set from root_b to root_a - return [root_a if find_root_num(set_num, s) == root_b else s for s in set_num] - def replace_elem_proxy(elem_str): # exit if already a proxy if isinstance(elem_str, torch.fx.proxy.Proxy): @@ -167,7 +139,7 @@ def generate_not_num(elem): return # case for NodeProxy string or SymInt - elems = self.preprocess_expression(elem) + elems = preprocess_expression(elem) # prepare for expression calculation if len(elems) > 1: @@ -277,7 +249,7 @@ def symint_to_str(shape): if isinstance(dim, torch.SymInt): # split expression elements to compare SymInt string dim_str = dim.node.str() - elems = self.preprocess_expression(dim_str) + elems = preprocess_expression(dim_str) # replace SymInt in expression using sympy function for elem in elems: From 9f0472aba22065a421397d5caca866440eef4116 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 16 May 2024 10:36:27 +0000 Subject: [PATCH 25/42] Add path for flash_attention to pass head, kvhead and dim in. --- dicp/dicp/vendor/AscendGraph/conversion.py | 22 ++++++++-------------- dicp/dicp/vendor/AscendGraph/ext_ops.py | 12 +++++++----- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 5e257567c..66a65bd09 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -483,7 +483,6 @@ def slice(self, x, dim=0, start=None, end=None, step=1): # TODO(tangzhiyi): miss step parameter x_shape = list(x.node.meta['val'].shape) y_shape = list(fx_traceback.get_current_meta()['val'].shape) - # y_shape = fx_traceback.get_current_meta()['val'].shape dim = int(dim) if not isinstance(start, torch.fx.proxy.Proxy): start = int(start) if start is not None else 0 @@ -1779,22 +1778,17 @@ def copy_with_offset(self, x, src, start_dim, end_dim): return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) @register_conversion(torch.ops.lightllm.flash_attention_inference.default) - def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len): + def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len, kvhead=None, head=None, dim=None): q_shape = list(q.node.meta['val'].shape) - batch, head, dim = q_shape[0], q_shape[1], q_shape[2] - - # head_num not change for common cases - if isinstance(head, torch.SymInt): - head = head.node.hint - if isinstance(dim, torch.SymInt): - dim = dim.node.hint + batch = q_shape[0] + if head is None: + head = q_shape[1] + if dim is None: + dim = q_shape[2] k_shape = list(all_k.node.meta['val'].shape) - kvhead = k_shape[1] - - # the same for kvhead - if isinstance(kvhead, torch.SymInt): - kvhead = kvhead.node.hint + if kvhead is None: + kvhead = k_shape[1] res = [] compute_batch = 1 diff --git a/dicp/dicp/vendor/AscendGraph/ext_ops.py b/dicp/dicp/vendor/AscendGraph/ext_ops.py index 324d2a9b2..1f2e41260 100644 --- a/dicp/dicp/vendor/AscendGraph/ext_ops.py +++ b/dicp/dicp/vendor/AscendGraph/ext_ops.py @@ -88,21 +88,23 @@ def lightllm_prompt_attention_inference_impl(q, k, v, seqlen, num_head, head_dim @torch._custom_op.impl.custom_op('lightllm::flash_attention_inference') -def flash_attention_inference(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int) -> Tensor: +def flash_attention_inference(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int, kvhead: int, head: int, dim: int) -> Tensor: ... @flash_attention_inference.impl_abstract() -def lightllm_flash_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int): +def lightllm_flash_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int, kvhead: int, head: int, dim: int): return torch.empty_like(q) @flash_attention_inference.impl(['cpu', 'cuda']) -def lightllm_flash_attention_inference_impl(q, all_k, all_v, current_lens, max_len): +def lightllm_flash_attention_inference_impl(q, all_k, all_v, current_lens, max_len, kvhead=None, head=None, dim=None): # q: batch, head, dim batch = q.shape[0] - head = q.shape[1] - dim = q.shape[2] + if head is None: + head = q.shape[1] + if dim is None: + dim = q.shape[2] res = [] compute_batch = 1 for i in range(batch): From 62957f7e89792d56d53c4a7e0e03458e191fe856 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 16 May 2024 11:30:06 +0000 Subject: [PATCH 26/42] Cancel path split for slice start proxy form. --- dicp/dicp/vendor/AscendGraph/conversion.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 66a65bd09..822ae910c 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -484,10 +484,9 @@ def slice(self, x, dim=0, start=None, end=None, step=1): x_shape = list(x.node.meta['val'].shape) y_shape = list(fx_traceback.get_current_meta()['val'].shape) dim = int(dim) - if not isinstance(start, torch.fx.proxy.Proxy): - start = int(start) if start is not None else 0 - start = start if start >= 0 else x_shape[dim] + start - assert start is None or start >= 0 and start < x_shape[dim] + start = int(start) if start is not None else 0 + start = start if start >= 0 else x_shape[dim] + start + assert start is None or start >= 0 and start < x_shape[dim] assert dim == -1 or dim >= 0 and dim < len(x_shape) offset = [0] * len(x_shape) From 709c009b8e215e8ea5355a0153cd1a0f721a7a23 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 17 May 2024 03:15:48 +0000 Subject: [PATCH 27/42] Add sequenceAt & triu test unit case. --- dicp/test/ascend_scripts/ops/dynamic.ini | 1 + dicp/test/ascend_scripts/ops/static.ini | 1 + dicp/test/op/test_sequence_at.py | 45 ++++++++++++++++++++++++ dicp/test/op/test_triu.py | 40 +++++++++++++++++++++ 4 files changed, 87 insertions(+) create mode 100644 dicp/test/op/test_sequence_at.py create mode 100644 dicp/test/op/test_triu.py diff --git a/dicp/test/ascend_scripts/ops/dynamic.ini b/dicp/test/ascend_scripts/ops/dynamic.ini index 9e0c79387..d1c1b91a0 100644 --- a/dicp/test/ascend_scripts/ops/dynamic.ini +++ b/dicp/test/ascend_scripts/ops/dynamic.ini @@ -1,4 +1,5 @@ [pytest] testpaths = ../../op python_files = + test_sequence_at.py \ No newline at end of file diff --git a/dicp/test/ascend_scripts/ops/static.ini b/dicp/test/ascend_scripts/ops/static.ini index c73cb4f71..5077e5f0e 100644 --- a/dicp/test/ascend_scripts/ops/static.ini +++ b/dicp/test/ascend_scripts/ops/static.ini @@ -71,6 +71,7 @@ python_files = test_repeat_interleave.py ; test_transpose.py test_tril.py + test_triu.py test_unsqueeze.py test_view_as_complex.py test_view_as_real.py diff --git a/dicp/test/op/test_sequence_at.py b/dicp/test/op/test_sequence_at.py new file mode 100644 index 000000000..73535ee8f --- /dev/null +++ b/dicp/test/op/test_sequence_at.py @@ -0,0 +1,45 @@ +import pytest +import operator +from ..common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, a, b): + split = torch.ops.aten.split.Tensor(a, 1) + res = operator.getitem(split, b) + return res + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + +torch._dynamo.config.dynamic_shapes = True +torch._dynamo.config.assume_static_by_default = True + +class TestSequenceAt(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) + @pytest.mark.parametrize("dim", [0]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_operator_sequence_at(self, sizes, dim, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + + in_tensor = torch.randn(size, dtype=dtype) + dicp_input = in_tensor.to(device) + + output = model(in_tensor, dim) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input, dim) + + assert torch.allclose(output, dicp_output.cpu(), equal_nan=True) diff --git a/dicp/test/op/test_triu.py b/dicp/test/op/test_triu.py new file mode 100644 index 000000000..91b0f1f81 --- /dev/null +++ b/dicp/test/op/test_triu.py @@ -0,0 +1,40 @@ +import pytest +from ..common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, a): + res_default = torch.ops.aten.triu(a, 1) + return res_default + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + + +class TestTriu(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size((3, 5), (3, 5)), Size((2, 3, 4), (2, 4))]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_torch_triu(self, sizes, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + input1 = torch.ones(size, dtype=dtype) + + dicp_input1 = input1.to(device) + + output = model(input1) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input1) + + assert torch.allclose(output, dicp_output.cpu(), equal_nan=True) From 7b013bc880a0f027f7e729e845648be13ebaa814 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 17 May 2024 06:17:16 +0000 Subject: [PATCH 28/42] Return several code logic back to original design, and fix flash_incre_attention unit test. --- dicp/dicp/vendor/AscendGraph/conversion.py | 18 +++++++++++------- dicp/dicp/vendor/AscendGraph/ext_ops.py | 6 +++--- dicp/test/op/test_lightllm_incre_attention.py | 2 +- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 822ae910c..97c1c532d 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1209,8 +1209,12 @@ def expand(self, x, shape): # Cast needed only when x_dtype is int64 if x.node.meta['val'].dtype == torch.int64: x = self.get_proxy(ascend_op.Cast, (x, "INT32")) - shape = self.get_shape_proxy(shape) - return self.get_proxy(ascend_op.Expand, (x, shape)) + + if isinstance(shape, list) and not_all_num_shape(shape): + shape = self.get_shape_proxy(shape) + return self.get_proxy(ascend_op.Expand, (x, shape)) + else: + return self.get_proxy(ascend_op.ExpandD, (x, shape)) @register_conversion(torch.ops.aten.slice_backward.default) def slice_backward(self, grad, input_shape, dim, start, end, step): @@ -1369,7 +1373,7 @@ def mean(self, x, dims=[], keepdim=False): @register_conversion(torch.ops.aten.cumsum.default) def cumsum(self, x, dim, dtype=None): - dim_const = self.get_const_proxy(dim, torch.int32) + dim_const = self.get_const_proxy(dim, torch.int32, target_shape=[1]) return self.get_proxy(ascend_op.Cumsum, (x, dim_const)) @register_conversion(torch.ops.aten._log_softmax.default) @@ -1777,16 +1781,16 @@ def copy_with_offset(self, x, src, start_dim, end_dim): return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) @register_conversion(torch.ops.lightllm.flash_attention_inference.default) - def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len, kvhead=None, head=None, dim=None): + def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len, kvhead=-1, head=-1, dim=-1): q_shape = list(q.node.meta['val'].shape) batch = q_shape[0] - if head is None: + if head < 0: head = q_shape[1] - if dim is None: + if dim < 0: dim = q_shape[2] k_shape = list(all_k.node.meta['val'].shape) - if kvhead is None: + if kvhead < 0: kvhead = k_shape[1] res = [] diff --git a/dicp/dicp/vendor/AscendGraph/ext_ops.py b/dicp/dicp/vendor/AscendGraph/ext_ops.py index 1f2e41260..d2dfcc6ed 100644 --- a/dicp/dicp/vendor/AscendGraph/ext_ops.py +++ b/dicp/dicp/vendor/AscendGraph/ext_ops.py @@ -98,12 +98,12 @@ def lightllm_flash_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: @flash_attention_inference.impl(['cpu', 'cuda']) -def lightllm_flash_attention_inference_impl(q, all_k, all_v, current_lens, max_len, kvhead=None, head=None, dim=None): +def lightllm_flash_attention_inference_impl(q, all_k, all_v, current_lens, max_len, kvhead=-1, head=-1, dim=-1): # q: batch, head, dim batch = q.shape[0] - if head is None: + if head < 0: head = q.shape[1] - if dim is None: + if dim < 0: dim = q.shape[2] res = [] compute_batch = 1 diff --git a/dicp/test/op/test_lightllm_incre_attention.py b/dicp/test/op/test_lightllm_incre_attention.py index f2c35ca1b..a8bd1d4da 100644 --- a/dicp/test/op/test_lightllm_incre_attention.py +++ b/dicp/test/op/test_lightllm_incre_attention.py @@ -14,7 +14,7 @@ class OpModule(torch.nn.Module): def forward(self, q, k, v, int_index_list, max_seq_length): - res = torch.ops.lightllm.flash_attention_inference.default(q, k, v, int_index_list, max_seq_length) + res = torch.ops.lightllm.flash_attention_inference.default(q, k, v, int_index_list, max_seq_length, -1, -1, -1) return res From 2a1b6d9077912de1dfc9710b4a955fbdf8c2f6c7 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 17 May 2024 08:11:29 +0000 Subject: [PATCH 29/42] Modify the logic of split implementation. --- dicp/dicp/vendor/AscendGraph/conversion.py | 9 +---- dicp/test/op/test_sequence_at.py | 45 ---------------------- 2 files changed, 1 insertion(+), 53 deletions(-) delete mode 100644 dicp/test/op/test_sequence_at.py diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 97c1c532d..d3e71df5c 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -467,15 +467,8 @@ def split(self, x, split_size, dim=0): if dim < 0: dim += len(shape) assert shape[dim] > 0 - # dynamic feature for lightllm llama 7B - if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: - if isinstance(split_size, int): - split_size = self.get_const_proxy(split_size, torch.int32) - else: - split_size = self.get_proxy(ascend_op.Squeeze, (split_size, [0])) - return self.get_proxy(ascend_op.SplitToSequence, (x, dim, split_size)) - num_split = int((shape[dim] + split_size - 1) / split_size) + num_split = len(fx_traceback.get_current_meta()['val']) return self.get_proxy(ascend_op.SplitD, (x, dim, num_split, num_split), splitD_kw) @register_conversion(aten.slice.Tensor) diff --git a/dicp/test/op/test_sequence_at.py b/dicp/test/op/test_sequence_at.py deleted file mode 100644 index 73535ee8f..000000000 --- a/dicp/test/op/test_sequence_at.py +++ /dev/null @@ -1,45 +0,0 @@ -import pytest -import operator -from ..common.utils import ( - torch, - dynamo, - parse_args, - compile_model, - get_device, - Size, - update_dynamo_config, -) - - -class OpModule(torch.nn.Module): - def forward(self, a, b): - split = torch.ops.aten.split.Tensor(a, 1) - res = operator.getitem(split, b) - return res - - -model = OpModule() -args = parse_args() -compiled_model = compile_model(model, args.backend, args.dynamic) - -torch._dynamo.config.dynamic_shapes = True -torch._dynamo.config.assume_static_by_default = True - -class TestSequenceAt(): - @pytest.mark.parametrize("dtype", [torch.float32]) - @pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) - @pytest.mark.parametrize("dim", [0]) - @pytest.mark.parametrize("compiled_model", compiled_model) - def test_operator_sequence_at(self, sizes, dim, dtype, compiled_model): - device = get_device() - size = sizes.dynamic if compiled_model.dynamic else sizes.static - - in_tensor = torch.randn(size, dtype=dtype) - dicp_input = in_tensor.to(device) - - output = model(in_tensor, dim) - dynamo.reset() - update_dynamo_config(compiled_model.dynamic) - dicp_output = compiled_model.model(dicp_input, dim) - - assert torch.allclose(output, dicp_output.cpu(), equal_nan=True) From 10b1603c974d798f4a809b78f8c2ff1b22be1f9c Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 17 May 2024 08:18:08 +0000 Subject: [PATCH 30/42] Add split dynamic case. --- dicp/test/ascend_scripts/ops/dynamic.ini | 2 +- dicp/test/op/test_split_dynamic.py | 45 ++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 dicp/test/op/test_split_dynamic.py diff --git a/dicp/test/ascend_scripts/ops/dynamic.ini b/dicp/test/ascend_scripts/ops/dynamic.ini index d1c1b91a0..50aa2a9cd 100644 --- a/dicp/test/ascend_scripts/ops/dynamic.ini +++ b/dicp/test/ascend_scripts/ops/dynamic.ini @@ -1,5 +1,5 @@ [pytest] testpaths = ../../op python_files = - test_sequence_at.py + test_split_dynamic.py \ No newline at end of file diff --git a/dicp/test/op/test_split_dynamic.py b/dicp/test/op/test_split_dynamic.py new file mode 100644 index 000000000..73535ee8f --- /dev/null +++ b/dicp/test/op/test_split_dynamic.py @@ -0,0 +1,45 @@ +import pytest +import operator +from ..common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, a, b): + split = torch.ops.aten.split.Tensor(a, 1) + res = operator.getitem(split, b) + return res + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + +torch._dynamo.config.dynamic_shapes = True +torch._dynamo.config.assume_static_by_default = True + +class TestSequenceAt(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) + @pytest.mark.parametrize("dim", [0]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_operator_sequence_at(self, sizes, dim, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + + in_tensor = torch.randn(size, dtype=dtype) + dicp_input = in_tensor.to(device) + + output = model(in_tensor, dim) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input, dim) + + assert torch.allclose(output, dicp_output.cpu(), equal_nan=True) From d5689c03f8530ad9108eb2c0b86f8549544c8ce9 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 17 May 2024 09:00:27 +0000 Subject: [PATCH 31/42] Remove identity additional logic, wrap convert into promote_dtype. --- dicp/dicp/vendor/AscendGraph/conversion.py | 23 +++++++--------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index d3e71df5c..16d9e1699 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1524,13 +1524,7 @@ def sigmoid(self, x): @register_conversion(operator.getitem) def identity(self, x, idx): - # common case - if not 'split_to_sequence' in str(x): - return self.get_proxy(ascend_op.Identity, (x, idx)) - # split_to_sequence return handle sequence - else: - idx = self.get_const_proxy(idx, torch.int32) - return self.get_proxy(ascend_op.SequenceAt, (x, idx)) + return self.get_proxy(ascend_op.Identity, (x, idx)) @register_conversion(torch.ops.aten.full_like) def fulllike(self, x, value, dtype=torch.float32, layout=torch.strided, @@ -1562,9 +1556,8 @@ def RandLike(self, x, dtype=torch.float32, layout=torch.strided, def GtScalar(self, x, y): dtype = get_ascend_dtype(x.node.meta['val'].dtype) scalar_op = self.get_const_proxy(float(y), torch.float32) - if dtype != torch.float32: - cast_op = self.get_proxy(ascend_op.Cast, (scalar_op, dtype)) - return self.get_proxy(ascend_op.Greater, (x, cast_op)) + x, scalar_op = self.promote_dtype(x, scalar_op, target_dtype=dtype) + return self.get_proxy(ascend_op.Greater, (x, scalar_op)) @register_conversion(torch.ops.aten.addcmul.default) def AddCMul(self, a, b, c, value=1): @@ -1630,12 +1623,10 @@ def Ge(self, x, y): if not isinstance(y, torch.fx.proxy.Proxy): y = self.get_const_proxy(y, x_dtype) else: - if not isinstance(y.node.meta['val'], torch.SymInt): - y_dtype = y.node.meta['val'].dtype - else: - y_dtype = torch.int32 - if x_dtype != y_dtype: - y = self.get_proxy(ascend_op.Cast, (y, get_ascend_dtype(x_dtype))) + if isinstance(y.node.meta['val'], torch.SymInt): + y = self.get_shape_proxy([y]) + y = self.get_proxy(ascend_op.Squeeze, (y, [0])) + x, y = self.promote_dtype(x, y, target_dtype=x_dtype) return self.get_proxy(ascend_op.GreaterEqual, (x, y)) @register_conversion(torch.ops.aten.logical_or.default) From 4e499b970b38a7eb1050a18e7705f1f7f2ebdd10 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 23 May 2024 07:23:47 +0000 Subject: [PATCH 32/42] Pass ge unit test. --- dicp/dicp/vendor/AscendGraph/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 16d9e1699..e240e2498 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1625,7 +1625,7 @@ def Ge(self, x, y): else: if isinstance(y.node.meta['val'], torch.SymInt): y = self.get_shape_proxy([y]) - y = self.get_proxy(ascend_op.Squeeze, (y, [0])) + y = self.get_proxy(ascend_op.Squeeze, (y, [0])) x, y = self.promote_dtype(x, y, target_dtype=x_dtype) return self.get_proxy(ascend_op.GreaterEqual, (x, y)) From 381131927bda1951bd336cb9c4a490c930ac254f Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 23 May 2024 07:39:38 +0000 Subject: [PATCH 33/42] Modify logic of lt dtype, and prompt_attention fp16 conversion. --- dicp/dicp/vendor/AscendGraph/conversion.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index e240e2498..686f1a39f 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -653,18 +653,9 @@ def ne(self, a, b): @register_conversion([aten.lt.Scalar, aten.lt.Tensor]) def lt(self, x, y): - x_shape = list(x.node.meta['val'].shape) - y_shape = [] if not isinstance( - y, torch.fx.proxy.Proxy) else list(y.node.meta['val'].shape) - out = list(fx_traceback.get_current_meta()['val'].shape) - out_shape = self.get_shape_proxy(out) - x, y = self.binary_cmp_cast_input(x, y) - dynamic_shape = not_all_num_shape(x_shape) or not_all_num_shape( - y_shape) or not_all_num_shape(out) - if dynamic_shape and (self.shape_prod(x_shape) < self.shape_prod(out)): - x = self.get_proxy(ascend_op.BroadcastTo, (x, out_shape)) - if dynamic_shape and (self.shape_prod(y_shape) < self.shape_prod(out)): - y = self.get_proxy(ascend_op.BroadcastTo, (y, out_shape)) + if not isinstance(y, torch.fx.proxy.Proxy): + x_dtype = x.node.meta['val'].dtype + y = self.get_const_proxy(y, x_dtype) return self.get_proxy(ascend_op.Less, (x, y)) @register_conversion(aten.masked_fill.Scalar) @@ -1711,7 +1702,8 @@ def prompt_attention_inference(self, q, k, v, seqlen, num_head, head_dim): mask = self.get_proxy(ascend_op.OnesLike, (mask,)) mask = self.get_proxy(ascend_op.Tril, (mask,)) mask = self.get_proxy(ascend_op.LogicalNot, (mask,)) - q = self.get_proxy(ascend_op.Cast, (q, get_ascend_dtype(torch.float16))) + if q.node.meta['val'].dtype != torch.float16: + q = self.get_proxy(ascend_op.Cast, (q, get_ascend_dtype(torch.float16))) fa = self.get_proxy(ascend_op.PromptFlashAttention, (q, k, v, num_head, seqlen, mask, head_dim)) return fa From a62933513810173650cb85996566a8caa588b23f Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 23 May 2024 08:24:08 +0000 Subject: [PATCH 34/42] Add promote_dtype priority logic. --- dicp/dicp/vendor/AscendGraph/conversion.py | 39 ++++++++++++++++++++-- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 686f1a39f..0dd61933e 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -297,8 +297,32 @@ def get_const_proxy(self, param, dtype, format=None, target_shape=None): return param def promote_dtype(self, *args, target_dtype): + priority = {torch.bool:0, + torch.int32:1, + torch.int64:2, + torch.float16:3, + torch.float32:4, + torch.float64:5, + None:-1} result = [] + + # align maximum arg priority + max_prio = -1 + for arg in args: + if isinstance(arg, torch.fx.proxy.Proxy): + cur_prio = priority[try_to_get_dtype(arg)] + if cur_prio > max_prio: + max_prio = cur_prio + + # align priority between arg max and target_dtype ascend_dtype = get_ascend_dtype(target_dtype) + cur_prio = priority[target_dtype] + if cur_prio > max_prio: + max_prio = cur_prio + assert max_prio > -1 + target_dtype = list(priority.keys())[max_prio] + + # core of dtype conversion for arg in args: if isinstance(arg, torch.fx.proxy.Proxy): current_dtype = try_to_get_dtype(arg) @@ -653,9 +677,18 @@ def ne(self, a, b): @register_conversion([aten.lt.Scalar, aten.lt.Tensor]) def lt(self, x, y): - if not isinstance(y, torch.fx.proxy.Proxy): - x_dtype = x.node.meta['val'].dtype - y = self.get_const_proxy(y, x_dtype) + x_shape = list(x.node.meta['val'].shape) + y_shape = [] if not isinstance( + y, torch.fx.proxy.Proxy) else list(y.node.meta['val'].shape) + out = list(fx_traceback.get_current_meta()['val'].shape) + out_shape = self.get_shape_proxy(out) + x, y = self.binary_cmp_cast_input(x, y) + dynamic_shape = not_all_num_shape(x_shape) or not_all_num_shape( + y_shape) or not_all_num_shape(out) + if dynamic_shape and (self.shape_prod(x_shape) < self.shape_prod(out)): + x = self.get_proxy(ascend_op.BroadcastTo, (x, out_shape)) + if dynamic_shape and (self.shape_prod(y_shape) < self.shape_prod(out)): + y = self.get_proxy(ascend_op.BroadcastTo, (y, out_shape)) return self.get_proxy(ascend_op.Less, (x, y)) @register_conversion(aten.masked_fill.Scalar) From 93414886b1162ec7bfb3ad0330be8e93d04046ec Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 23 May 2024 08:35:25 +0000 Subject: [PATCH 35/42] Fix promote_dtype bug. --- dicp/dicp/vendor/AscendGraph/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 0dd61933e..60a948740 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -315,12 +315,12 @@ def promote_dtype(self, *args, target_dtype): max_prio = cur_prio # align priority between arg max and target_dtype - ascend_dtype = get_ascend_dtype(target_dtype) cur_prio = priority[target_dtype] if cur_prio > max_prio: max_prio = cur_prio assert max_prio > -1 target_dtype = list(priority.keys())[max_prio] + ascend_dtype = get_ascend_dtype(target_dtype) # core of dtype conversion for arg in args: From c84fbdd4c69cdfde6d88925dffefda5edb6a91f5 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 23 May 2024 11:03:00 +0000 Subject: [PATCH 36/42] Cast back fa to float32 if dtype not consistent. --- dicp/dicp/vendor/AscendGraph/conversion.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 60a948740..6fcd5e7b4 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1728,6 +1728,7 @@ def lightllm_rms_norm(self, x, weight, eps): @register_conversion(torch.ops.lightllm.prompt_attention_inference.default) def prompt_attention_inference(self, q, k, v, seqlen, num_head, head_dim): q_shape = list(q.node.meta['val'].shape) + q_dtype = q.node.meta['val'].dtype seq_len = q_shape[1] shape = [seq_len, seq_len] shape = self.get_shape_proxy(shape) @@ -1735,9 +1736,13 @@ def prompt_attention_inference(self, q, k, v, seqlen, num_head, head_dim): mask = self.get_proxy(ascend_op.OnesLike, (mask,)) mask = self.get_proxy(ascend_op.Tril, (mask,)) mask = self.get_proxy(ascend_op.LogicalNot, (mask,)) - if q.node.meta['val'].dtype != torch.float16: + if q_dtype != torch.float16: q = self.get_proxy(ascend_op.Cast, (q, get_ascend_dtype(torch.float16))) fa = self.get_proxy(ascend_op.PromptFlashAttention, (q, k, v, num_head, seqlen, mask, head_dim)) + + # cast back fa to float32 + if q_dtype != torch.float16: + fa = self.get_proxy(ascend_op.Cast, (fa, get_ascend_dtype(torch.float32))) return fa def incre_flash_attention(self, q, k, v, head_num, kv_head_num, dim): From 151f4b2eed97d84dc7ce251ff53e571d99ad0da1 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 23 May 2024 11:10:22 +0000 Subject: [PATCH 37/42] Change to return q_dtype tensor. --- dicp/dicp/vendor/AscendGraph/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 6fcd5e7b4..9cdd98a5d 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1742,7 +1742,7 @@ def prompt_attention_inference(self, q, k, v, seqlen, num_head, head_dim): # cast back fa to float32 if q_dtype != torch.float16: - fa = self.get_proxy(ascend_op.Cast, (fa, get_ascend_dtype(torch.float32))) + return self.get_proxy(ascend_op.Cast, (fa, get_ascend_dtype(q_dtype))) return fa def incre_flash_attention(self, q, k, v, head_num, kv_head_num, dim): From bdecf8cee16202abf8006ba5037e275c2d366f1d Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 24 May 2024 05:22:18 +0000 Subject: [PATCH 38/42] Improve promote_dtype logic. --- dicp/dicp/vendor/AscendGraph/conversion.py | 30 +++++++++++++--------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 9cdd98a5d..9c8445c7d 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -48,6 +48,14 @@ def try_to_get_dtype(x): return x.node.meta['val'].dtype else: return None + + # handle with basic scalar type + if isinstance(x, bool): + return torch.bool + elif isinstance(x, int): + return torch.int32 + elif isinstance(x, float): + return torch.float32 return None @@ -309,10 +317,9 @@ def promote_dtype(self, *args, target_dtype): # align maximum arg priority max_prio = -1 for arg in args: - if isinstance(arg, torch.fx.proxy.Proxy): - cur_prio = priority[try_to_get_dtype(arg)] - if cur_prio > max_prio: - max_prio = cur_prio + cur_prio = priority[try_to_get_dtype(arg)] + if cur_prio > max_prio: + max_prio = cur_prio # align priority between arg max and target_dtype cur_prio = priority[target_dtype] @@ -333,6 +340,9 @@ def promote_dtype(self, *args, target_dtype): # 1. unable to get tensor dtype # 2. current_dtype != target_dtype result.append(self.get_proxy(ascend_op.Cast, (arg, ascend_dtype))) + elif try_to_get_dtype(arg) is not None: + # handle with scalar case + result.append(self.get_const_proxy(arg, target_dtype)) else: raise RuntimeError("Not implemented") return tuple(result) if len(result) > 1 else result[0] @@ -1579,7 +1589,6 @@ def RandLike(self, x, dtype=torch.float32, layout=torch.strided, @register_conversion(torch.ops.aten.gt.Scalar) def GtScalar(self, x, y): dtype = get_ascend_dtype(x.node.meta['val'].dtype) - scalar_op = self.get_const_proxy(float(y), torch.float32) x, scalar_op = self.promote_dtype(x, scalar_op, target_dtype=dtype) return self.get_proxy(ascend_op.Greater, (x, scalar_op)) @@ -1644,13 +1653,10 @@ def Repeat(self, x, repeats): @register_conversion([torch.ops.aten.ge.Scalar, torch.ops.aten.ge.Tensor]) def Ge(self, x, y): x_dtype = x.node.meta['val'].dtype - if not isinstance(y, torch.fx.proxy.Proxy): - y = self.get_const_proxy(y, x_dtype) - else: - if isinstance(y.node.meta['val'], torch.SymInt): - y = self.get_shape_proxy([y]) - y = self.get_proxy(ascend_op.Squeeze, (y, [0])) - x, y = self.promote_dtype(x, y, target_dtype=x_dtype) + if isinstance(y, torch.fx.proxy.Proxy) and isinstance(y.node.meta['val'], torch.SymInt): + y = self.get_shape_proxy([y]) + y = self.get_proxy(ascend_op.Squeeze, (y, [0])) + x, y = self.promote_dtype(x, y, target_dtype=x_dtype) return self.get_proxy(ascend_op.GreaterEqual, (x, y)) @register_conversion(torch.ops.aten.logical_or.default) From 69be14cd5dcda914bfda014ae46c84838ae30d6f Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 24 May 2024 06:18:34 +0000 Subject: [PATCH 39/42] Add const proxy logic for promote_dtype. --- dicp/dicp/vendor/AscendGraph/conversion.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 9c8445c7d..9d0f47fb9 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -46,6 +46,10 @@ def try_to_get_dtype(x): if isinstance(x, torch.fx.proxy.Proxy): if hasattr(x.node, "meta") and "val" in x.node.meta.keys(): return x.node.meta['val'].dtype + elif 'const' in str(x): + # handle with const proxy dtype + assert len(x.node.args) > 1 + return x.node.args[1] else: return None From 052bf1e7f4a86766e87f8f1d9b29a2d39a38eb94 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 24 May 2024 08:47:55 +0000 Subject: [PATCH 40/42] Fix flash_attention declaration. --- dicp/dicp/vendor/AscendGraph/ext_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/ext_ops.py b/dicp/dicp/vendor/AscendGraph/ext_ops.py index d2dfcc6ed..baa3eefc9 100644 --- a/dicp/dicp/vendor/AscendGraph/ext_ops.py +++ b/dicp/dicp/vendor/AscendGraph/ext_ops.py @@ -88,12 +88,12 @@ def lightllm_prompt_attention_inference_impl(q, k, v, seqlen, num_head, head_dim @torch._custom_op.impl.custom_op('lightllm::flash_attention_inference') -def flash_attention_inference(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int, kvhead: int, head: int, dim: int) -> Tensor: +def flash_attention_inference(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int) -> Tensor: ... @flash_attention_inference.impl_abstract() -def lightllm_flash_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int, kvhead: int, head: int, dim: int): +def lightllm_flash_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int): return torch.empty_like(q) From 06d0bcc0f18866a1ed2c70f6766f2115c2ba874c Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Fri, 24 May 2024 09:39:27 +0000 Subject: [PATCH 41/42] Remove Symint & Proxy from 7B static path. --- dicp/dicp/vendor/AscendGraph/conversion.py | 14 ++++++++++---- dicp/dicp/vendor/AscendGraph/ext_ops.py | 4 ++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 9d0f47fb9..26cc994b0 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1827,10 +1827,16 @@ def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len, kvhe xq = self.get_proxy(ascend_op.GatherV2, (q, select_index, select_axis)) kv_start_index = self.get_const_proxy([i * max_len, 0, 0], torch.int32) - imax_const = self.get_const_proxy(i * max_len, torch.int32) - curlen_const = self.get_const_proxy(current_len, torch.int32) - end_proxy = self.get_proxy(ascend_op.Add, (curlen_const, imax_const)) - kv_end_index = self.get_shape_proxy([end_proxy, kvhead, dim], torch.int32) + + # split path for dynamic & static 7B + if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: + imax_const = self.get_const_proxy(i * max_len, torch.int32) + curlen_const = self.get_const_proxy(current_len, torch.int32) + end_proxy = self.get_proxy(ascend_op.Add, (curlen_const, imax_const)) + kv_end_index = self.get_shape_proxy([end_proxy, kvhead, dim], torch.int32) + else: + kv_end_index = self.get_const_proxy([i * max_len + current_len, kvhead, dim], torch.int32) + kv_seq_len = current_len kv_gather_shape = self.get_shape_proxy([compute_batch, kv_seq_len, kvhead, dim]) diff --git a/dicp/dicp/vendor/AscendGraph/ext_ops.py b/dicp/dicp/vendor/AscendGraph/ext_ops.py index baa3eefc9..d2dfcc6ed 100644 --- a/dicp/dicp/vendor/AscendGraph/ext_ops.py +++ b/dicp/dicp/vendor/AscendGraph/ext_ops.py @@ -88,12 +88,12 @@ def lightllm_prompt_attention_inference_impl(q, k, v, seqlen, num_head, head_dim @torch._custom_op.impl.custom_op('lightllm::flash_attention_inference') -def flash_attention_inference(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int) -> Tensor: +def flash_attention_inference(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int, kvhead: int, head: int, dim: int) -> Tensor: ... @flash_attention_inference.impl_abstract() -def lightllm_flash_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int): +def lightllm_flash_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int, kvhead: int, head: int, dim: int): return torch.empty_like(q) From 5ed78ccca5800b01216f86456f3985724902fa60 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Mon, 27 May 2024 05:56:28 +0000 Subject: [PATCH 42/42] Change const judge method. --- dicp/dicp/vendor/AscendGraph/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 26cc994b0..ce593bf84 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -46,7 +46,7 @@ def try_to_get_dtype(x): if isinstance(x, torch.fx.proxy.Proxy): if hasattr(x.node, "meta") and "val" in x.node.meta.keys(): return x.node.meta['val'].dtype - elif 'const' in str(x): + elif isinstance(x.node.target, ascend_op.Const): # handle with const proxy dtype assert len(x.node.args) > 1 return x.node.args[1]