Skip to content

Commit

Permalink
Possibility to override dtype and default num_inference_steps through…
Browse files Browse the repository at this point in the history
… env var (#443)

* Possibility to override dtype through env var
* Specify default num inference steps via env var as well
  • Loading branch information
oOraph authored Aug 8, 2024
1 parent 40a21b7 commit db248d4
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
25 changes: 19 additions & 6 deletions docker_images/diffusers/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def __init__(self, model_id: str):
if model_id.startswith("hf-internal-testing/")
else {}
)
if torch.cuda.is_available():
env_dtype = os.getenv("TORCH_DTYPE")
if env_dtype:
kwargs["torch_dtype"] = getattr(torch, env_dtype)
elif torch.cuda.is_available():
kwargs["torch_dtype"] = torch.float16
if model_id == "stabilityai/stable-diffusion-xl-refiner-1.0":
kwargs["variant"] = "fp16"
Expand Down Expand Up @@ -189,30 +192,40 @@ def _process_req(self, image, prompt, **kwargs):
),
):
if "num_inference_steps" not in kwargs:
kwargs["num_inference_steps"] = 25
kwargs["num_inference_steps"] = int(
os.getenv("DEFAULT_NUM_INFERENCE_STEPS", "25")
)
images = self.ldm(prompt, image, **kwargs)["images"]
return images[0]
elif isinstance(self.ldm, StableDiffusionXLImg2ImgPipeline):
if "num_inference_steps" not in kwargs:
kwargs["num_inference_steps"] = 25
kwargs["num_inference_steps"] = int(
os.getenv("DEFAULT_NUM_INFERENCE_STEPS", "25")
)
image = image.convert("RGB")
images = self.ldm(prompt, image=image, **kwargs)["images"]
return images[0]
elif isinstance(self.ldm, (StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline)):
if "num_inference_steps" not in kwargs:
kwargs["num_inference_steps"] = 25
kwargs["num_inference_steps"] = int(
os.getenv("DEFAULT_NUM_INFERENCE_STEPS", "25")
)
# image comes first
images = self.ldm(image, prompt, **kwargs)["images"]
return images[0]
elif isinstance(self.ldm, StableDiffusionImageVariationPipeline):
if "num_inference_steps" not in kwargs:
kwargs["num_inference_steps"] = 25
kwargs["num_inference_steps"] = int(
os.getenv("DEFAULT_NUM_INFERENCE_STEPS", "25")
)
# only image is needed
images = self.ldm(image, **kwargs)["images"]
return images[0]
elif isinstance(self.ldm, (KandinskyImg2ImgPipeline)):
if "num_inference_steps" not in kwargs:
kwargs["num_inference_steps"] = 100
kwargs["num_inference_steps"] = int(
os.getenv("DEFAULT_NUM_INFERENCE_STEPS", "100")
)
# not all args are supported by the prior
prior_args = {
"num_inference_steps": kwargs["num_inference_steps"],
Expand Down
14 changes: 11 additions & 3 deletions docker_images/diffusers/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def __init__(self, model_id: str):
if model_id.startswith("hf-internal-testing/")
else {}
)
if torch.cuda.is_available():
env_dtype = os.getenv("TORCH_DTYPE")
if env_dtype:
kwargs["torch_dtype"] = getattr(torch, env_dtype)
elif torch.cuda.is_available():
kwargs["torch_dtype"] = torch.float16

has_model_index = any(
Expand Down Expand Up @@ -158,8 +161,13 @@ def _process_req(self, inputs, **kwargs):
# only one image per prompt is supported
kwargs["num_images_per_prompt"] = 1

if self.is_karras_compatible and "num_inference_steps" not in kwargs:
kwargs["num_inference_steps"] = 20
if "num_inference_steps" not in kwargs:
default_num_steps = os.getenv("DEFAULT_NUM_INFERENCE_STEPS")
if default_num_steps:
kwargs["num_inference_steps"] = int(default_num_steps)
elif self.is_karras_compatible:
kwargs["num_inference_steps"] = 20
# Else, don't specify anything, leave the default behaviour

images = self.ldm(inputs, **kwargs)["images"]
return images[0]

0 comments on commit db248d4

Please sign in to comment.