Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Raphael Glon <oOraph@users.noreply.github.com>
  • Loading branch information
oOraph committed Nov 29, 2023
1 parent 040c5db commit 08a869d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
7 changes: 5 additions & 2 deletions docker_images/diffusers/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,21 @@ def __init__(self, model_id: str):

# This should allow us to make the image work with private models when no token is provided, if the said model
# is already in local cache
cache_root = os.getenv("DIFFUSERS_CACHE", os.getenv("HUGGINGFACE_HUB_CACHE"))
cache_root = os.getenv(
"DIFFUSERS_CACHE", os.getenv("HUGGINGFACE_HUB_CACHE", "")
)
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(cache_root, folder_name)
logger.debug("Cache folder path %s", folder_path)
filename = os.path.join(folder_path, "hub_model_info.json")
try:
with open(filename, "r") as f:
model_data = json.load(f)
except OSError:
logger.info(
"No cached model info found in %s for model %s. Fetching on the hub",
"No cached model info found in file %s found for model %s. Fetching on the hub",
filename,
model_id,
)
Expand Down
7 changes: 5 additions & 2 deletions docker_images/diffusers/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,21 @@ def __init__(self, model_id: str):

# This should allow us to make the image work with private models when no token is provided, if the said model
# is already in local cache
cache_root = os.getenv("DIFFUSERS_CACHE", os.getenv("HUGGINGFACE_HUB_CACHE"))
cache_root = os.getenv(
"DIFFUSERS_CACHE", os.getenv("HUGGINGFACE_HUB_CACHE", "")
)
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(cache_root, folder_name)
logger.debug("Cache folder path %s", folder_path)
filename = os.path.join(folder_path, "hub_model_info.json")
try:
with open(filename, "r") as f:
model_data = json.load(f)
except OSError:
logger.info(
"No cached model info found in %s for model %s. Fetching on the hub",
"No cached model info found in file %s found for model %s. Fetching on the hub",
filename,
model_id,
)
Expand Down

0 comments on commit 08a869d

Please sign in to comment.