Skip to content

Commit

Permalink
[dicp][ascend]Support ascend stable_diffusion. (#623)
Browse files Browse the repository at this point in the history
* Support ascend stable_diffusion.

* Comment a replacement pattern.

* Fix unit test.

* Further fix config.

* Change stable diffusion model dir config.

* Add setup module for ci process.

* Sync with tops unit test for stable_diffusion.

* Fix stable_diffusion model dir assignment.

* Fix bug brought by new changes.

* Fix review comments.

* Add comments for unsafe_index conversion, explain state machine more readable.

* Change split number calc logic.

* Move llama hf transformers import after torch_dipu.
  • Loading branch information
pdx1989 authored Jan 22, 2024
1 parent 8b1add6 commit cae8d42
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 18 deletions.
50 changes: 50 additions & 0 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ def infer_result(self, x1, x2, adj_x1=False, adj_x2=False):
)


class LayerNorm(Operator):
def __init__(self):
super().__init__("LayerNorm")


class GroupNorm(Operator):
def __init__(self):
super().__init__("GroupNorm")


class Sub(Operator):
def __init__(self):
super().__init__("Sub")
Expand Down Expand Up @@ -228,6 +238,11 @@ def infer_result(self, x, base=-1.0, scale=1.0, shift=0.0):
return common_unary_op_infer(x)


class Gelu(Operator):
def __init__(self):
super().__init__("Gelu")


class Swish(Operator):
def __init__(self):
super().__init__("Swish")
Expand All @@ -237,6 +252,9 @@ class Transpose(Operator):
def __init__(self):
super().__init__("Transpose")

def infer_result(self, x, axes=None):
return common_unary_op_infer(x)


class SoftmaxV2(Operator):
def __init__(self):
Expand Down Expand Up @@ -470,6 +488,11 @@ def infer_result(self, x1, x2):
return common_binary_op_infer(x1, x2, torch.bool)


class ArgMax(Operator):
def __init__(self):
super().__init__("ArgMax")


class Equal(Operator):
def __init__(self):
super().__init__("Equal")
Expand Down Expand Up @@ -574,6 +597,23 @@ def infer_result(
)


class GatherNd(Operator):
def __init__(self):
super().__init__("GatherNd")

def infer_result(self, x, index, orig_index):
x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x)
idx, idx_shape, idx_dim, idx_dtype = get_fake_tensor_meta_val(index)
idx_shape = list(idx_shape)

# assume not none index, and replace prefix x_shape dims
len_idx_shape = len(orig_index)
assert(len_idx_shape > 0)
bcast_index_shape = list(orig_index[0].shape)
x_shape = bcast_index_shape + list(x_shape[len_idx_shape:])
return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x))


class GatherV2(Operator):
def __init__(self):
super().__init__("GatherV2")
Expand Down Expand Up @@ -725,6 +765,16 @@ def infer_result(self, x, offset, size):
return res


class Cos(Operator):
def __init__(self):
super().__init__("Cos")


class Sin(Operator):
def __init__(self):
super().__init__("Sin")


class ConcatD(Operator):
def __init__(self):
super().__init__("ConcatD")
Expand Down
63 changes: 60 additions & 3 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,27 @@ def gen_args(op_var, args_dict, args):
args_str.append(args[i])
return src_code, args_str

@staticmethod
def LayerNorm(name, x, begin_dim, weight, bias, eps):
op = OP(name, "LayerNorm")
op.set_input("x", x)
op.set_input("gamma", weight)
op.set_input("beta", bias)
op.set_attr_int("begin_norm_axis", begin_dim)
op.set_attr_int("begin_params_axis", begin_dim)
op.set_attr_float("epsilon", eps)
return op.to_node()

@staticmethod
def GroupNorm(name, x, weight, bias, N, C, HxW, group, eps):
op = OP(name, "GroupNorm")
op.set_input("x", x)
op.set_input("gamma", weight)
op.set_input("beta", bias)
op.set_attr_int("num_groups", group)
op.set_attr_float("eps", eps)
return op.to_node()

@staticmethod
def Mul(name, x, y):
op = OP(name, "Mul")
Expand Down Expand Up @@ -755,6 +776,12 @@ def Relu(name, x):
op.set_input("x", x)
return op.to_node()

@staticmethod
def Gelu(name, x):
op = OP(name, "Gelu")
op.set_input("x", x)
return op.to_node()

@staticmethod
def Swish(name, x, scale):
silu_op = OP(name, "Swish")
Expand Down Expand Up @@ -1176,6 +1203,13 @@ def Less(name, x1, x2):
cond_op.set_input("x2", x2)
return cond_op.to_node()

@staticmethod
def ArgMax(name, x, dim):
cond_op = OP(name, "ArgMaxV2")
cond_op.set_input("x", x)
cond_op.set_input("dimension", dim)
return cond_op.to_node()

@staticmethod
def ret_tuple(name, in1, in2):
op = OP(name, "IdentityN")
Expand Down Expand Up @@ -1340,11 +1374,15 @@ def SplitD(name, x, dim, num_split, y, from_view_complex):

@staticmethod
def Pack(name, x, axis):
x = [elem.name for elem in x]
x_name = []
for elem in x:
if elem is not None:
x_name.append(elem.name)

op = OP(name, "Pack")
op.set_dynamic_input("x", len(x), x)
op.set_dynamic_input("x", len(x_name), x_name)
op.set_attr_int("axis", axis)
op.set_attr_int("N", len(x))
op.set_attr_int("N", len(x_name))
return op.to_node()

@staticmethod
Expand All @@ -1366,6 +1404,18 @@ def ConcatD(name, x, dim):
op.set_attr_int("concat_dim", dim)
return op.to_node()

@staticmethod
def Cos(name, x):
op = OP(name, "Cos")
op.set_input("x", x)
return op.to_node()

@staticmethod
def Sin(name, x):
op = OP(name, "Sin")
op.set_input("x", x)
return op.to_node()

@staticmethod
def Reshape(name, x, shape, ori_op=None, params_passed=None):
op = OP(name, "Reshape")
Expand All @@ -1381,6 +1431,13 @@ def GatherV2(name, x, indices, axis):
gather_op.set_input("axis", axis)
return gather_op.to_node()

@staticmethod
def GatherNd(name, x, indices, orig_indices):
gather_op = OP(name, "GatherNd")
gather_op.set_input("x", x)
gather_op.set_input("indices", indices)
return gather_op.to_node()

@staticmethod
def Pad(name, x, paddings):
pad_op = OP(name, "Pad")
Expand Down
Loading

0 comments on commit cae8d42

Please sign in to comment.