-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
model info that works offline (#371)
* offline model info + hub local file layout helpers can be used to fix #372
- Loading branch information
Showing
3 changed files
with
276 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import json | ||
import logging | ||
import os | ||
import pathlib | ||
import re | ||
from typing import List, Optional | ||
|
||
from huggingface_hub import ModelCard, constants, hf_api, try_to_load_from_cache | ||
from huggingface_hub.file_download import repo_folder_name | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _cached_repo_root_path(cache_dir: pathlib.Path, repo_id: str) -> pathlib.Path: | ||
folder = pathlib.Path(repo_folder_name(repo_id=repo_id, repo_type="model")) | ||
return cache_dir / folder | ||
|
||
|
||
def cached_revision_path(cache_dir, repo_id, revision) -> pathlib.Path: | ||
|
||
error_msg = f"No revision path found for {repo_id}, revision {revision}" | ||
|
||
if revision is None: | ||
revision = "main" | ||
|
||
repo_cache = _cached_repo_root_path(cache_dir, repo_id) | ||
|
||
if not repo_cache.is_dir(): | ||
msg = f"Local repo {repo_cache} does not exist" | ||
logger.error(msg) | ||
raise Exception(msg) | ||
|
||
refs_dir = repo_cache / "refs" | ||
snapshots_dir = repo_cache / "snapshots" | ||
|
||
# Resolve refs (for instance to convert main to the associated commit sha) | ||
if refs_dir.is_dir(): | ||
revision_file = refs_dir / revision | ||
if revision_file.exists(): | ||
with revision_file.open() as f: | ||
revision = f.read() | ||
|
||
# Check if revision folder exists | ||
if not snapshots_dir.exists(): | ||
msg = f"No local revision path {snapshots_dir} found for {repo_id}, revision {revision}" | ||
logger.error(msg) | ||
raise Exception(msg) | ||
|
||
cached_shas = os.listdir(snapshots_dir) | ||
if revision not in cached_shas: | ||
# No cache for this revision and we won't try to return a random revision | ||
logger.error(error_msg) | ||
raise Exception(error_msg) | ||
|
||
return snapshots_dir / revision | ||
|
||
|
||
def _build_offline_model_info( | ||
repo_id: str, cache_dir: pathlib.Path, revision: str | ||
) -> hf_api.ModelInfo: | ||
|
||
logger.info("Rebuilding offline model info for repo %s", repo_id) | ||
|
||
# Let's rebuild some partial model info from what we see in cache, info extracted should be enough | ||
# for most use cases | ||
card_path = try_to_load_from_cache( | ||
repo_id=repo_id, | ||
filename="README.md", | ||
cache_dir=cache_dir, | ||
revision=revision, | ||
) | ||
if not isinstance(card_path, str): | ||
raise Exception( | ||
"Unable to rebuild offline model info, no README could be found" | ||
) | ||
|
||
card_path = pathlib.Path(card_path) | ||
logger.debug("Loading model card from model readme %s", card_path) | ||
model_card = ModelCard.load(card_path) | ||
card_data = model_card.data.to_dict() | ||
|
||
repo = card_path.parent | ||
logger.debug("Repo path %s", repo) | ||
siblings = _build_offline_siblings(repo) | ||
model_info = hf_api.ModelInfo( | ||
private=False, | ||
downloads=0, | ||
likes=0, | ||
id=repo_id, | ||
card_data=card_data, | ||
siblings=siblings, | ||
**card_data, | ||
) | ||
logger.info("Offline model info for repo %s: %s", repo, model_info) | ||
return model_info | ||
|
||
|
||
def _build_offline_siblings(repo: pathlib.Path) -> List[dict]: | ||
siblings = [] | ||
prefix_pattern = re.compile(r"^" + re.escape(str(repo)) + r"(.*)$") | ||
for root, dirs, files in os.walk(repo): | ||
for file in files: | ||
filepath = os.path.join(root, file) | ||
size = os.stat(filepath).st_size | ||
m = prefix_pattern.match(filepath) | ||
if not m: | ||
msg = ( | ||
f"File {filepath} does not match expected pattern {prefix_pattern}" | ||
) | ||
logger.error(msg) | ||
raise Exception(msg) | ||
filepath = m.group(1) | ||
filepath = filepath.strip(os.sep) | ||
sibling = dict(rfilename=filepath, size=size) | ||
siblings.append(sibling) | ||
return siblings | ||
|
||
|
||
def _cached_model_info( | ||
repo_id: str, revision: str, cache_dir: pathlib.Path | ||
) -> hf_api.ModelInfo: | ||
""" | ||
Looks for a json file containing prefetched model info in the revision path. | ||
If none found we just rebuild model info with the local directory files. | ||
Note that this file is not automatically created by hub_download/snapshot_download. | ||
It is just a convenience we add here, just in case the offline info we rebuild from | ||
the local directories would not cover all use cases. | ||
""" | ||
revision_path = cached_revision_path(cache_dir, repo_id, revision) | ||
model_info_basename = "hub_model_info.json" | ||
model_info_path = revision_path / model_info_basename | ||
logger.info("Checking if there are some cached model info at %s", model_info_path) | ||
if os.path.exists(model_info_path): | ||
with open(model_info_path, "r") as f: | ||
o = json.load(f) | ||
r = hf_api.ModelInfo(**o) | ||
logger.debug("Cached model info from file: %s", r) | ||
else: | ||
logger.debug( | ||
"No cached model info file %s found, " | ||
"rebuilding partial model info from cached model files", | ||
model_info_path, | ||
) | ||
# Let's rebuild some partial model info from what we see in cache, info extracted should be enough | ||
# for most use cases | ||
r = _build_offline_model_info(repo_id, cache_dir, revision) | ||
|
||
return r | ||
|
||
|
||
def hub_model_info( | ||
repo_id: str, | ||
revision: Optional[str] = None, | ||
cache_dir: Optional[pathlib.Path] = None, | ||
**kwargs, | ||
) -> hf_api.ModelInfo: | ||
""" | ||
Get Hub model info with offline support | ||
""" | ||
if revision is None: | ||
revision = "main" | ||
|
||
if not constants.HF_HUB_OFFLINE: | ||
return hf_api.model_info(repo_id=repo_id, revision=revision, **kwargs) | ||
|
||
logger.info("Model info for offline mode") | ||
|
||
if cache_dir is None: | ||
cache_dir = pathlib.Path(constants.HF_HUB_CACHE) | ||
|
||
return _cached_model_info(repo_id, revision, cache_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import logging | ||
import sys | ||
from unittest import TestCase | ||
|
||
from api_inference_community import hub | ||
from huggingface_hub import constants, hf_api, snapshot_download | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
logger.level = logging.DEBUG | ||
stream_handler = logging.StreamHandler(sys.stdout) | ||
logger.addHandler(stream_handler) | ||
|
||
|
||
class HubTestCase(TestCase): | ||
def test_offline_model_info1(self): | ||
repo_id = "google/t5-efficient-tiny" | ||
revision = "3441d7e8bf3f89841f366d39452b95200416e4a9" | ||
bak_value = constants.HF_HUB_OFFLINE | ||
try: | ||
# with tempfile.TemporaryDirectory() as cache_dir: | ||
# logger.info("Cache directory %s", cache_dir) | ||
dirpath = snapshot_download(repo_id=repo_id, revision=revision) | ||
logger.info("Snapshot downloaded at %s", dirpath) | ||
constants.HF_HUB_OFFLINE = True | ||
model_info = hub.hub_model_info(repo_id=repo_id, revision=revision) | ||
finally: | ||
constants.HF_HUB_OFFLINE = bak_value | ||
|
||
logger.info("Model info %s", model_info) | ||
self.assertIsInstance(model_info, hf_api.ModelInfo) | ||
self.assertEqual(model_info.id, repo_id) | ||
self.assertEqual(model_info.downloads, 0) | ||
self.assertEqual(model_info.likes, 0) | ||
self.assertEqual(len(model_info.siblings), 12) | ||
self.assertIn("pytorch_model.bin", [s.rfilename for s in model_info.siblings]) | ||
self.assertFalse(model_info.private) | ||
self.assertEqual(model_info.license, "apache-2.0") # noqa | ||
self.assertEqual(model_info.tags, ["deep-narrow"]) | ||
self.assertIsNone(model_info.library_name) | ||
|
||
logger.info("Model card data %s", model_info.card_data) | ||
self.assertEqual(model_info.card_data, model_info.cardData) | ||
self.assertEqual(model_info.card_data.license, "apache-2.0") | ||
self.assertEqual(model_info.card_data.tags, ["deep-narrow"]) | ||
|
||
def test_offline_model_info2(self): | ||
repo_id = "dfurman/Mixtral-8x7B-peft-v0.1" | ||
revision = "8908d586219993ec79949acaef566363a7c7864c" | ||
bak_value = constants.HF_HUB_OFFLINE | ||
try: | ||
# with tempfile.TemporaryDirectory() as cache_dir: | ||
# logger.info("Cache directory %s", cache_dir) | ||
dirpath = snapshot_download(repo_id=repo_id, revision=revision) | ||
logger.info("Snapshot downloaded at %s", dirpath) | ||
constants.HF_HUB_OFFLINE = True | ||
model_info = hub.hub_model_info(repo_id=repo_id, revision=revision) | ||
finally: | ||
constants.HF_HUB_OFFLINE = bak_value | ||
|
||
logger.info("Model info %s", model_info) | ||
self.assertIsInstance(model_info, hf_api.ModelInfo) | ||
self.assertEqual(model_info.id, repo_id) | ||
self.assertEqual(model_info.downloads, 0) | ||
self.assertEqual(model_info.likes, 0) | ||
self.assertEqual(len(model_info.siblings), 9) | ||
self.assertFalse(model_info.private) | ||
self.assertEqual(model_info.license, "apache-2.0") # noqa | ||
self.assertEqual(model_info.tags, ["mistral"]) | ||
self.assertEqual(model_info.library_name, "peft") | ||
self.assertEqual(model_info.pipeline_tag, "text-generation") | ||
self.assertIn(".gitattributes", [s.rfilename for s in model_info.siblings]) | ||
logger.info("Model card data %s", model_info.card_data) | ||
self.assertEqual(model_info.card_data, model_info.cardData) | ||
self.assertEqual(model_info.card_data.license, "apache-2.0") | ||
self.assertEqual(model_info.card_data.tags, ["mistral"]) | ||
|
||
def test_online_model_info(self): | ||
repo_id = "dfurman/Mixtral-8x7B-Instruct-v0.1" | ||
revision = "8908d586219993ec79949acaef566363a7c7864c" | ||
bak_value = constants.HF_HUB_OFFLINE | ||
try: | ||
constants.HF_HUB_OFFLINE = False | ||
model_info = hub.hub_model_info(repo_id=repo_id, revision=revision) | ||
finally: | ||
constants.HF_HUB_OFFLINE = bak_value | ||
|
||
logger.info("Model info %s", model_info) | ||
self.assertIsInstance(model_info, hf_api.ModelInfo) | ||
self.assertEqual(model_info.id, repo_id) | ||
self.assertGreater(model_info.downloads, 0) | ||
self.assertGreater(model_info.likes, 0) | ||
self.assertEqual(len(model_info.siblings), 9) | ||
self.assertFalse(model_info.private) | ||
self.assertGreater(model_info.tags, ["peft", "safetensors", "mistral"]) | ||
self.assertEqual(model_info.library_name, "peft") | ||
self.assertEqual(model_info.pipeline_tag, "text-generation") | ||
self.assertIn(".gitattributes", [s.rfilename for s in model_info.siblings]) | ||
logger.info("Model card data %s", model_info.card_data) | ||
self.assertEqual(model_info.card_data, model_info.cardData) | ||
self.assertEqual(model_info.card_data.license, "apache-2.0") | ||
self.assertEqual(model_info.card_data.tags, ["mistral"]) | ||
self.assertIsNone(model_info.safetensors) |