Skip to content

Commit

Permalink
Fix setfit in offline mode (#378)
Browse files Browse the repository at this point in the history
  • Loading branch information
oOraph authored Jan 11, 2024
1 parent f06a71e commit e74c506
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
20 changes: 18 additions & 2 deletions docker_images/setfit/app/main.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
import functools
import logging
import os
import pathlib
from typing import Dict, Type

from api_inference_community import hub
from api_inference_community.routes import pipeline_route, status_ok
from app.pipelines import Pipeline, TextClassificationPipeline
from huggingface_hub import constants
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.routing import Route


TASK = os.getenv("TASK")
MODEL_ID = os.getenv("MODEL_ID")


def get_model_id():
m_id = os.getenv("MODEL_ID")
# Workaround, when sentence_transformers handles properly this env variable
# this should not be needed anymore
if constants.HF_HUB_OFFLINE:
cache_dir = pathlib.Path(constants.HF_HUB_CACHE)
m_id = hub.cached_revision_path(
cache_dir=cache_dir, repo_id=m_id, revision=os.getenv("REVISION")
)
return m_id


MODEL_ID = get_model_id()

logger = logging.getLogger(__name__)


Expand All @@ -40,7 +56,7 @@
@functools.lru_cache()
def get_pipeline() -> Pipeline:
task = os.environ["TASK"]
model_id = os.environ["MODEL_ID"]
model_id = MODEL_ID
if task not in ALLOWED_TASKS:
raise EnvironmentError(f"{task} is not a valid pipeline for model : {model_id}")
return ALLOWED_TASKS[task](model_id)
Expand Down
6 changes: 3 additions & 3 deletions docker_images/setfit/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
starlette==0.27.0
api-inference-community==0.0.32
huggingface_hub==0.19.4
setfit==1.0.1
git+https://github.com/huggingface/api-inference-community.git@f06a71e72e92caeebabaeced979eacb3542bf2ca
huggingface_hub==0.20.2
setfit==1.0.1

0 comments on commit e74c506

Please sign in to comment.