Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dicp][ascend]Support ascend stable_diffusion. #623

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ def infer_result(self, x1, x2, adj_x1=False, adj_x2=False):
)


class LayerNorm(Operator):
def __init__(self):
super().__init__("LayerNorm")


class GroupNorm(Operator):
def __init__(self):
super().__init__("GroupNorm")


class Sub(Operator):
def __init__(self):
super().__init__("Sub")
Expand Down Expand Up @@ -228,6 +238,11 @@ def infer_result(self, x, base=-1.0, scale=1.0, shift=0.0):
return common_unary_op_infer(x)


class Gelu(Operator):
def __init__(self):
super().__init__("Gelu")


class Swish(Operator):
def __init__(self):
super().__init__("Swish")
Expand All @@ -237,6 +252,9 @@ class Transpose(Operator):
def __init__(self):
super().__init__("Transpose")

def infer_result(self, x, axes=None):
return common_unary_op_infer(x)


class SoftmaxV2(Operator):
def __init__(self):
Expand Down Expand Up @@ -470,6 +488,11 @@ def infer_result(self, x1, x2):
return common_binary_op_infer(x1, x2, torch.bool)


class ArgMax(Operator):
def __init__(self):
super().__init__("ArgMax")


class Equal(Operator):
def __init__(self):
super().__init__("Equal")
Expand Down Expand Up @@ -574,6 +597,23 @@ def infer_result(
)


class GatherNd(Operator):
def __init__(self):
super().__init__("GatherNd")

def infer_result(self, x, index, orig_index):
x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x)
idx, idx_shape, idx_dim, idx_dtype = get_fake_tensor_meta_val(index)
idx_shape = list(idx_shape)

# assume not none index, and replace prefix x_shape dims
len_idx_shape = len(orig_index)
assert(len_idx_shape > 0)
bcast_index_shape = list(orig_index[0].shape)
x_shape = bcast_index_shape + list(x_shape[len_idx_shape:])
return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x))


class GatherV2(Operator):
def __init__(self):
super().__init__("GatherV2")
Expand Down Expand Up @@ -725,6 +765,16 @@ def infer_result(self, x, offset, size):
return res


class Cos(Operator):
def __init__(self):
super().__init__("Cos")


class Sin(Operator):
def __init__(self):
super().__init__("Sin")


class ConcatD(Operator):
def __init__(self):
super().__init__("ConcatD")
Expand Down
63 changes: 60 additions & 3 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,27 @@ def gen_args(op_var, args_dict, args):
args_str.append(args[i])
return src_code, args_str

@staticmethod
def LayerNorm(name, x, begin_dim, weight, bias, eps):
op = OP(name, "LayerNorm")
op.set_input("x", x)
op.set_input("gamma", weight)
op.set_input("beta", bias)
op.set_attr_int("begin_norm_axis", begin_dim)
op.set_attr_int("begin_params_axis", begin_dim)
op.set_attr_float("epsilon", eps)
return op.to_node()

@staticmethod
def GroupNorm(name, x, weight, bias, N, C, HxW, group, eps):
op = OP(name, "GroupNorm")
op.set_input("x", x)
op.set_input("gamma", weight)
op.set_input("beta", bias)
op.set_attr_int("num_groups", group)
op.set_attr_float("eps", eps)
return op.to_node()

@staticmethod
def Mul(name, x, y):
op = OP(name, "Mul")
Expand Down Expand Up @@ -755,6 +776,12 @@ def Relu(name, x):
op.set_input("x", x)
return op.to_node()

@staticmethod
def Gelu(name, x):
op = OP(name, "Gelu")
op.set_input("x", x)
return op.to_node()

@staticmethod
def Swish(name, x, scale):
silu_op = OP(name, "Swish")
Expand Down Expand Up @@ -1176,6 +1203,13 @@ def Less(name, x1, x2):
cond_op.set_input("x2", x2)
return cond_op.to_node()

@staticmethod
def ArgMax(name, x, dim):
cond_op = OP(name, "ArgMaxV2")
cond_op.set_input("x", x)
cond_op.set_input("dimension", dim)
return cond_op.to_node()

@staticmethod
def ret_tuple(name, in1, in2):
op = OP(name, "IdentityN")
Expand Down Expand Up @@ -1340,11 +1374,15 @@ def SplitD(name, x, dim, num_split, y, from_view_complex):

@staticmethod
def Pack(name, x, axis):
x = [elem.name for elem in x]
x_name = []
for elem in x:
if elem is not None:
x_name.append(elem.name)

op = OP(name, "Pack")
op.set_dynamic_input("x", len(x), x)
op.set_dynamic_input("x", len(x_name), x_name)
op.set_attr_int("axis", axis)
op.set_attr_int("N", len(x))
op.set_attr_int("N", len(x_name))
return op.to_node()

@staticmethod
Expand All @@ -1366,6 +1404,18 @@ def ConcatD(name, x, dim):
op.set_attr_int("concat_dim", dim)
return op.to_node()

@staticmethod
def Cos(name, x):
op = OP(name, "Cos")
op.set_input("x", x)
return op.to_node()

@staticmethod
def Sin(name, x):
op = OP(name, "Sin")
op.set_input("x", x)
return op.to_node()

@staticmethod
def Reshape(name, x, shape, ori_op=None, params_passed=None):
op = OP(name, "Reshape")
Expand All @@ -1381,6 +1431,13 @@ def GatherV2(name, x, indices, axis):
gather_op.set_input("axis", axis)
return gather_op.to_node()

@staticmethod
def GatherNd(name, x, indices, orig_indices):
gather_op = OP(name, "GatherNd")
gather_op.set_input("x", x)
gather_op.set_input("indices", indices)
return gather_op.to_node()

@staticmethod
def Pad(name, x, paddings):
pad_op = OP(name, "Pad")
Expand Down
Loading