diff --git a/docker_images/diffusers/app/pipelines/text_to_image.py b/docker_images/diffusers/app/pipelines/text_to_image.py index 30944ea7..59ef6159 100644 --- a/docker_images/diffusers/app/pipelines/text_to_image.py +++ b/docker_images/diffusers/app/pipelines/text_to_image.py @@ -1,3 +1,4 @@ +import importlib import json import logging import os @@ -161,6 +162,29 @@ def __call__(self, inputs: str, **kwargs) -> "Image.Image": Return: A :obj:`PIL.Image.Image` with the raw image representation as PIL. """ + + # Check if users set a custom scheduler and pop if from the kwargs if so + custom_scheduler = None + if "scheduler" in kwargs: + custom_scheduler = kwargs["scheduler"] + kwargs.pop("scheduler") + + if custom_scheduler: + compatibles = self.ldm.compatibles + # Check if the scheduler is compatible + is_compatible_scheduler = [ + cls for cls in compatibles if cls.__name__ == custom_scheduler + ] + # In case of a compatible scheduler, swap to that for inference + if is_compatible_scheduler: + # Import the scheduler dynamically + SchedulerClass = getattr( + importlib.import_module("diffusers.schedulers"), custom_scheduler + ) + self.ldm.scheduler = SchedulerClass.from_config( + self.ldm.scheduler.config + ) + self._load_lora_adapter(kwargs) if idle.UNLOAD_IDLE: