Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/offline #356

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docker_images/diffusers/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ ARG max_workers=1
ENV MAX_WORKERS=$max_workers
ENV HUGGINGFACE_HUB_CACHE=/data
ENV DIFFUSERS_CACHE=/data
ENV HF_HOME=/data

# Necessary on GPU environment docker.
# TIMEOUT env variable is used by nvcr.io/nvidia/pytorch:xx for another purpose
Expand Down
58 changes: 52 additions & 6 deletions docker_images/diffusers/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os

import torch
from app import idle, timing
from app import idle, timing, validation
from app.pipelines import Pipeline
from diffusers import (
AltDiffusionImg2ImgPipeline,
Expand All @@ -26,7 +26,7 @@
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
)
from huggingface_hub import hf_hub_download, model_info
from huggingface_hub import file_download, hf_api, hf_hub_download, model_info, utils
from PIL import Image


Expand All @@ -36,7 +36,36 @@
class ImageToImagePipeline(Pipeline):
def __init__(self, model_id: str):
use_auth_token = os.getenv("HF_API_TOKEN")
model_data = model_info(model_id, token=use_auth_token)
self.use_auth_token = use_auth_token
# This should allow us to make the image work with private models when no token is provided, if the said model
# is already in local cache
self.offline_preferred = validation.str_to_bool(os.getenv("OFFLINE_PREFERRED"))
fetched = False
if self.offline_preferred:
cache_root = os.getenv(
"DIFFUSERS_CACHE", os.getenv("HUGGINGFACE_HUB_CACHE", "")
)
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(cache_root, folder_name)
logger.debug("Cache folder path %s", folder_path)
filename = os.path.join(folder_path, "hub_model_info.json")
try:
with open(filename, "r") as f:
model_data = json.load(f)
except OSError:
logger.info(
"No cached model info found in file %s found for model %s. Fetching on the hub",
filename,
model_id,
)
else:
model_data = hf_api.ModelInfo(**model_data)
fetched = True

if not fetched:
model_data = model_info(model_id, token=self.use_auth_token)

kwargs = (
{"safety_checker": None}
Expand All @@ -55,9 +84,26 @@ def __init__(self, model_id: str):
config_file_name = file_name
break
if config_file_name:
config_file = hf_hub_download(
model_id, config_file_name, token=use_auth_token
)
fetched = False
if self.offline_preferred:
try:
config_file = hf_hub_download(
model_id,
config_file_name,
token=self.use_auth_token,
local_files_only=True,
)
except utils.LocalEntryNotFoundError:
logger.info("Unable to fetch model index in local cache")
else:
fetched = True
if not fetched:
config_file = hf_hub_download(
model_id,
config_file_name,
token=self.use_auth_token,
)

with open(config_file, "r") as f:
config_dict = json.load(f)

Expand Down
57 changes: 51 additions & 6 deletions docker_images/diffusers/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING

import torch
from app import idle, lora, timing
from app import idle, lora, timing, validation
from app.pipelines import Pipeline
from diffusers import (
AutoencoderKL,
Expand All @@ -13,7 +13,7 @@
EulerAncestralDiscreteScheduler,
)
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
from huggingface_hub import hf_hub_download, model_info
from huggingface_hub import file_download, hf_api, hf_hub_download, model_info, utils


logger = logging.getLogger(__name__)
Expand All @@ -27,7 +27,35 @@ def __init__(self, model_id: str):
self.current_lora_adapter = None
self.model_id = None
self.use_auth_token = os.getenv("HF_API_TOKEN")
model_data = model_info(model_id, token=self.use_auth_token)
# This should allow us to make the image work with private models when no token is provided, if the said model
# is already in local cache
self.offline_preferred = validation.str_to_bool(os.getenv("OFFLINE_PREFERRED"))
fetched = False
if self.offline_preferred:
cache_root = os.getenv(
"DIFFUSERS_CACHE", os.getenv("HUGGINGFACE_HUB_CACHE", "")
)
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(cache_root, folder_name)
logger.debug("Cache folder path %s", folder_path)
filename = os.path.join(folder_path, "hub_model_info.json")
try:
with open(filename, "r") as f:
model_data = json.load(f)
except OSError:
logger.info(
"No cached model info found in file %s found for model %s. Fetching on the hub",
filename,
model_id,
)
else:
model_data = hf_api.ModelInfo(**model_data)
fetched = True

if not fetched:
model_data = model_info(model_id, token=self.use_auth_token)

kwargs = (
{"safety_checker": None}
Expand All @@ -44,9 +72,26 @@ def __init__(self, model_id: str):
if self._is_lora(model_data):
model_type = "LoraModel"
elif has_model_index:
config_file = hf_hub_download(
model_id, "model_index.json", token=self.use_auth_token
)
fetched = False
if self.offline_preferred:
try:
config_file = hf_hub_download(
model_id,
"model_index.json",
token=self.use_auth_token,
local_files_only=True,
)
except utils.LocalEntryNotFoundError:
logger.info("Unable to fetch model index in local cache")
else:
fetched = True

if not fetched:
config_file = hf_hub_download(
model_id,
"model_index.json",
token=self.use_auth_token,
)
with open(config_file, "r") as f:
config_dict = json.load(f)
model_type = config_dict.get("_class_name", None)
Expand Down
8 changes: 8 additions & 0 deletions docker_images/diffusers/app/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import re


STR_TO_BOOL = re.compile(r"^\s*true|yes|1\s*$", re.IGNORECASE)


def str_to_bool(s):
return STR_TO_BOOL.match(str(s))
3 changes: 1 addition & 2 deletions docker_images/diffusers/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
starlette==0.27.0
# to be replaced with api-inference-community==0.0.33 as soon as released
git+https://github.com/huggingface/api-inference-community.git@b3ef3f3a6015ed988ce77f71935e45006be5b054
# to be replaced with diffusers==0.21.5 as soon as released
git+https://github.com/huggingface/diffusers.git@ed2f956072a3b446d984f359ba6c427c259ab4ee
diffusers==0.23.1
transformers==4.31.0
accelerate==0.21.0
hf_transfer==0.1.3
Expand Down
Loading