Skip to content

Commit

Permalink
[dicp][tops] Support stable-diffusion in tops (#568)
Browse files Browse the repository at this point in the history
  • Loading branch information
yao-fengchen authored Jan 14, 2024
1 parent 6b76f2e commit 934f32e
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 93 deletions.
89 changes: 11 additions & 78 deletions dicp/dicp/vendor/TopsGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def Index(self, *args, **kwargs):
start_indices = []
for index in new_indices:
in_shape = index.node.meta["val"].shape
offset = len(broadcast_shape) - len(in_shape)
broadcast_dims = [i + offset for i in range(len(in_shape))]
rank_diff = len(broadcast_shape) - len(in_shape)
broadcast_dims = [i + rank_diff for i in range(len(in_shape))]
start_indices.append(self.get_proxy(tops_op.Expand, (index, tuple(broadcast_shape), broadcast_dims)))
start_indices = self.get_proxy(tops_op.Stack, (start_indices, -1))
return self.get_proxy(tops_op.XlaGather, (operand, start_indices, offset_dims, collapsed_slice_dims,
Expand Down Expand Up @@ -308,7 +308,7 @@ def ReduceMean(self, a, dim=None, keepdim=False, **kwargs):
in_shape = a.node.meta["val"].shape
if dim is None:
dim = list(range(len(in_shape)))
return self.get_proxy(tops_op.ReduceMean, (a, dim))
return self.get_proxy(tops_op.ReduceMean, (a, dim, keepdim))
dim = [(item + len(in_shape)) if item < 0 else item for item in dim]
return self.get_proxy(tops_op.ReduceMean, (a, dim, keepdim))

Expand Down Expand Up @@ -371,8 +371,8 @@ def Adaptive_avg_pool2d(self, *args, **kwargs):
def Adaptive_avg_pool2d_backward(self, grad_output, inputs):
out_shape = fx_traceback.get_current_meta()["val"].shape
grad_output_shape = grad_output.node.meta["val"].shape
offset = len(out_shape) - len(grad_output_shape)
broadcast_dims = [i + offset for i in range(len(grad_output_shape))]
rank_diff = len(out_shape) - len(grad_output_shape)
broadcast_dims = [i + rank_diff for i in range(len(grad_output_shape))]
expand = self.get_proxy(tops_op.Expand, (grad_output, out_shape, broadcast_dims))
value = out_shape[2] * out_shape[3]
scalar = self.get_proxy(tops_op.Scalar, (value, ))
Expand Down Expand Up @@ -446,8 +446,8 @@ def NewEmptyStrided(self, *args, **kwargs):
def Expand(self, *args, **kwargs):
in_shape = args[0].node.meta["val"].shape
out_shape = fx_traceback.get_current_meta()["val"].shape
offset = len(out_shape) - len(in_shape)
broadcast_dims = [i + offset for i in range(len(in_shape))]
rank_diff = len(out_shape) - len(in_shape)
broadcast_dims = [i + rank_diff for i in range(len(in_shape))]
return self.get_proxy(tops_op.Expand, (*args, broadcast_dims), kwargs)

@register_conversion(aten.stack)
Expand Down Expand Up @@ -602,84 +602,17 @@ def VarMean(self, x, dims, *args, correction=1, keepdim=False):
div1 = self.get_proxy(tops_op.Div, (sum_dim, samples))
return self.get_proxy(tops_op.MakeTuple, (div1, mean1))

@register_conversion(aten.addmm)
def Addmm(self, x, mat1, mat2):
dot = self.get_proxy(tops_op.Dot, (mat1, mat2))
return self.get_proxy(tops_op.Add, (x, dot))

# Patterns
tops_patterns = PatternMatcherPass()
aten_patterns_cls_list = []
register_aten_patterns = functools.partial(
register_backend_patterns, aten_patterns_cls_list)
tops_patterns_cls_list = []
register_tops_patterns = functools.partial(
register_backend_patterns, tops_patterns_cls_list)


@register_aten_patterns
class ReplacePatternAddmm(BackendPatternBase):
@staticmethod
def pattern(a, b, c):
return torch.ops.aten.addmm.default(a, b, c)

@staticmethod
def replacement(a, b, c):
return torch.ops.aten.add.Tensor(a, torch.ops.aten.mm(b, c))


# %var: [#users=2] = call_function[target=torch.ops.aten.var.correction]
# (args = (%convolution_4, [0, 2, 3]), kwargs = {correction: 0, keepdim: True})
@register_aten_patterns
class ReplacePatternVar(BackendPatternBase):
@staticmethod
def pattern(a, b):
return torch.ops.aten.var.correction(a, b, correction=0, keepdim=True)

@staticmethod
def replacement(inputs, dims):
keepdim = True
correction = 0
denom = 64
denom = denom - correction
mean1 = torch.ops.aten.mean.dim(inputs, dims, keepdim)
diffs = torch.ops.aten.square.default(
torch.ops.aten.sub.Tensor(inputs, mean1))
sum_results = torch.ops.aten.sum.dim_IntList(diffs, dims, keepdim)
x_var = torch.ops.aten.div.Tensor(sum_results, denom)
return x_var


@register_aten_patterns
class ReplacePatternT(BackendPatternBase):
@staticmethod
def pattern(a):
return torch.ops.aten.t.default(a)

@staticmethod
def replacement(inputs):
return torch.ops.aten.transpose(inputs, 0, 1)


@register_aten_patterns
class ReplacePatternRsub(BackendPatternBase):
@staticmethod
def pattern(a, b):
return torch.ops.aten.rsub.Scalar(a, b)

@staticmethod
def replacement(a, b):
return torch.ops.aten.sub.Scalar(b, a)


@register_aten_patterns
class ReplacePatternSiLU(BackendPatternBase):
# silu(x) = x / (1+exp(-x)) = x*sigmoid(x)
@staticmethod
def pattern(a):
return torch.ops.aten.silu.default(a)

@staticmethod
def replacement(a):
return torch.ops.aten.mul.default(a, torch.ops.aten.sigmoid.default(a))


if is_torch_210:
Dot = torch.fx.wrap(tops_op.Dot.get_singleton())
DotGeneral = torch.fx.wrap(tops_op.DotGeneral.get_singleton())
Expand Down
9 changes: 2 additions & 7 deletions dicp/dicp/vendor/TopsGraph/opset_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

if is_torch_210:
from dicp.dynamo_bridge.op_transformer import BackendPatternMatcherTransformer
from dicp.vendor.TopsGraph.conversion import tops_patterns, aten_patterns_cls_list, tops_patterns_cls_list
from dicp.vendor.TopsGraph.conversion import tops_patterns, tops_patterns_cls_list


class HandleInplaceCopyPass():
Expand All @@ -20,7 +20,7 @@ def transform(self, gm: torch.fx.GraphModule):
inplace_dict = {}
for node in reversed(nodes):
if node.op not in ["placeholder", "output"] and not isinstance(node.target, str):
if node.target.name() == "Copy_":
if hasattr(node.target, "name") and node.target.name() == "Copy_":
if node.args[0].op == "placeholder" and node.args[0].name not in inplace_dict.values():
inplace_outputs.append(node.args[1])
inplace_dict[node.args[1].name] = node.args[0].name
Expand All @@ -35,11 +35,6 @@ def topsgraph_opset_transform(
gm: torch.fx.GraphModule,
):

# 1aten to Naten
if is_torch_210:
gm = BackendPatternMatcherTransformer(
tops_patterns, aten_patterns_cls_list).transform(gm)

# 1aten to Ntops
gm = AtenToTopsTransformer(gm).transform()

Expand Down
9 changes: 1 addition & 8 deletions dicp/dicp/vendor/TopsGraph/tops_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,9 +799,8 @@ def __call__(self, x, idx):


class MakeTuple(Operator):
def __init__(self, a, b):
def __init__(self, *args):
super().__init__("MakeTuple")
self.torch_op = torch.empty_like

def __call__(self, *args):
return (arg.meta["val"] if hasattr(arg, "meta") else arg for arg in args)
Expand All @@ -816,9 +815,3 @@ def __call__(self, operand, indices, offset_dims, collapsed_slice_dims,
start_index_map, index_vector_dim, slice_size, out_shape):
with operand.meta['val'].fake_mode:
return aten.empty(out_shape, device=operand.meta["val"].device)


# TODO check if we need this wrap
@torch.fx.wrap
def ret_tuples(a, b) -> Tuple[torch.Tensor, torch.Tensor]:
return a, b
2 changes: 2 additions & 0 deletions dicp/scripts/ci/tops/ci_tops_test_env.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env bash

LLAMA_MODEL_DIR=$1
STABLE_DIFFUSION_MODEL_DIR=$2

export DIPU_MOCK_CUDA=True
export LLAMA_MODEL_DIR=$1
export LLAMA_FINETUNE_DIR=$2
export STABLE_DIFFUSION_MODEL_DIR=$3
Binary file not shown.
1 change: 1 addition & 0 deletions dicp/test/model/stable_diffusion/topsgraph_output.txt

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions dicp/test/model/test_stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
import torch
import os
import torch._dynamo as dynamo
from ..common import utils
import torch_dipu
from diffusers import StableDiffusionPipeline
dynamo.config.cache_size_limit = 128
utils.update_dynamo_config(False)
device = utils.get_device()
torch_dipu.dipu.set_device(device)
models_dir = os.environ.get("STABLE_DIFFUSION_MODEL_DIR")
assert models_dir is not None


class TestStableDiffusion():
@pytest.mark.parametrize("model_path", [f"{models_dir}/stable-diffusion-2"])
@pytest.mark.parametrize("num_inference_steps", [50])
def test_inference(
self,
model_path: str,
backend: str,
dynamic: bool,
num_inference_steps: int
):
prompt = "A photo of an astronaut riding a horse on mars."
utils.update_dynamo_config(dynamic=dynamic)
torch_dipu.dipu.set_device(device)

dicp_pipe = StableDiffusionPipeline.from_pretrained(model_path).to(device)
dicp_pipe.text_encoder = torch.compile(dicp_pipe.text_encoder, backend=backend)
dicp_pipe.unet = torch.compile(dicp_pipe.unet, backend=backend)
dicp_image = dicp_pipe(prompt, num_inference_steps=num_inference_steps).images[0]
if backend == "ascendgraph":
standard_output = torch.load("stable_diffusion/ascendgraph_output.pt")
elif backend == "topsgraph":
standard_output = torch.load("stable_diffusion/topsgraph_output.pt")
else:
raise ValueError("backend should in (ascendgrap, topsgraph)")
dicp_output = torch.tensor(list(dicp_image.getdata()))
assert torch.allclose(dicp_output, standard_output, equal_nan=True)

0 comments on commit 934f32e

Please sign in to comment.