Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dicp]Support stable-diffusion in tops #568

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 11 additions & 78 deletions dicp/dicp/vendor/TopsGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def Index(self, *args, **kwargs):
start_indices = []
for index in new_indices:
in_shape = index.node.meta["val"].shape
offset = len(broadcast_shape) - len(in_shape)
broadcast_dims = [i + offset for i in range(len(in_shape))]
rank_diff = len(broadcast_shape) - len(in_shape)
broadcast_dims = [i + rank_diff for i in range(len(in_shape))]
start_indices.append(self.get_proxy(tops_op.Expand, (index, tuple(broadcast_shape), broadcast_dims)))
start_indices = self.get_proxy(tops_op.Stack, (start_indices, -1))
return self.get_proxy(tops_op.XlaGather, (operand, start_indices, offset_dims, collapsed_slice_dims,
Expand Down Expand Up @@ -308,7 +308,7 @@ def ReduceMean(self, a, dim=None, keepdim=False, **kwargs):
in_shape = a.node.meta["val"].shape
if dim is None:
dim = list(range(len(in_shape)))
return self.get_proxy(tops_op.ReduceMean, (a, dim))
return self.get_proxy(tops_op.ReduceMean, (a, dim, keepdim))
dim = [(item + len(in_shape)) if item < 0 else item for item in dim]
return self.get_proxy(tops_op.ReduceMean, (a, dim, keepdim))

Expand Down Expand Up @@ -371,8 +371,8 @@ def Adaptive_avg_pool2d(self, *args, **kwargs):
def Adaptive_avg_pool2d_backward(self, grad_output, inputs):
out_shape = fx_traceback.get_current_meta()["val"].shape
grad_output_shape = grad_output.node.meta["val"].shape
offset = len(out_shape) - len(grad_output_shape)
broadcast_dims = [i + offset for i in range(len(grad_output_shape))]
rank_diff = len(out_shape) - len(grad_output_shape)
broadcast_dims = [i + rank_diff for i in range(len(grad_output_shape))]
expand = self.get_proxy(tops_op.Expand, (grad_output, out_shape, broadcast_dims))
value = out_shape[2] * out_shape[3]
scalar = self.get_proxy(tops_op.Scalar, (value, ))
Expand Down Expand Up @@ -446,8 +446,8 @@ def NewEmptyStrided(self, *args, **kwargs):
def Expand(self, *args, **kwargs):
in_shape = args[0].node.meta["val"].shape
out_shape = fx_traceback.get_current_meta()["val"].shape
offset = len(out_shape) - len(in_shape)
broadcast_dims = [i + offset for i in range(len(in_shape))]
rank_diff = len(out_shape) - len(in_shape)
broadcast_dims = [i + rank_diff for i in range(len(in_shape))]
return self.get_proxy(tops_op.Expand, (*args, broadcast_dims), kwargs)

@register_conversion(aten.stack)
Expand Down Expand Up @@ -602,84 +602,17 @@ def VarMean(self, x, dims, *args, correction=1, keepdim=False):
div1 = self.get_proxy(tops_op.Div, (sum_dim, samples))
return self.get_proxy(tops_op.MakeTuple, (div1, mean1))

@register_conversion(aten.addmm)
def Addmm(self, x, mat1, mat2):
dot = self.get_proxy(tops_op.Dot, (mat1, mat2))
return self.get_proxy(tops_op.Add, (x, dot))

# Patterns
tops_patterns = PatternMatcherPass()
aten_patterns_cls_list = []
register_aten_patterns = functools.partial(
register_backend_patterns, aten_patterns_cls_list)
tops_patterns_cls_list = []
register_tops_patterns = functools.partial(
register_backend_patterns, tops_patterns_cls_list)


@register_aten_patterns
class ReplacePatternAddmm(BackendPatternBase):
@staticmethod
def pattern(a, b, c):
return torch.ops.aten.addmm.default(a, b, c)

@staticmethod
def replacement(a, b, c):
return torch.ops.aten.add.Tensor(a, torch.ops.aten.mm(b, c))


# %var: [#users=2] = call_function[target=torch.ops.aten.var.correction]
# (args = (%convolution_4, [0, 2, 3]), kwargs = {correction: 0, keepdim: True})
@register_aten_patterns
class ReplacePatternVar(BackendPatternBase):
@staticmethod
def pattern(a, b):
return torch.ops.aten.var.correction(a, b, correction=0, keepdim=True)

@staticmethod
def replacement(inputs, dims):
keepdim = True
correction = 0
denom = 64
denom = denom - correction
mean1 = torch.ops.aten.mean.dim(inputs, dims, keepdim)
diffs = torch.ops.aten.square.default(
torch.ops.aten.sub.Tensor(inputs, mean1))
sum_results = torch.ops.aten.sum.dim_IntList(diffs, dims, keepdim)
x_var = torch.ops.aten.div.Tensor(sum_results, denom)
return x_var


@register_aten_patterns
class ReplacePatternT(BackendPatternBase):
@staticmethod
def pattern(a):
return torch.ops.aten.t.default(a)

@staticmethod
def replacement(inputs):
return torch.ops.aten.transpose(inputs, 0, 1)


@register_aten_patterns
class ReplacePatternRsub(BackendPatternBase):
@staticmethod
def pattern(a, b):
return torch.ops.aten.rsub.Scalar(a, b)

@staticmethod
def replacement(a, b):
return torch.ops.aten.sub.Scalar(b, a)


@register_aten_patterns
class ReplacePatternSiLU(BackendPatternBase):
# silu(x) = x / (1+exp(-x)) = x*sigmoid(x)
@staticmethod
def pattern(a):
return torch.ops.aten.silu.default(a)

@staticmethod
def replacement(a):
return torch.ops.aten.mul.default(a, torch.ops.aten.sigmoid.default(a))


if is_torch_210:
Dot = torch.fx.wrap(tops_op.Dot.get_singleton())
DotGeneral = torch.fx.wrap(tops_op.DotGeneral.get_singleton())
Expand Down
9 changes: 2 additions & 7 deletions dicp/dicp/vendor/TopsGraph/opset_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

if is_torch_210:
from dicp.dynamo_bridge.op_transformer import BackendPatternMatcherTransformer
from dicp.vendor.TopsGraph.conversion import tops_patterns, aten_patterns_cls_list, tops_patterns_cls_list
from dicp.vendor.TopsGraph.conversion import tops_patterns, tops_patterns_cls_list


class HandleInplaceCopyPass():
Expand All @@ -20,7 +20,7 @@ def transform(self, gm: torch.fx.GraphModule):
inplace_dict = {}
for node in reversed(nodes):
if node.op not in ["placeholder", "output"] and not isinstance(node.target, str):
if node.target.name() == "Copy_":
if hasattr(node.target, "name") and node.target.name() == "Copy_":
if node.args[0].op == "placeholder" and node.args[0].name not in inplace_dict.values():
inplace_outputs.append(node.args[1])
inplace_dict[node.args[1].name] = node.args[0].name
Expand All @@ -35,11 +35,6 @@ def topsgraph_opset_transform(
gm: torch.fx.GraphModule,
):

# 1aten to Naten
if is_torch_210:
gm = BackendPatternMatcherTransformer(
tops_patterns, aten_patterns_cls_list).transform(gm)

# 1aten to Ntops
gm = AtenToTopsTransformer(gm).transform()

Expand Down
9 changes: 1 addition & 8 deletions dicp/dicp/vendor/TopsGraph/tops_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,9 +799,8 @@ def __call__(self, x, idx):


class MakeTuple(Operator):
def __init__(self, a, b):
def __init__(self, *args):
super().__init__("MakeTuple")
self.torch_op = torch.empty_like

def __call__(self, *args):
return (arg.meta["val"] if hasattr(arg, "meta") else arg for arg in args)
Expand All @@ -816,9 +815,3 @@ def __call__(self, operand, indices, offset_dims, collapsed_slice_dims,
start_index_map, index_vector_dim, slice_size, out_shape):
with operand.meta['val'].fake_mode:
return aten.empty(out_shape, device=operand.meta["val"].device)


# TODO check if we need this wrap
@torch.fx.wrap
def ret_tuples(a, b) -> Tuple[torch.Tensor, torch.Tensor]:
return a, b
2 changes: 2 additions & 0 deletions dicp/scripts/ci/tops/ci_tops_test_env.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env bash

LLAMA_MODEL_DIR=$1
STABLE_DIFFUSION_MODEL_DIR=$2

export DIPU_MOCK_CUDA=True
export LLAMA_MODEL_DIR=$1
export LLAMA_FINETUNE_DIR=$2
export STABLE_DIFFUSION_MODEL_DIR=$3
Binary file not shown.
1 change: 1 addition & 0 deletions dicp/test/model/stable_diffusion/topsgraph_output.txt

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions dicp/test/model/test_stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
import torch
import os
import torch._dynamo as dynamo
from ..common import utils
import torch_dipu
from diffusers import StableDiffusionPipeline
dynamo.config.cache_size_limit = 128
utils.update_dynamo_config(False)
device = utils.get_device()
torch_dipu.dipu.set_device(device)
models_dir = os.environ.get("STABLE_DIFFUSION_MODEL_DIR")
assert models_dir is not None


class TestStableDiffusion():
@pytest.mark.parametrize("model_path", [f"{models_dir}/stable-diffusion-2"])
@pytest.mark.parametrize("num_inference_steps", [50])
def test_inference(
self,
model_path: str,
backend: str,
dynamic: bool,
num_inference_steps: int
):
prompt = "A photo of an astronaut riding a horse on mars."
utils.update_dynamo_config(dynamic=dynamic)
torch_dipu.dipu.set_device(device)

dicp_pipe = StableDiffusionPipeline.from_pretrained(model_path).to(device)
dicp_pipe.text_encoder = torch.compile(dicp_pipe.text_encoder, backend=backend)
dicp_pipe.unet = torch.compile(dicp_pipe.unet, backend=backend)
dicp_image = dicp_pipe(prompt, num_inference_steps=num_inference_steps).images[0]
if backend == "ascendgraph":
standard_output = torch.load("stable_diffusion/ascendgraph_output.pt")
elif backend == "topsgraph":
standard_output = torch.load("stable_diffusion/topsgraph_output.pt")
else:
raise ValueError("backend should in (ascendgrap, topsgraph)")
dicp_output = torch.tensor(list(dicp_image.getdata()))
assert torch.allclose(dicp_output, standard_output, equal_nan=True)