diff --git a/api_inference_community/routes.py b/api_inference_community/routes.py index 347cf8ea..9107cb5a 100644 --- a/api_inference_community/routes.py +++ b/api_inference_community/routes.py @@ -95,18 +95,19 @@ async def pipeline_route(request: Request) -> Response: ] and already_left(request): logger.info("Discarding request as the caller already left") return Response(status_code=204) - payload = await request.body() + task = os.environ["TASK"] + if os.getenv("DEBUG", "0") in {"1", "true"}: pipe = request.app.get_pipeline() + try: pipe = request.app.get_pipeline() try: sampling_rate = pipe.sampling_rate except Exception: sampling_rate = None - inputs, params = normalize_payload(payload, task, sampling_rate=sampling_rate) except ValidationError as e: errors = [] for error in e.errors(): @@ -122,6 +123,16 @@ async def pipeline_route(request: Request) -> Response: except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) + try: + inputs, params = normalize_payload(payload, task, sampling_rate=sampling_rate) + except EnvironmentError as e: + logger.error("Error while parsing input %s", e) + return JSONResponse({"error": str(e)}, status_code=500) + except Exception as e: + # We assume the payload is bad -> 400 + logger.warning("Error while parsing input %s", e) + return JSONResponse({"error": str(e)}, status_code=400) + accept = request.headers.get("accept", "") lora_adapter = request.headers.get("lora") if lora_adapter: