Skip to content

Commit

Permalink
fix some minor issue
Browse files Browse the repository at this point in the history
  • Loading branch information
lmxyy committed Mar 5, 2024
1 parent bcd91f2 commit 042ae44
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 7 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ from distrifuser.utils import DistriConfig

distri_config = DistriConfig(height=1024, width=1024, warmup_steps=4)
pipeline = DistriSDXLPipeline.from_pretrained(
distri_config=distri_config, pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0"
distri_config=distri_config,
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0",
variant="fp16",
use_safetensors=True,
)
pipeline.prepare()

pipeline.set_progress_bar_config(disable=distri_config.rank != 0)
image = pipeline(
Expand Down
2 changes: 1 addition & 1 deletion distrifuser/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.0beta0"
__version__ = "0.0.0beta1"
2 changes: 2 additions & 0 deletions distrifuser/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def __init__(self, pipeline: StableDiffusionXLPipeline, module_config: DistriCon

self.static_inputs = None

self.prepare()

@staticmethod
def from_pretrained(distri_config: DistriConfig, **kwargs):
device = distri_config.device
Expand Down
1 change: 0 additions & 1 deletion scripts/generate_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def main():
use_safetensors=True,
scheduler=scheduler,
)
pipeline.prepare()
pipeline.set_progress_bar_config(disable=distri_config.rank != 0, position=1, leave=False)

if args.output_root is None:
Expand Down
1 change: 0 additions & 1 deletion scripts/run_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def main():
use_safetensors=True,
scheduler=scheduler,
)
pipeline.prepare()

if args.mode == "generation":
assert args.output_path is not None
Expand Down
6 changes: 4 additions & 2 deletions scripts/sdxl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

distri_config = DistriConfig(height=1024, width=1024, warmup_steps=4)
pipeline = DistriSDXLPipeline.from_pretrained(
distri_config=distri_config, pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0"
distri_config=distri_config,
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0",
variant="fp16",
use_safetensors=True,
)
pipeline.prepare()

pipeline.set_progress_bar_config(disable=distri_config.rank != 0)
image = pipeline(
Expand Down

0 comments on commit 042ae44

Please sign in to comment.