Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dicp][ascend]Support ascend stable_diffusion. #623

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ def infer_result(self, x1, x2, adj_x1=False, adj_x2=False):
)


class LayerNorm(Operator):
def __init__(self):
super().__init__("LayerNorm")


class GroupNorm(Operator):
def __init__(self):
super().__init__("GroupNorm")


class Sub(Operator):
def __init__(self):
super().__init__("Sub")
Expand Down Expand Up @@ -228,6 +238,11 @@ def infer_result(self, x, base=-1.0, scale=1.0, shift=0.0):
return common_unary_op_infer(x)


class Gelu(Operator):
def __init__(self):
super().__init__("Gelu")


class Swish(Operator):
def __init__(self):
super().__init__("Swish")
Expand All @@ -237,6 +252,9 @@ class Transpose(Operator):
def __init__(self):
super().__init__("Transpose")

def infer_result(self, x, axes=None):
return common_unary_op_infer(x)


class SoftmaxV2(Operator):
def __init__(self):
Expand Down Expand Up @@ -470,6 +488,11 @@ def infer_result(self, x1, x2):
return common_binary_op_infer(x1, x2, torch.bool)


class ArgMax(Operator):
def __init__(self):
super().__init__("ArgMax")


class Equal(Operator):
def __init__(self):
super().__init__("Equal")
Expand Down Expand Up @@ -574,6 +597,23 @@ def infer_result(
)


class GatherNd(Operator):
def __init__(self):
super().__init__("GatherNd")

def infer_result(self, x, index, orig_index):
x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x)
idx, idx_shape, idx_dim, idx_dtype = get_fake_tensor_meta_val(index)
idx_shape = list(idx_shape)

# assume not none index, and replace prefix x_shape dims
len_idx_shape = len(orig_index)
assert(len_idx_shape > 0)
bcast_index_shape = list(orig_index[0].shape)
x_shape = bcast_index_shape + list(x_shape[len_idx_shape:])
return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x))


class GatherV2(Operator):
def __init__(self):
super().__init__("GatherV2")
Expand Down Expand Up @@ -725,6 +765,16 @@ def infer_result(self, x, offset, size):
return res


class Cos(Operator):
def __init__(self):
super().__init__("Cos")


class Sin(Operator):
def __init__(self):
super().__init__("Sin")


class ConcatD(Operator):
def __init__(self):
super().__init__("ConcatD")
Expand Down
63 changes: 60 additions & 3 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,27 @@ def gen_args(op_var, args_dict, args):
args_str.append(args[i])
return src_code, args_str

@staticmethod
def LayerNorm(name, x, begin_dim, weight, bias, eps):
op = OP(name, "LayerNorm")
op.set_input("x", x)
op.set_input("gamma", weight)
op.set_input("beta", bias)
op.set_attr_int("begin_norm_axis", begin_dim)
op.set_attr_int("begin_params_axis", begin_dim)
op.set_attr_float("epsilon", eps)
return op.to_node()

@staticmethod
def GroupNorm(name, x, weight, bias, N, C, HxW, group, eps):
op = OP(name, "GroupNorm")
op.set_input("x", x)
op.set_input("gamma", weight)
op.set_input("beta", bias)
op.set_attr_int("num_groups", group)
op.set_attr_float("eps", eps)
return op.to_node()

@staticmethod
def Mul(name, x, y):
op = OP(name, "Mul")
Expand Down Expand Up @@ -755,6 +776,12 @@ def Relu(name, x):
op.set_input("x", x)
return op.to_node()

@staticmethod
def Gelu(name, x):
op = OP(name, "Gelu")
op.set_input("x", x)
return op.to_node()

@staticmethod
def Swish(name, x, scale):
silu_op = OP(name, "Swish")
Expand Down Expand Up @@ -1176,6 +1203,13 @@ def Less(name, x1, x2):
cond_op.set_input("x2", x2)
return cond_op.to_node()

@staticmethod
def ArgMax(name, x, dim):
cond_op = OP(name, "ArgMaxV2")
cond_op.set_input("x", x)
cond_op.set_input("dimension", dim)
return cond_op.to_node()

@staticmethod
def ret_tuple(name, in1, in2):
op = OP(name, "IdentityN")
Expand Down Expand Up @@ -1340,11 +1374,15 @@ def SplitD(name, x, dim, num_split, y, from_view_complex):

@staticmethod
def Pack(name, x, axis):
x = [elem.name for elem in x]
x_name = []
for elem in x:
if elem is not None:
x_name.append(elem.name)

op = OP(name, "Pack")
op.set_dynamic_input("x", len(x), x)
op.set_dynamic_input("x", len(x_name), x_name)
op.set_attr_int("axis", axis)
op.set_attr_int("N", len(x))
op.set_attr_int("N", len(x_name))
return op.to_node()

@staticmethod
Expand All @@ -1366,6 +1404,18 @@ def ConcatD(name, x, dim):
op.set_attr_int("concat_dim", dim)
return op.to_node()

@staticmethod
def Cos(name, x):
op = OP(name, "Cos")
op.set_input("x", x)
return op.to_node()

@staticmethod
def Sin(name, x):
op = OP(name, "Sin")
op.set_input("x", x)
return op.to_node()

@staticmethod
def Reshape(name, x, shape, ori_op=None, params_passed=None):
op = OP(name, "Reshape")
Expand All @@ -1381,6 +1431,13 @@ def GatherV2(name, x, indices, axis):
gather_op.set_input("axis", axis)
return gather_op.to_node()

@staticmethod
def GatherNd(name, x, indices, orig_indices):
gather_op = OP(name, "GatherNd")
gather_op.set_input("x", x)
gather_op.set_input("indices", indices)
return gather_op.to_node()

@staticmethod
def Pad(name, x, paddings):
pad_op = OP(name, "Pad")
Expand Down
143 changes: 142 additions & 1 deletion dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
import numpy as np
import torch.fx.traceback as fx_traceback
from torch.fx.immutable_collections import immutable_list
from torch._subclasses import FakeTensor
import dicp.vendor.AscendGraph.ascend_op as ascend_op
from dicp.vendor.AscendGraph.codegen.utils import (
Expand Down Expand Up @@ -176,7 +177,7 @@ def promote_dtype(self, *args, target_dtype):
def mul_scalar(self, x, y):
out_dtype = fx_traceback.get_current_meta()['val'].dtype
# Muls support bfloat16, int32, int16, float16, float32, complex32, complex64.
if out_dtype not in [torch.float, torch.float16, torch.int32]:
if out_dtype not in [torch.float, torch.float32, torch.float16, torch.int32]:
y_op = self.get_const_proxy(y, out_dtype)
return self.get_proxy(ascend_op.Mul, (x, y_op))
return self.get_proxy(ascend_op.Muls, (x, y))
Expand Down Expand Up @@ -264,6 +265,11 @@ def le(self, a, b):
a, b = self.binary_cmp_cast_input(a, b)
return self.get_proxy(ascend_op.LessEqual, (a, b), {})

@register_conversion(aten.argmax.default)
def argmax(self, x, dim):
dim = self.get_proxy(ascend_op.Const, ([dim], torch.int32, []))
return self.get_proxy(ascend_op.ArgMax, (x, dim))

@register_conversion(aten.view_as_real)
def view_as_real(self, x):
out_dtype = fx_traceback.get_current_meta()['val'].dtype
Expand Down Expand Up @@ -307,6 +313,16 @@ def div(self, x, y):
y_op = self.get_const_proxy(y, out_dtype)
return self.get_proxy(ascend_op.Div, (x, y_op), {})

@register_conversion(aten.split.Tensor)
def split(self, x, split_size, dim=0):
splitD_kw = { "from_view_complex": False }
shape = list(x.node.meta['val'].shape)
if dim < 0:
dim += len(shape)
assert shape[dim] > 0
num_split = int(shape[dim] / split_size)
return self.get_proxy(ascend_op.SplitD, (x, dim, num_split, num_split), splitD_kw)

@register_conversion(aten.slice.Tensor)
def slice(self, x, dim=0, start=None, end=None, step=1):
# TODO(tangzhiyi): miss step parameter
Expand Down Expand Up @@ -542,6 +558,14 @@ def nll_loss_backward(self, grad_output, x, target, weight, reduction, ignore_in
weight, total_weight,
reduction_str, ignore_index))

@register_conversion(torch.ops.aten.sin.default)
def sin(self, x):
return self.get_proxy(ascend_op.Sin, (x,))

@register_conversion(torch.ops.aten.cos.default)
def cos(self, x):
return self.get_proxy(ascend_op.Cos, (x,))

@register_conversion(torch.ops.aten.cat.default)
def cat(self, x, dim=0):
out_dtype = fx_traceback.get_current_meta()['val'].dtype
Expand Down Expand Up @@ -624,12 +648,125 @@ def index_base(self, x, dim, index):

@register_conversion(torch.ops.aten.index.Tensor)
def index(self, x, index):
if isinstance(index, list):
return self.unsafe_index(x, index)
return self.index_base(x, 0, index)

@register_conversion(torch.ops.aten._unsafe_index.Tensor)
def unsafe_index(self, x, index):
if isinstance(index, list):
if len(index) == 1:
index = index[0]
return self.index_base(x, 0, index)
else:
bcast_shape = []
first_not_none = 0
not_none_len = 0
status = 0

# calc for first_not_none & not_none_len
for i, elem in enumerate(index):
if status == 0:
if elem is None:
status = 1
else:
break
elif status == 1:
if elem is not None:
status = 2
first_not_none = i
elif status == 2:
if elem is not None:
not_none_len = i - first_not_none + 1
else:
status = 3
elif status == 3:
# not support now!
assert elem is None
index_tmp = [e for e in index]

# insert transpose op
if status > 0:
x_shape = list(x.node.meta['val'].shape)
perm = [num for num in range(len(x_shape))]
for i in range(not_none_len):
index_tmp[i] = index_tmp[first_not_none + i]
index_tmp[first_not_none + i] = None
perm[i] = first_not_none + i
perm[first_not_none + i] = i
perm = self.get_proxy(ascend_op.Const, (perm, torch.int32, [len(perm)]))
x = self.get_proxy(ascend_op.Transpose, (x, perm))

# get broadcast shape
bcast_flag = False
for elem in index_tmp:
if elem is not None:
shape = list(elem.node.meta['val'].shape)
bcast_shape.append(shape)
bcast_shape = list(torch.broadcast_shapes(*bcast_shape))

for elem in index_tmp:
if elem is not None:
shape = list(elem.node.meta['val'].shape)
if not self.shape_prod(shape) == self.shape_prod(bcast_shape) or not len(shape) == len(bcast_shape):
bcast_flag = True

# insert broadcast op
if bcast_flag:
bcast_shape = self.get_proxy(ascend_op.Const, (bcast_shape, torch.int32, [len(bcast_shape)]))
for i, elem in enumerate(index_tmp):
if elem is not None:
index_tmp[i] = self.get_proxy(ascend_op.BroadcastTo, (elem, bcast_shape))

# core gather calc
if status > 0:
index_tmp = index_tmp[:not_none_len]
index = immutable_list(index_tmp)
indices = self.get_proxy(ascend_op.Pack, (index, -1))
gather = self.get_proxy(ascend_op.GatherNd, (x, indices, index_tmp))
if status > 0:
return self.get_proxy(ascend_op.Transpose, (gather, perm))
return gather
return self.index_base(x, 0, index)

@register_conversion(torch.ops.aten.index_select.default)
def index_arg3_(self, x, dim, index):
return self.index_base(x, dim, index)

@register_conversion(torch.ops.aten.native_layer_norm.default)
def native_layer_norm(self, x, shape, weight, bias, eps):
input_shape = x.node.meta['val'].shape
input_ndim = len(input_shape)
normalized_ndim = len(shape)
axis = input_ndim - normalized_ndim
M = 1
for idx in range(axis):
M *= input_shape[idx]
N = 1
for idx in range(axis, input_ndim):
N *= input_shape[idx]

weight_numel = weight.node.meta['val'].numel()
bias_numel = bias.node.meta['val'].numel()
assert weight_numel == N and bias_numel == N

numels = 1
begin_dim = 0
for idx in range(len(input_shape)):
numels *= input_shape[idx]
if numels == M:
begin_dim = idx + 1
weight_dims = list(input_shape[idx + 1:])
break
weight_dims = self.get_shape_proxy(weight_dims)
weight = self.get_proxy(ascend_op.Reshape, (weight, weight_dims))

return self.get_proxy(ascend_op.LayerNorm, (x, begin_dim, weight, bias, eps))

@register_conversion(torch.ops.aten.native_group_norm.default)
def native_group_norm(self, x, weight, bias, N, C, HxW, group, eps):
return self.get_proxy(ascend_op.GroupNorm, (x, weight, bias, N, C, HxW, group, eps))

@register_conversion(torch.ops.aten._native_batch_norm_legit_functional.default)
def _native_batch_norm_legit_functional(self, x, weight, bias, running_mean, running_var,
train, momentum, eps):
Expand Down Expand Up @@ -1041,6 +1178,10 @@ def neg(self, a):
def relu(self, a):
return self.get_proxy(ascend_op.Relu, (a,))

@register_conversion(torch.ops.aten.gelu)
def gelu(self, a):
return self.get_proxy(ascend_op.Gelu, (a,))

@register_conversion(torch.ops.aten.silu)
def silu(self, a):
return self.get_proxy(ascend_op.Swish, (a, 1.0))
Expand Down
Loading
Loading