Skip to content

Commit

Permalink
model info that works offline (#371)
Browse files Browse the repository at this point in the history
* offline model info + hub local file layout helpers
can be used to fix #372
  • Loading branch information
oOraph authored Jan 11, 2024
1 parent 7deb0e3 commit f06a71e
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 1 deletion.
172 changes: 172 additions & 0 deletions api_inference_community/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import json
import logging
import os
import pathlib
import re
from typing import List, Optional

from huggingface_hub import ModelCard, constants, hf_api, try_to_load_from_cache
from huggingface_hub.file_download import repo_folder_name


logger = logging.getLogger(__name__)


def _cached_repo_root_path(cache_dir: pathlib.Path, repo_id: str) -> pathlib.Path:
folder = pathlib.Path(repo_folder_name(repo_id=repo_id, repo_type="model"))
return cache_dir / folder


def cached_revision_path(cache_dir, repo_id, revision) -> pathlib.Path:

error_msg = f"No revision path found for {repo_id}, revision {revision}"

if revision is None:
revision = "main"

repo_cache = _cached_repo_root_path(cache_dir, repo_id)

if not repo_cache.is_dir():
msg = f"Local repo {repo_cache} does not exist"
logger.error(msg)
raise Exception(msg)

refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"

# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()

# Check if revision folder exists
if not snapshots_dir.exists():
msg = f"No local revision path {snapshots_dir} found for {repo_id}, revision {revision}"
logger.error(msg)
raise Exception(msg)

cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
logger.error(error_msg)
raise Exception(error_msg)

return snapshots_dir / revision


def _build_offline_model_info(
repo_id: str, cache_dir: pathlib.Path, revision: str
) -> hf_api.ModelInfo:

logger.info("Rebuilding offline model info for repo %s", repo_id)

# Let's rebuild some partial model info from what we see in cache, info extracted should be enough
# for most use cases
card_path = try_to_load_from_cache(
repo_id=repo_id,
filename="README.md",
cache_dir=cache_dir,
revision=revision,
)
if not isinstance(card_path, str):
raise Exception(
"Unable to rebuild offline model info, no README could be found"
)

card_path = pathlib.Path(card_path)
logger.debug("Loading model card from model readme %s", card_path)
model_card = ModelCard.load(card_path)
card_data = model_card.data.to_dict()

repo = card_path.parent
logger.debug("Repo path %s", repo)
siblings = _build_offline_siblings(repo)
model_info = hf_api.ModelInfo(
private=False,
downloads=0,
likes=0,
id=repo_id,
card_data=card_data,
siblings=siblings,
**card_data,
)
logger.info("Offline model info for repo %s: %s", repo, model_info)
return model_info


def _build_offline_siblings(repo: pathlib.Path) -> List[dict]:
siblings = []
prefix_pattern = re.compile(r"^" + re.escape(str(repo)) + r"(.*)$")
for root, dirs, files in os.walk(repo):
for file in files:
filepath = os.path.join(root, file)
size = os.stat(filepath).st_size
m = prefix_pattern.match(filepath)
if not m:
msg = (
f"File {filepath} does not match expected pattern {prefix_pattern}"
)
logger.error(msg)
raise Exception(msg)
filepath = m.group(1)
filepath = filepath.strip(os.sep)
sibling = dict(rfilename=filepath, size=size)
siblings.append(sibling)
return siblings


def _cached_model_info(
repo_id: str, revision: str, cache_dir: pathlib.Path
) -> hf_api.ModelInfo:
"""
Looks for a json file containing prefetched model info in the revision path.
If none found we just rebuild model info with the local directory files.
Note that this file is not automatically created by hub_download/snapshot_download.
It is just a convenience we add here, just in case the offline info we rebuild from
the local directories would not cover all use cases.
"""
revision_path = cached_revision_path(cache_dir, repo_id, revision)
model_info_basename = "hub_model_info.json"
model_info_path = revision_path / model_info_basename
logger.info("Checking if there are some cached model info at %s", model_info_path)
if os.path.exists(model_info_path):
with open(model_info_path, "r") as f:
o = json.load(f)
r = hf_api.ModelInfo(**o)
logger.debug("Cached model info from file: %s", r)
else:
logger.debug(
"No cached model info file %s found, "
"rebuilding partial model info from cached model files",
model_info_path,
)
# Let's rebuild some partial model info from what we see in cache, info extracted should be enough
# for most use cases
r = _build_offline_model_info(repo_id, cache_dir, revision)

return r


def hub_model_info(
repo_id: str,
revision: Optional[str] = None,
cache_dir: Optional[pathlib.Path] = None,
**kwargs,
) -> hf_api.ModelInfo:
"""
Get Hub model info with offline support
"""
if revision is None:
revision = "main"

if not constants.HF_HUB_OFFLINE:
return hf_api.model_info(repo_id=repo_id, revision=revision, **kwargs)

logger.info("Model info for offline mode")

if cache_dir is None:
cache_dir = pathlib.Path(constants.HF_HUB_CACHE)

return _cached_model_info(repo_id, revision, cache_dir)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ numpy>=1.18.0
pydantic>=2
parameterized>=0.8.1
pillow>=8.2.0
huggingface_hub>=0.5.1
huggingface_hub>=0.20.2
datasets>=2.2
pytest
httpx
Expand Down
103 changes: 103 additions & 0 deletions tests/test_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import logging
import sys
from unittest import TestCase

from api_inference_community import hub
from huggingface_hub import constants, hf_api, snapshot_download


logger = logging.getLogger(__name__)
logger.level = logging.DEBUG
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)


class HubTestCase(TestCase):
def test_offline_model_info1(self):
repo_id = "google/t5-efficient-tiny"
revision = "3441d7e8bf3f89841f366d39452b95200416e4a9"
bak_value = constants.HF_HUB_OFFLINE
try:
# with tempfile.TemporaryDirectory() as cache_dir:
# logger.info("Cache directory %s", cache_dir)
dirpath = snapshot_download(repo_id=repo_id, revision=revision)
logger.info("Snapshot downloaded at %s", dirpath)
constants.HF_HUB_OFFLINE = True
model_info = hub.hub_model_info(repo_id=repo_id, revision=revision)
finally:
constants.HF_HUB_OFFLINE = bak_value

logger.info("Model info %s", model_info)
self.assertIsInstance(model_info, hf_api.ModelInfo)
self.assertEqual(model_info.id, repo_id)
self.assertEqual(model_info.downloads, 0)
self.assertEqual(model_info.likes, 0)
self.assertEqual(len(model_info.siblings), 12)
self.assertIn("pytorch_model.bin", [s.rfilename for s in model_info.siblings])
self.assertFalse(model_info.private)
self.assertEqual(model_info.license, "apache-2.0") # noqa
self.assertEqual(model_info.tags, ["deep-narrow"])
self.assertIsNone(model_info.library_name)

logger.info("Model card data %s", model_info.card_data)
self.assertEqual(model_info.card_data, model_info.cardData)
self.assertEqual(model_info.card_data.license, "apache-2.0")
self.assertEqual(model_info.card_data.tags, ["deep-narrow"])

def test_offline_model_info2(self):
repo_id = "dfurman/Mixtral-8x7B-peft-v0.1"
revision = "8908d586219993ec79949acaef566363a7c7864c"
bak_value = constants.HF_HUB_OFFLINE
try:
# with tempfile.TemporaryDirectory() as cache_dir:
# logger.info("Cache directory %s", cache_dir)
dirpath = snapshot_download(repo_id=repo_id, revision=revision)
logger.info("Snapshot downloaded at %s", dirpath)
constants.HF_HUB_OFFLINE = True
model_info = hub.hub_model_info(repo_id=repo_id, revision=revision)
finally:
constants.HF_HUB_OFFLINE = bak_value

logger.info("Model info %s", model_info)
self.assertIsInstance(model_info, hf_api.ModelInfo)
self.assertEqual(model_info.id, repo_id)
self.assertEqual(model_info.downloads, 0)
self.assertEqual(model_info.likes, 0)
self.assertEqual(len(model_info.siblings), 9)
self.assertFalse(model_info.private)
self.assertEqual(model_info.license, "apache-2.0") # noqa
self.assertEqual(model_info.tags, ["mistral"])
self.assertEqual(model_info.library_name, "peft")
self.assertEqual(model_info.pipeline_tag, "text-generation")
self.assertIn(".gitattributes", [s.rfilename for s in model_info.siblings])
logger.info("Model card data %s", model_info.card_data)
self.assertEqual(model_info.card_data, model_info.cardData)
self.assertEqual(model_info.card_data.license, "apache-2.0")
self.assertEqual(model_info.card_data.tags, ["mistral"])

def test_online_model_info(self):
repo_id = "dfurman/Mixtral-8x7B-Instruct-v0.1"
revision = "8908d586219993ec79949acaef566363a7c7864c"
bak_value = constants.HF_HUB_OFFLINE
try:
constants.HF_HUB_OFFLINE = False
model_info = hub.hub_model_info(repo_id=repo_id, revision=revision)
finally:
constants.HF_HUB_OFFLINE = bak_value

logger.info("Model info %s", model_info)
self.assertIsInstance(model_info, hf_api.ModelInfo)
self.assertEqual(model_info.id, repo_id)
self.assertGreater(model_info.downloads, 0)
self.assertGreater(model_info.likes, 0)
self.assertEqual(len(model_info.siblings), 9)
self.assertFalse(model_info.private)
self.assertGreater(model_info.tags, ["peft", "safetensors", "mistral"])
self.assertEqual(model_info.library_name, "peft")
self.assertEqual(model_info.pipeline_tag, "text-generation")
self.assertIn(".gitattributes", [s.rfilename for s in model_info.siblings])
logger.info("Model card data %s", model_info.card_data)
self.assertEqual(model_info.card_data, model_info.cardData)
self.assertEqual(model_info.card_data.license, "apache-2.0")
self.assertEqual(model_info.card_data.tags, ["mistral"])
self.assertIsNone(model_info.safetensors)

0 comments on commit f06a71e

Please sign in to comment.