Skip to content

Commit

Permalink
Dev/offline (#356)
Browse files Browse the repository at this point in the history
* update diffusers requirements.txt
* diffusers image: possibility to prefer offline mode

Useful to load private models successfully without providing any token if
the said model is already cached
  • Loading branch information
oOraph authored Nov 30, 2023
1 parent 3a64955 commit a49c34d
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 15 deletions.
1 change: 0 additions & 1 deletion docker_images/diffusers/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ ARG max_workers=1
ENV MAX_WORKERS=$max_workers
ENV HUGGINGFACE_HUB_CACHE=/data
ENV DIFFUSERS_CACHE=/data
ENV HF_HOME=/data

# Necessary on GPU environment docker.
# TIMEOUT env variable is used by nvcr.io/nvidia/pytorch:xx for another purpose
Expand Down
58 changes: 52 additions & 6 deletions docker_images/diffusers/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os

import torch
from app import idle, timing
from app import idle, timing, validation
from app.pipelines import Pipeline
from diffusers import (
AltDiffusionImg2ImgPipeline,
Expand All @@ -26,7 +26,7 @@
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
)
from huggingface_hub import hf_hub_download, model_info
from huggingface_hub import file_download, hf_api, hf_hub_download, model_info, utils
from PIL import Image


Expand All @@ -36,7 +36,36 @@
class ImageToImagePipeline(Pipeline):
def __init__(self, model_id: str):
use_auth_token = os.getenv("HF_API_TOKEN")
model_data = model_info(model_id, token=use_auth_token)
self.use_auth_token = use_auth_token
# This should allow us to make the image work with private models when no token is provided, if the said model
# is already in local cache
self.offline_preferred = validation.str_to_bool(os.getenv("OFFLINE_PREFERRED"))
fetched = False
if self.offline_preferred:
cache_root = os.getenv(
"DIFFUSERS_CACHE", os.getenv("HUGGINGFACE_HUB_CACHE", "")
)
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(cache_root, folder_name)
logger.debug("Cache folder path %s", folder_path)
filename = os.path.join(folder_path, "hub_model_info.json")
try:
with open(filename, "r") as f:
model_data = json.load(f)
except OSError:
logger.info(
"No cached model info found in file %s found for model %s. Fetching on the hub",
filename,
model_id,
)
else:
model_data = hf_api.ModelInfo(**model_data)
fetched = True

if not fetched:
model_data = model_info(model_id, token=self.use_auth_token)

kwargs = (
{"safety_checker": None}
Expand All @@ -55,9 +84,26 @@ def __init__(self, model_id: str):
config_file_name = file_name
break
if config_file_name:
config_file = hf_hub_download(
model_id, config_file_name, token=use_auth_token
)
fetched = False
if self.offline_preferred:
try:
config_file = hf_hub_download(
model_id,
config_file_name,
token=self.use_auth_token,
local_files_only=True,
)
except utils.LocalEntryNotFoundError:
logger.info("Unable to fetch model index in local cache")
else:
fetched = True
if not fetched:
config_file = hf_hub_download(
model_id,
config_file_name,
token=self.use_auth_token,
)

with open(config_file, "r") as f:
config_dict = json.load(f)

Expand Down
57 changes: 51 additions & 6 deletions docker_images/diffusers/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING

import torch
from app import idle, lora, timing
from app import idle, lora, timing, validation
from app.pipelines import Pipeline
from diffusers import (
AutoencoderKL,
Expand All @@ -13,7 +13,7 @@
EulerAncestralDiscreteScheduler,
)
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
from huggingface_hub import hf_hub_download, model_info
from huggingface_hub import file_download, hf_api, hf_hub_download, model_info, utils


logger = logging.getLogger(__name__)
Expand All @@ -27,7 +27,35 @@ def __init__(self, model_id: str):
self.current_lora_adapter = None
self.model_id = None
self.use_auth_token = os.getenv("HF_API_TOKEN")
model_data = model_info(model_id, token=self.use_auth_token)
# This should allow us to make the image work with private models when no token is provided, if the said model
# is already in local cache
self.offline_preferred = validation.str_to_bool(os.getenv("OFFLINE_PREFERRED"))
fetched = False
if self.offline_preferred:
cache_root = os.getenv(
"DIFFUSERS_CACHE", os.getenv("HUGGINGFACE_HUB_CACHE", "")
)
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(cache_root, folder_name)
logger.debug("Cache folder path %s", folder_path)
filename = os.path.join(folder_path, "hub_model_info.json")
try:
with open(filename, "r") as f:
model_data = json.load(f)
except OSError:
logger.info(
"No cached model info found in file %s found for model %s. Fetching on the hub",
filename,
model_id,
)
else:
model_data = hf_api.ModelInfo(**model_data)
fetched = True

if not fetched:
model_data = model_info(model_id, token=self.use_auth_token)

kwargs = (
{"safety_checker": None}
Expand All @@ -44,9 +72,26 @@ def __init__(self, model_id: str):
if self._is_lora(model_data):
model_type = "LoraModel"
elif has_model_index:
config_file = hf_hub_download(
model_id, "model_index.json", token=self.use_auth_token
)
fetched = False
if self.offline_preferred:
try:
config_file = hf_hub_download(
model_id,
"model_index.json",
token=self.use_auth_token,
local_files_only=True,
)
except utils.LocalEntryNotFoundError:
logger.info("Unable to fetch model index in local cache")
else:
fetched = True

if not fetched:
config_file = hf_hub_download(
model_id,
"model_index.json",
token=self.use_auth_token,
)
with open(config_file, "r") as f:
config_dict = json.load(f)
model_type = config_dict.get("_class_name", None)
Expand Down
8 changes: 8 additions & 0 deletions docker_images/diffusers/app/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import re


STR_TO_BOOL = re.compile(r"^\s*true|yes|1\s*$", re.IGNORECASE)


def str_to_bool(s):
return STR_TO_BOOL.match(str(s))
3 changes: 1 addition & 2 deletions docker_images/diffusers/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
starlette==0.27.0
# to be replaced with api-inference-community==0.0.33 as soon as released
git+https://github.com/huggingface/api-inference-community.git@b3ef3f3a6015ed988ce77f71935e45006be5b054
# to be replaced with diffusers==0.21.5 as soon as released
git+https://github.com/huggingface/diffusers.git@ed2f956072a3b446d984f359ba6c427c259ab4ee
diffusers==0.23.1
transformers==4.31.0
accelerate==0.21.0
hf_transfer==0.1.3
Expand Down

0 comments on commit a49c34d

Please sign in to comment.