Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix status code #459

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions api_inference_community/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
IMAGE,
IMAGE_INPUTS,
IMAGE_OUTPUTS,
KNOWN_TASKS,
ffmpeg_convert,
normalize_payload,
parse_accept,
Expand Down Expand Up @@ -88,6 +89,18 @@ def already_left(request: Request) -> bool:
async def pipeline_route(request: Request) -> Response:
start = time.time()

task = os.environ["TASK"]

# Shortcut: quickly check the task is in enum: no need to go any further otherwise, as we know for sure that
# normalize_payload will fail below: this avoids us to wait for the pipeline to be loaded to return
if task not in KNOWN_TASKS:
msg = f"The task `{task}` is not recognized by api-inference-community"
logger.error(msg)
# Special case: despite the fact that the task comes from environment (which could be considered a service
# config error, thus triggering a 500), this var indirectly comes from the user
# so we choose to have a 400 here
return JSONResponse({"error": msg}, status_code=400)

if os.getenv("DISCARD_LEFT", "0").lower() in [
"1",
"true",
Expand All @@ -97,16 +110,30 @@ async def pipeline_route(request: Request) -> Response:
return Response(status_code=204)

payload = await request.body()
task = os.environ["TASK"]

if os.getenv("DEBUG", "0") in {"1", "true"}:
pipe = request.app.get_pipeline()

try:
pipe = request.app.get_pipeline()
try:
sampling_rate = pipe.sampling_rate
except Exception:
sampling_rate = None
if task in AUDIO_INPUTS:
msg = f"Sampling rate is expected for model for audio task {task}"
logger.error(msg)
return JSONResponse({"error": msg}, status_code=500)
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)

try:
inputs, params = normalize_payload(payload, task, sampling_rate=sampling_rate)
except EnvironmentError as e:
# Since we catch the environment edge cases earlier above, this should not happen here anymore
# harmless to keep it, just in case
logger.error("Error while parsing input %s", e)
return JSONResponse({"error": str(e)}, status_code=500)
except ValidationError as e:
errors = []
for error in e.errors():
Expand All @@ -120,7 +147,9 @@ async def pipeline_route(request: Request) -> Response:
)
return JSONResponse({"error": errors}, status_code=400)
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
# We assume the payload is bad -> 400
logger.warning("Error while parsing input %s", e)
return JSONResponse({"error": str(e)}, status_code=400)

accept = request.headers.get("accept", "")
lora_adapter = request.headers.get("lora")
Expand Down
1 change: 1 addition & 0 deletions api_inference_community/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def check_inputs(inputs, tag):
"zero-shot-classification",
}

KNOWN_TASKS = AUDIO_INPUTS.union(IMAGE_INPUTS).union(TEXT_INPUTS)

AUDIO = [
"flac",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def startup_event():
self.assertEqual(response.headers["x-compute-characters"], "4")
self.assertEqual(response.content, b'{"some":"json serializable"}')

def test_invalid_pipeline(self):
def test_invalid_task(self):
os.environ["TASK"] = "invalid"

class Pipeline:
Expand Down Expand Up @@ -99,15 +99,15 @@ async def startup_event():

self.assertEqual(
response.status_code,
500,
400,
)
self.assertEqual(
response.content,
b'{"error":"The task `invalid` is not recognized by api-inference-community"}',
)

def test_invalid_task(self):
os.environ["TASK"] = "invalid"
def test_invalid_pipeline(self):
os.environ["TASK"] = "text-generation"

def get_pipeline():
raise Exception("We cannot load the pipeline")
Expand Down
Loading