Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dicp][ascend] Optimization for dynamic shape code logic. #791

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
39d38b0
Refine code structure of dynamic shape handling.
pdx1989 Apr 23, 2024
5a97b9b
Adjust symint_to_args relationship code logic.
pdx1989 Apr 24, 2024
e6f4e43
Remove redundant code.
pdx1989 Apr 24, 2024
d2333fc
Enable 70B get_qkv stage dynamic shape.
pdx1989 Apr 24, 2024
54167a9
Fix complex size append.
pdx1989 Apr 25, 2024
3dec435
Change load_and_run in/out shape assignment.
pdx1989 Apr 25, 2024
2666a3f
Refine variable replacement in in/out shape structure.
pdx1989 Apr 25, 2024
de1dc08
Merge branch 'daoxin/lightllm_dynamic_shape' into daoxin/dynamic_shap…
pdx1989 Apr 25, 2024
0389a64
Fix merge bugs.
pdx1989 Apr 25, 2024
cfee2df
Merge branch 'main' into daoxin/dynamic_shape_optimization
pdx1989 Apr 25, 2024
5a2fd6a
Change one comment and variable name.
pdx1989 Apr 25, 2024
03ba1a4
Fix an array assignment change.
pdx1989 Apr 25, 2024
71f6c61
Code refinement including:
pdx1989 Apr 26, 2024
25c5c56
Get clear idea for expand Cast situation.
pdx1989 Apr 26, 2024
9025019
Apply some idea from Gpt AI.
pdx1989 Apr 26, 2024
62a6b36
Revert "Apply some idea from Gpt AI."
pdx1989 Apr 26, 2024
2f6bd52
Remove dead use, replace const proxy.
pdx1989 Apr 28, 2024
5629a80
Merge branch 'daoxin/lightllm_dynamic_support' into daoxin/dynamic_sh…
pdx1989 Apr 30, 2024
9df01a5
Merge branch 'main' into daoxin/dynamic_shape_optimization
pdx1989 Apr 30, 2024
92521eb
Support 7B dynamic shape version.
pdx1989 May 10, 2024
8898541
Pass 1st dynamic graph model.
CyCle1024 May 11, 2024
8971b85
Pass both two graph model for 7B dynamic shape version.
pdx1989 May 14, 2024
08e73f5
Fix ci case incre_flash_attention.
pdx1989 May 14, 2024
53f728d
Change split execution path for both shape mode.
pdx1989 May 15, 2024
3e36e2e
Add execution path for copy_with_offset.
pdx1989 May 15, 2024
6c40ad4
Merge copy_with_offset shape path mode.
pdx1989 May 15, 2024
0ce746f
Add const proxy for int split_size.
pdx1989 May 16, 2024
e5c06ab
Move some common functions into util.
pdx1989 May 16, 2024
9f0472a
Add path for flash_attention to pass head, kvhead and dim in.
pdx1989 May 16, 2024
62957f7
Cancel path split for slice start proxy form.
pdx1989 May 16, 2024
709c009
Add sequenceAt & triu test unit case.
pdx1989 May 17, 2024
7b013bc
Return several code logic back to original design, and fix flash_incr…
pdx1989 May 17, 2024
2a1b6d9
Modify the logic of split implementation.
pdx1989 May 17, 2024
10b1603
Add split dynamic case.
pdx1989 May 17, 2024
d5689c0
Remove identity additional logic, wrap convert into promote_dtype.
pdx1989 May 17, 2024
4e499b9
Pass ge unit test.
pdx1989 May 23, 2024
3811319
Modify logic of lt dtype, and prompt_attention fp16 conversion.
pdx1989 May 23, 2024
a629335
Add promote_dtype priority logic.
pdx1989 May 23, 2024
9341488
Fix promote_dtype bug.
pdx1989 May 23, 2024
c84fbdd
Cast back fa to float32 if dtype not consistent.
pdx1989 May 23, 2024
151f4b2
Change to return q_dtype tensor.
pdx1989 May 23, 2024
bdecf8c
Improve promote_dtype logic.
pdx1989 May 24, 2024
69be14c
Add const proxy logic for promote_dtype.
pdx1989 May 24, 2024
052bf1e
Fix flash_attention declaration.
pdx1989 May 24, 2024
06d0bcc
Remove Symint & Proxy from 7B static path.
pdx1989 May 24, 2024
5ed78cc
Change const judge method.
pdx1989 May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions dicp/dicp/dynamo_bridge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,69 @@
from torch.fx.node import Argument, Target


def not_all_num_shape(shape):
for elem in shape:
if not isinstance(elem, int):
return True
return False


def symint_in_shape(shape):
for elem in shape:
if isinstance(elem, torch.SymInt):
return True
return False


def neg_in_shape(shape):
for elem in shape:
if isinstance(elem, int):
if elem < 0:
return True
return False


def process_sym_name(st):
# dynamic shape feature
# return string wrapper in new version
# node.str() will not fallback SymInt value form
if isinstance(st, torch.SymInt):
return st.node.str()
return str(st)


def preprocess_expression(expr):
elem_str = process_sym_name(expr)
elem_str = elem_str.replace('+', ' + ')
elem_str = elem_str.replace('-', ' - ')
elem_str = elem_str.replace('*', ' * ')
elem_str = elem_str.replace('//', ' // ')
elem_str = elem_str.replace('(', ' ( ')
elem_str = elem_str.replace(')', ' ) ')
elems = elem_str.split(' ')
elems = [e for e in elems if e != '']
return elems


def find_root_num(set_num, num):
while set_num[num] != num:
num = set_num[num]
return num


def merge_disjoint_set(set_num, idx_a, idx_b):
root_a = find_root_num(set_num, idx_a)
root_b = find_root_num(set_num, idx_b)
# an example for (s5 / 8) - (s5 / 16)
# num: 0 1 2 3
# step1 - > set_num: 0 1 2 3
# step2 - > set_num: 0 0 2 2
# step3 - > set_num: 0 0 0 0

# return merged set from root_b to root_a
return [root_a if find_root_num(set_num, s) == root_b else s for s in set_num]


def save_cpu_gm(gm: torch.fx.GraphModule, folder: str):
Path(folder).mkdir(exist_ok=True)
cpu_gm = copy_gm_to_cpu(gm)
Expand Down
36 changes: 35 additions & 1 deletion dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ def infer_result(self, x):
return common_unary_op_infer(x)


class Triu(Operator):
def __init__(self):
super().__init__("Triu")

def infer_result(self, x, diag):
return common_unary_op_infer(x)


class Sqrt(Operator):
def __init__(self):
super().__init__("Sqrt")
Expand Down Expand Up @@ -567,6 +575,27 @@ def __init__(self):
super().__init__("CastToCpu")


class SequenceAt(Operator):
def __init__(self):
super().__init__("SequenceAt")

def infer_result(self, x, idx=None):
x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x)
if isinstance(x, (List, Tuple)):
return x[idx]
out_dtype = x_dtype
if x_dtype == torch.complex64: # for complex64
out_shape = list(x_shape)
if idx == 0 or idx == 1:
out_dtype = torch.float32
out_shape.append(1)
else:
out_shape = [x_shape[idx]] if idx is not None else list(x_shape)
return torch.empty(
out_shape, dtype=out_dtype, memory_format=get_memory_format(x)
)


class Identity(Operator):
def __init__(self):
super().__init__("Identity")
Expand Down Expand Up @@ -750,7 +779,6 @@ def __init__(self, x):
def infer_result(self, x):
return common_unary_op_infer(x)


class SplitD(Operator):
def __init__(self):
super().__init__("SplitD")
Expand All @@ -770,6 +798,12 @@ def infer_result(self, x, split_dim, num_split, y, from_view_complex=False):
memory_format=get_memory_format(x),
)

class SplitToSequence(Operator):
def __init__(self):
super().__init__("SplitToSequence")

def infer_result(self, x, split_dim, split_size):
torch.split(x, split_size, split_dim)

class Slice(Operator):
def __init__(self):
Expand Down
121 changes: 53 additions & 68 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.fx.node import Node
from torch.utils._pytree import tree_map_only
from torch._inductor.utils import IndentedBuffer
from dicp.dynamo_bridge.utils import symint_in_shape
from dicp.dynamo_bridge.utils import symint_in_shape, process_sym_name
from dicp.vendor.AscendGraph.codegen.utils import (
get_ascend_dtype,
get_cpp_dtype,
Expand Down Expand Up @@ -219,58 +219,11 @@ def check_tensor(a, b, atol=5e-2, rtol=1e-2):
)
return self.import_code.getvalue()

def process_sym_name(self, st):
# dynamic shape feature
if st.isdigit():
return st
elif '+' in st:
sp = st.split('+')
if len(sp) > 2:
sp = [sp[0], '+'.join(sp[1:])]
assert (len(sp) == 2)
sp = [elem.strip() for elem in sp]
if sp[0].isdigit():
(sp[1], sp[0]) = (sp[0], sp[1])
if sp[0] in self.sym_in_args:
arg, idx = self.sym_in_args[sp[0]]
return "{}.shape[{}]".format(arg, idx) + '+' + sp[1]
if sp[0] in self.sym_to_inputs.keys():
return self.sym_to_inputs[sp[0]] + '+' + sp[1]
else:
return self.process_sym_name(sp[0]) + '+' + sp[1]
elif '-' in st:
sp = st.split('-')
if len(sp) > 2:
sp = [sp[0], '-'.join(sp[1:])]
assert (len(sp) == 2)
sp = [elem.strip() for elem in sp]
if sp[0] in self.sym_in_args:
arg, idx = self.sym_in_args[sp[0]]
return "{}.shape[{}]".format(arg, idx) + '-' + sp[1]
if sp[0] in self.sym_to_inputs.keys():
return self.sym_to_inputs[sp[0]] + '-' + sp[1]
else:
return self.process_sym_name(sp[0]) + '-' + sp[1]
elif '*' in st:
sp = st.split('*')
if len(sp) > 2:
sp = [sp[0], '*'.join(sp[1:])]
assert (len(sp) == 2)
sp = [elem.strip() for elem in sp]
if sp[0].isdigit():
(sp[1], sp[0]) = (sp[0], sp[1])
if sp[0] in self.sym_in_args:
arg, idx = self.sym_in_args[sp[0]]
return "{}.shape[{}]".format(arg, idx) + '*' + sp[1]
if sp[0] in self.sym_to_inputs.keys():
return self.sym_to_inputs[sp[0]] + '*' + sp[1]
else:
return self.process_sym_name(sp[0]) + '*' + sp[1]
else:
if st in self.sym_in_args:
arg, idx = self.sym_in_args[st]
return "{}.shape[{}]".format(arg, idx)
return self.sym_to_inputs[st]
def operator_in_str(self, st):
for op in ['+', '-', '*', '/']:
if op in st:
return True
return False

def gen_call_func(self):
# TODO check scalar input
Expand All @@ -283,16 +236,23 @@ def gen_call_func(self):
args = ['_' if arg not in shape_symint and arg not in self.sym_to_inputs.values() else arg for arg in self.args]
call_body.writeline(f"({','.join(args)}) = args")

# assign SymInt to InputArgs relationship
if len(self.sym_in_args) > 0:
for key in self.sym_in_args.keys():
if not key.isdigit() and not self.operator_in_str(key):
call_body.writeline(f"{key} = {self.sym_in_args[key][0]}.shape[{self.sym_in_args[key][1]}]")
if len(self.sym_to_inputs) > 0:
for key in self.sym_to_inputs.keys():
if not key.isdigit() and not self.operator_in_str(key):
call_body.writeline(f"{key} = {self.sym_to_inputs[key]}")

# generate input dims
if len(self.dynamic_inputs) > 0:
dim_len = 0
for shape in self.actual_shape:
dim_len += len(shape)
dims = 'dims = {'
for idx, elem in enumerate(self.actual_shape):
if len(elem) == 0:
continue
elem = [self.process_sym_name(dim) for dim in elem]
elem = [process_sym_name(dim) for dim in elem]
dims += str(self.dynamic_index[idx]) + \
":[" + ','.join(map(str, elem)) + '],'
dims = dims[:-1] + '}'
Expand All @@ -315,20 +275,20 @@ def gen_call_func(self):
shape = list(elem.shape)
if len(shape) == 0:
raise RuntimeError("Error handling empty output_shape")
shape = [self.process_sym_name(str(dim)) for dim in shape]
shape = [process_sym_name(dim) for dim in shape]
shape_str += "[" + ','.join(map(str, shape)) + "],"

# process output_shape with modified args
for elem in self.assign_args:
shape = list(self.input_args[elem[1]].meta['val'].shape)
if len(shape) == 0:
raise RuntimeError("Error handling empty output_shape")
shape = [self.process_sym_name(str(dim)) for dim in shape]
shape = [process_sym_name(dim) for dim in shape]
shape_str += "[" + ','.join(map(str, shape)) + "],"
stride = list(self.input_args[elem[1]].meta['val'].stride())
if len(stride) == 0:
raise RuntimeError("Error handling empty output_stride")
stride = [self.process_sym_name(str(dim)) for dim in stride]
stride = [process_sym_name(dim) for dim in stride]
extra_stride_str += '[' + ','.join(map(str, stride)) + '],'
extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ','
shape_str = shape_str[:-1] + ''']'''
Expand All @@ -342,21 +302,23 @@ def gen_call_func(self):
for elem in self.output_args:
if hasattr(elem, 'meta'):
elem = elem.meta['val']
if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool):
out_strides.append('[1]')
out_storage_offsets.append('0')
continue
if elem.dim() == 0: # temporary solution for sum.default(a) whose result is a scalar(no dim no stride)

# temporary solution for sum.default(a) whose result is scalar or with no dim no stride
if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool) or elem.dim() == 0:
out_strides.append('[1]')
out_storage_offsets.append('0')
continue
stride = list(elem.stride())
stride = [self.process_sym_name(str(dim)) for dim in stride]
out_strides.append(str(stride))
stride = [process_sym_name(dim) for dim in stride]
out_strides.append('[' + ','.join(map(str, stride)) + ']')
out_storage_offsets.append(elem.storage_offset())
call_body.writeline(f'out_stride = {out_strides}')
call_body.writeline(f'''out_stride = [{','.join(out_strides)}]''')
call_body.writeline(f'out_storage_offset = {out_storage_offsets}')

# In precision debug mode, modified array recording InputArgs integer needed
if precision_check and self.aten_graph is not None:
call_body.writeline(f"modified = [idx for idx in range(len(args))] if isinstance(args[idx], int)")

call_body.splice("""
import torch_dipu
dipu_device_str = torch_dipu.dipu.device.__diputype__
Expand Down Expand Up @@ -827,6 +789,13 @@ def Rsqrt(name, x):
op.set_input("x", x)
return op.to_node()

@staticmethod
def Triu(name, x, diag):
op = OP(name, "Triu")
op.set_input("x", x)
op.set_attr_int("diagonal", diag)
return op.to_node()

@staticmethod
def Conv2D(name, input, weight, stride, padding,
dilation, groups, format, bias):
Expand Down Expand Up @@ -915,6 +884,14 @@ def Identity(name, input, index=None):
op.set_input("x", input)
return op.to_node()

@staticmethod
def SequenceAt(name, input, index=None):
op = OP(name, "SequenceAt")
assert index is not None
op.set_input("handle", input)
op.set_input("index", index)
return op.to_node()

@staticmethod
def IdentityInp(name, input, dst=None):
op = OP(name, "Identity")
Expand Down Expand Up @@ -1377,6 +1354,14 @@ def SplitD(name, x, dim, num_split, y, from_view_complex):
split_op.set_dynamic_output("y", y)
return split_op.to_node()

@staticmethod
def SplitToSequence(name, x, dim, split_size):
split_op = OP(name, "SplitToSequence")
split_op.set_input("x", x)
split_op.set_input("split", split_size)
split_op.set_attr_int("axis", dim)
return split_op.to_node()

@staticmethod
def Pack(name, x, axis):
x_name = []
Expand Down
Loading
Loading