Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Raphael Glon <oOraph@users.noreply.github.com>
  • Loading branch information
oOraph committed Nov 29, 2023
1 parent e5aa9e7 commit b069b0f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
12 changes: 9 additions & 3 deletions docker_images/diffusers/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,19 @@ def __init__(self, model_id: str):

# This should allow us to make the image work with private models when no token is provided, if the said model
# is already in local cache
cache_folder_name = file_download.repo_folder_name(repo_id=model_id, repo_type="model")
cache_folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
filename = os.path.join(cache_folder_name, "hub_model_info.json")
try:
with open(filename, 'r') as f:
with open(filename, "r") as f:
model_data = json.load(f)
except OSError:
logger.info("No cached model info found in %s for model %s. Fetching on the hub", filename, model_id)
logger.info(
"No cached model info found in %s for model %s. Fetching on the hub",
filename,
model_id,
)
model_data = model_info(model_id, token=self.use_auth_token)
else:
model_data = hf_api.ModelInfo(**model_data)
Expand Down
16 changes: 12 additions & 4 deletions docker_images/diffusers/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@ def __init__(self, model_id: str):

# This should allow us to make the image work with private models when no token is provided, if the said model
# is already in local cache
cache_folder_name = file_download.repo_folder_name(repo_id=model_id, repo_type="model")
cache_folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
filename = os.path.join(cache_folder_name, "hub_model_info.json")
try:
with open(filename, 'r') as f:
with open(filename, "r") as f:
model_data = json.load(f)
except OSError:
logger.info("No cached model info found in %s for model %s. Fetching on the hub", filename, model_id)
logger.info(
"No cached model info found in %s for model %s. Fetching on the hub",
filename,
model_id,
)
model_data = model_info(model_id, token=self.use_auth_token)
else:
model_data = hf_api.ModelInfo(**model_data)
Expand All @@ -57,7 +63,9 @@ def __init__(self, model_id: str):
model_type = "LoraModel"
elif has_model_index:
config_file = hf_hub_download(
model_id, "model_index.json", token=self.use_auth_token, local_files_only=self.offline,
model_id,
"model_index.json",
token=self.use_auth_token,
)
with open(config_file, "r") as f:
config_dict = json.load(f)
Expand Down

0 comments on commit b069b0f

Please sign in to comment.