From 040580150081761d8e8c312cdb0a5d346397b746 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Wed, 6 Dec 2023 13:56:18 +0100 Subject: [PATCH 1/3] Let users swap schedulers --- .../diffusers/app/pipelines/text_to_image.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/docker_images/diffusers/app/pipelines/text_to_image.py b/docker_images/diffusers/app/pipelines/text_to_image.py index 30944ea7..986fb9ae 100644 --- a/docker_images/diffusers/app/pipelines/text_to_image.py +++ b/docker_images/diffusers/app/pipelines/text_to_image.py @@ -2,6 +2,7 @@ import logging import os from typing import TYPE_CHECKING +import importlib import torch from app import idle, lora, timing, validation @@ -65,6 +66,11 @@ def __init__(self, model_id: str): if torch.cuda.is_available(): kwargs["torch_dtype"] = torch.float16 + custom_scheduler = None + if "scheduler" in kwargs: + custom_scheduler = kwargs["scheduler"] + kwargs.pop("scheduler") + has_model_index = any( file.rfilename == "model_index.json" for file in model_data.siblings ) @@ -127,7 +133,14 @@ def __init__(self, model_id: str): self.ldm.__class__.__init__.__annotations__.get("scheduler", None) == KarrasDiffusionSchedulers ) - if self.is_karras_compatible: + if custom_scheduler: + compatibles = self.ldm.compatibles + is_compatible_scheduler = [cls for cls in compatibles if cls.__name__ == custom_scheduler] + if(is_compatible_scheduler): + SchedulerClass = getattr(importlib.import_module("diffusers.schedulers"), custom_scheduler) + self.ldm.scheduler = SchedulerClass.from_config(self.ldm.scheduler.config) + + if self.is_karras_compatible and ((not custom_scheduler) or (custom_scheduler and not is_compatible_scheduler)): self.ldm.scheduler = EulerAncestralDiscreteScheduler.from_config( self.ldm.scheduler.config ) From 99de135b4908c6c12ddeb7b2e2dc2ff11e409a8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Wed, 6 Dec 2023 14:38:20 +0100 Subject: [PATCH 2/3] move it to call --- .../diffusers/app/pipelines/text_to_image.py | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/docker_images/diffusers/app/pipelines/text_to_image.py b/docker_images/diffusers/app/pipelines/text_to_image.py index 986fb9ae..a9ca7835 100644 --- a/docker_images/diffusers/app/pipelines/text_to_image.py +++ b/docker_images/diffusers/app/pipelines/text_to_image.py @@ -66,11 +66,6 @@ def __init__(self, model_id: str): if torch.cuda.is_available(): kwargs["torch_dtype"] = torch.float16 - custom_scheduler = None - if "scheduler" in kwargs: - custom_scheduler = kwargs["scheduler"] - kwargs.pop("scheduler") - has_model_index = any( file.rfilename == "model_index.json" for file in model_data.siblings ) @@ -133,14 +128,7 @@ def __init__(self, model_id: str): self.ldm.__class__.__init__.__annotations__.get("scheduler", None) == KarrasDiffusionSchedulers ) - if custom_scheduler: - compatibles = self.ldm.compatibles - is_compatible_scheduler = [cls for cls in compatibles if cls.__name__ == custom_scheduler] - if(is_compatible_scheduler): - SchedulerClass = getattr(importlib.import_module("diffusers.schedulers"), custom_scheduler) - self.ldm.scheduler = SchedulerClass.from_config(self.ldm.scheduler.config) - - if self.is_karras_compatible and ((not custom_scheduler) or (custom_scheduler and not is_compatible_scheduler)): + if self.is_karras_compatible: self.ldm.scheduler = EulerAncestralDiscreteScheduler.from_config( self.ldm.scheduler.config ) @@ -174,6 +162,23 @@ def __call__(self, inputs: str, **kwargs) -> "Image.Image": Return: A :obj:`PIL.Image.Image` with the raw image representation as PIL. """ + + #Check if users set a custom scheduler and pop if from the kwargs if so + custom_scheduler = None + if "scheduler" in kwargs: + custom_scheduler = kwargs["scheduler"] + kwargs.pop("scheduler") + + if custom_scheduler: + compatibles = self.ldm.compatibles + #Check if the scheduler is compatible + is_compatible_scheduler = [cls for cls in compatibles if cls.__name__ == custom_scheduler] + #In case of a compatible scheduler, swap to that for inference + if(is_compatible_scheduler): + #Import the scheduler dynamically + SchedulerClass = getattr(importlib.import_module("diffusers.schedulers"), custom_scheduler) + self.ldm.scheduler = SchedulerClass.from_config(self.ldm.scheduler.config) + self._load_lora_adapter(kwargs) if idle.UNLOAD_IDLE: From 0c0c3eb84a215cbb85cc4698e7def8e73e523340 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Wed, 6 Dec 2023 14:46:07 +0100 Subject: [PATCH 3/3] =?UTF-8?q?=E2=9C=A8=20style=20=E2=9C=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../diffusers/app/pipelines/text_to_image.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/docker_images/diffusers/app/pipelines/text_to_image.py b/docker_images/diffusers/app/pipelines/text_to_image.py index a9ca7835..59ef6159 100644 --- a/docker_images/diffusers/app/pipelines/text_to_image.py +++ b/docker_images/diffusers/app/pipelines/text_to_image.py @@ -1,8 +1,8 @@ +import importlib import json import logging import os from typing import TYPE_CHECKING -import importlib import torch from app import idle, lora, timing, validation @@ -163,22 +163,28 @@ def __call__(self, inputs: str, **kwargs) -> "Image.Image": A :obj:`PIL.Image.Image` with the raw image representation as PIL. """ - #Check if users set a custom scheduler and pop if from the kwargs if so + # Check if users set a custom scheduler and pop if from the kwargs if so custom_scheduler = None if "scheduler" in kwargs: custom_scheduler = kwargs["scheduler"] kwargs.pop("scheduler") - + if custom_scheduler: compatibles = self.ldm.compatibles - #Check if the scheduler is compatible - is_compatible_scheduler = [cls for cls in compatibles if cls.__name__ == custom_scheduler] - #In case of a compatible scheduler, swap to that for inference - if(is_compatible_scheduler): - #Import the scheduler dynamically - SchedulerClass = getattr(importlib.import_module("diffusers.schedulers"), custom_scheduler) - self.ldm.scheduler = SchedulerClass.from_config(self.ldm.scheduler.config) - + # Check if the scheduler is compatible + is_compatible_scheduler = [ + cls for cls in compatibles if cls.__name__ == custom_scheduler + ] + # In case of a compatible scheduler, swap to that for inference + if is_compatible_scheduler: + # Import the scheduler dynamically + SchedulerClass = getattr( + importlib.import_module("diffusers.schedulers"), custom_scheduler + ) + self.ldm.scheduler = SchedulerClass.from_config( + self.ldm.scheduler.config + ) + self._load_lora_adapter(kwargs) if idle.UNLOAD_IDLE: