Skip to content

Commit

Permalink
Fix status code: split pipeline load from input parsing
Browse files Browse the repository at this point in the history
pipeline loading -> 500
input parsing -> 400

Signed-off-by: Raphael Glon <oOraph@users.noreply.github.com>
  • Loading branch information
oOraph committed Oct 7, 2024
1 parent 45907cd commit 15e25cc
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
30 changes: 28 additions & 2 deletions api_inference_community/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
IMAGE,
IMAGE_INPUTS,
IMAGE_OUTPUTS,
KNOWN_TASKS,
ffmpeg_convert,
normalize_payload,
parse_accept,
Expand Down Expand Up @@ -88,6 +89,15 @@ def already_left(request: Request) -> bool:
async def pipeline_route(request: Request) -> Response:
start = time.time()

task = os.environ["TASK"]

# Shortcut: quickly check the task is in enum: no need to go any further otherwise, as we know for sure that
# normalize_payload will fail below: this avoids us to wait for the pipeline to be loaded to return
if task not in KNOWN_TASKS:
msg = f"The task `{task}` is not recognized by api-inference-community"
logger.error(msg)
return JSONResponse({"error": msg}, status_code=400)

if os.getenv("DISCARD_LEFT", "0").lower() in [
"1",
"true",
Expand All @@ -97,16 +107,30 @@ async def pipeline_route(request: Request) -> Response:
return Response(status_code=204)

payload = await request.body()
task = os.environ["TASK"]

if os.getenv("DEBUG", "0") in {"1", "true"}:
pipe = request.app.get_pipeline()

try:
pipe = request.app.get_pipeline()
try:
sampling_rate = pipe.sampling_rate
except Exception:
sampling_rate = None
if task in AUDIO_INPUTS:
msg = f"Sampling rate is expected for model for audio task {task}"
logger.error(msg)
return JSONResponse({"error": msg}, status_code=500)
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)

try:
inputs, params = normalize_payload(payload, task, sampling_rate=sampling_rate)
except EnvironmentError as e:
# Since we catch the environment edge cases earlier above, this should not happen here anymore
# harmless to keep it, just in case
logger.error("Error while parsing input %s", e)
return JSONResponse({"error": str(e)}, status_code=500)
except ValidationError as e:
errors = []
for error in e.errors():
Expand All @@ -120,7 +144,9 @@ async def pipeline_route(request: Request) -> Response:
)
return JSONResponse({"error": errors}, status_code=400)
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
# We assume the payload is bad -> 400
logger.warning("Error while parsing input %s", e)
return JSONResponse({"error": str(e)}, status_code=400)

accept = request.headers.get("accept", "")
lora_adapter = request.headers.get("lora")
Expand Down
1 change: 1 addition & 0 deletions api_inference_community/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def check_inputs(inputs, tag):
"zero-shot-classification",
}

KNOWN_TASKS = AUDIO_INPUTS.union(IMAGE_INPUTS).union(TEXT_INPUTS)

AUDIO = [
"flac",
Expand Down

0 comments on commit 15e25cc

Please sign in to comment.