Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade diffusers version to 0.31.0 #25

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
.DS_Store
build
dist
*.egg-info
*.egg-info
*__pycache__*
2 changes: 1 addition & 1 deletion distrifuser/models/distri_sdxl_unet_pp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import Attention
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from torch import distributed as dist, nn

from .base_model import BaseModel
Expand Down
8 changes: 6 additions & 2 deletions distrifuser/models/distri_sdxl_unet_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from diffusers import UNet2DConditionModel
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.resnet import ResnetBlock2D
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from torch import distributed as dist, nn

from distrifuser.modules.base_module import BaseModule
Expand Down Expand Up @@ -157,7 +157,11 @@ def forward(
if self.buffer_list is None:
self.buffer_list = [torch.empty_like(output) for _ in range(2)]
dist.all_gather(
self.buffer_list, output.contiguous(), group=distri_config.split_group(), async_op=False
self.buffer_list, output.contiguous(),
group=distri_config.split_group,
# original code
# group=distri_config.split_group(),
async_op=False
)
torch.cat(self.buffer_list, dim=0, out=self.output_buffer)
output = self.output_buffer
Expand Down
2 changes: 1 addition & 1 deletion distrifuser/models/naive_patch_sdxl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from torch import distributed as dist

from .base_model import BaseModel
Expand Down
28 changes: 22 additions & 6 deletions distrifuser/modules/pp/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,13 @@ def forward(
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)

# Handle scale parameter based on PEFT backend
if USE_PEFT_BACKEND:
query = attn.to_q(hidden_states)
else:
query = attn.to_q(hidden_states)
query = query * scale

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

Expand Down Expand Up @@ -117,8 +121,14 @@ def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):

batch_size, sequence_length, _ = hidden_states.shape

args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
# args = () if USE_PEFT_BACKEND else (scale,)
# query = attn.to_q(hidden_states, *args)

if USE_PEFT_BACKEND:
query = attn.to_q(hidden_states)
else:
query = attn.to_q(hidden_states)
query = query * scale

encoder_hidden_states = hidden_states

Expand Down Expand Up @@ -156,7 +166,13 @@ def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# hidden_states = attn.to_out[0](hidden_states, *args)

if USE_PEFT_BACKEND:
hidden_states = attn.to_out[0](hidden_states)
else:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = hidden_states * scale
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down
6 changes: 4 additions & 2 deletions distrifuser/modules/tp/resnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch.cuda
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D, USE_PEFT_BACKEND
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.utils import USE_PEFT_BACKEND
from torch import distributed as dist
from torch import nn
from torch.nn import functional as F
Expand Down Expand Up @@ -192,7 +193,8 @@ def forward(

if module.conv_shortcut is not None:
input_tensor = (
module.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
# module.conv_shortcut(input_tensor, scale)
module.conv_shortcut(input_tensor) * scale if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
)

output_tensor = (input_tensor + hidden_states) / module.output_scale_factor
Expand Down
5 changes: 4 additions & 1 deletion distrifuser/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def from_pretrained(distri_config: DistriConfig, **kwargs):
raise ValueError(f"Unknown parallelism: {distri_config.parallelism}")

pipeline = StableDiffusionXLPipeline.from_pretrained(
pretrained_model_name_or_path, torch_dtype=torch_dtype, unet=unet, **kwargs
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
unet=unet,
**kwargs
).to(device)
return DistriSDXLPipeline(pipeline, distri_config)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
author="Muyang Li, Tianle Cai, Jiaxin Cao, Qinsheng Zhang, Han Cai, Junjie Bai, Yangqing Jia, Ming-Yu Liu, Kai Li and Song Han",
author_email="muyangli@mit.edu",
packages=find_packages(),
install_requires=["torch>=2.2", "diffusers==0.24.0", "transformers", "tqdm"],
install_requires=["torch>=2.2", "diffusers>=0.31.0", "transformers", "tqdm"],
url="https://github.com/mit-han-lab/distrifuser",
description="DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models",
long_description=long_description,
Expand Down