diff --git a/Cargo.toml b/Cargo.toml index 997cd046..a19013ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ name = "fms-guardrails-orchestr8" path = "src/main.rs" [dependencies] +anyhow = "1.0.83" axum = { version = "0.7.5", features = ["json"] } clap = { version = "4.5.3", features = ["derive", "env"] } futures = "0.3.30" @@ -25,15 +26,15 @@ rustls-webpki = "0.102.2" serde = { version = "1.0.200", features = ["derive"] } serde_json = "1.0.116" serde_yml = "0.0.5" -thiserror = "1.0.59" +thiserror = "1.0.60" tokio = { version = "1.37.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] } tokio-stream = "0.1.14" tonic = { version = "0.11.0", features = ["tls"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] } url = "2.5.0" -validator = { version = "0.18.1", features = ["derive"] } # For API validation uuid = { version = "1.8.0", features = ["v4", "fast-rng"] } +validator = { version = "0.18.1", features = ["derive"] } # For API validation [build-dependencies] tonic-build = "0.11.0" diff --git a/src/clients.rs b/src/clients.rs index 048c6ae3..27e12067 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -1,8 +1,9 @@ #![allow(dead_code)] use std::{collections::HashMap, time::Duration}; -use futures::future::try_join_all; +use futures::future::join_all; use ginepro::LoadBalancedChannel; +use reqwest::StatusCode; use url::Url; use crate::config::{ServiceConfig, Tls}; @@ -26,16 +27,44 @@ pub const DEFAULT_DETECTOR_PORT: u16 = 8080; const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10); +/// Client errors. #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("model not found: {0}")] - ModelNotFound(String), - #[error(transparent)] - ReqwestError(#[from] reqwest::Error), - #[error(transparent)] - TonicError(#[from] tonic::Status), - #[error(transparent)] - IoError(#[from] std::io::Error), + #[error("{}", .0.message())] + Grpc(#[from] tonic::Status), + #[error("{0}")] + Http(#[from] reqwest::Error), + #[error("model not found: {model_id}")] + ModelNotFound { model_id: String }, +} + +impl Error { + /// Returns status code. + pub fn status_code(&self) -> StatusCode { + use tonic::Code::*; + match self { + // Return equivalent http status code for grpc status code + Error::Grpc(error) => match error.code() { + InvalidArgument => StatusCode::BAD_REQUEST, + Internal => StatusCode::INTERNAL_SERVER_ERROR, + NotFound => StatusCode::NOT_FOUND, + DeadlineExceeded => StatusCode::REQUEST_TIMEOUT, + Unimplemented => StatusCode::NOT_IMPLEMENTED, + Unauthenticated => StatusCode::UNAUTHORIZED, + PermissionDenied => StatusCode::FORBIDDEN, + Ok => StatusCode::OK, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }, + // Return http status code for error responses + // and 500 for other errors + Error::Http(error) => match error.status() { + Some(code) => code, + None => StatusCode::INTERNAL_SERVER_ERROR, + }, + // Return 404 for model not found + Error::ModelNotFound { .. } => StatusCode::NOT_FOUND, + } + } } #[derive(Clone)] @@ -71,7 +100,7 @@ impl std::ops::Deref for HttpClient { pub async fn create_http_clients( default_port: u16, config: &[(String, ServiceConfig)], -) -> Result, Error> { +) -> HashMap { let clients = config .iter() .map(|(name, service_config)| async move { @@ -86,22 +115,25 @@ pub async fn create_http_clients( let cert_pem = tokio::fs::read(cert_path).await.unwrap_or_else(|error| { panic!("error reading cert from {cert_path:?}: {error}") }); - let identity = reqwest::Identity::from_pem(&cert_pem)?; + let identity = reqwest::Identity::from_pem(&cert_pem) + .unwrap_or_else(|error| panic!("error parsing cert: {error}")); builder = builder.use_rustls_tls().identity(identity); } - let client = builder.build()?; + let client = builder + .build() + .unwrap_or_else(|error| panic!("error creating http client for {name}: {error}")); let client = HttpClient::new(base_url, client); - Ok((name.clone(), client)) as Result<(String, HttpClient), Error> + (name.clone(), client) }) .collect::>(); - Ok(try_join_all(clients).await?.into_iter().collect()) + join_all(clients).await.into_iter().collect() } async fn create_grpc_clients( default_port: u16, config: &[(String, ServiceConfig)], new: fn(LoadBalancedChannel) -> C, -) -> Result, Error> { +) -> HashMap { let clients = config .iter() .map(|(name, service_config)| async move { @@ -140,9 +172,9 @@ async fn create_grpc_clients( if let Some(client_tls_config) = client_tls_config { builder = builder.with_tls(client_tls_config); } - let channel = builder.channel().await.unwrap(); // TODO: handle error - Ok((name.clone(), new(channel))) as Result<(String, C), Error> + let channel = builder.channel().await.unwrap_or_else(|error| panic!("error creating grpc client for {name}: {error}")); + (name.clone(), new(channel)) }) .collect::>(); - Ok(try_join_all(clients).await?.into_iter().collect()) + join_all(clients).await.into_iter().collect() } diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index 57e3a908..493ad8e6 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -26,16 +26,18 @@ pub struct ChunkerClient { } impl ChunkerClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result { - let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await?; - Ok(Self { clients }) + pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { + let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await; + Self { clients } } fn client(&self, model_id: &str) -> Result, Error> { Ok(self .clients .get(model_id) - .ok_or_else(|| Error::ModelNotFound(model_id.into()))? + .ok_or_else(|| Error::ModelNotFound { + model_id: model_id.to_string(), + })? .clone()) } diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 9d1068ed..6d684e70 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -13,17 +13,18 @@ pub struct DetectorClient { } impl DetectorClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result { - let clients: HashMap = - create_http_clients(default_port, config).await?; - Ok(Self { clients }) + pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { + let clients: HashMap = create_http_clients(default_port, config).await; + Self { clients } } fn client(&self, model_id: &str) -> Result { Ok(self .clients .get(model_id) - .ok_or_else(|| Error::ModelNotFound(model_id.into()))? + .ok_or_else(|| Error::ModelNotFound { + model_id: model_id.to_string(), + })? .clone()) } diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index 49d49573..bfc3499f 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -29,16 +29,18 @@ pub struct NlpClient { } impl NlpClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result { - let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await?; - Ok(Self { clients }) + pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { + let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await; + Self { clients } } fn client(&self, model_id: &str) -> Result, Error> { Ok(self .clients .get(model_id) - .ok_or_else(|| Error::ModelNotFound(model_id.into()))? + .ok_or_else(|| Error::ModelNotFound { + model_id: model_id.to_string(), + })? .clone()) } diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index a4306e11..b1dd3942 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -21,10 +21,9 @@ pub struct TgisClient { } impl TgisClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result { - let clients = - create_grpc_clients(default_port, config, GenerationServiceClient::new).await?; - Ok(Self { clients }) + pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { + let clients = create_grpc_clients(default_port, config, GenerationServiceClient::new).await; + Self { clients } } fn client( @@ -36,7 +35,9 @@ impl TgisClient { Ok(self .clients .get(model_id) - .ok_or_else(|| Error::ModelNotFound(model_id.into()))? + .ok_or_else(|| Error::ModelNotFound { + model_id: model_id.to_string(), + })? .clone()) } diff --git a/src/config.rs b/src/config.rs index 7161544f..95c42baf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -108,8 +108,10 @@ impl OrchestratorConfig { todo!() } - pub fn get_chunker_id(&self, detector_id: &str) -> String { - self.detectors.get(detector_id).unwrap().chunker_id.clone() + pub fn get_chunker_id(&self, detector_id: &str) -> Option { + self.detectors + .get(detector_id) + .map(|detector_config| detector_config.chunker_id.clone()) } } @@ -126,8 +128,9 @@ fn service_tls_name_to_config( #[cfg(test)] mod tests { + use anyhow::Error; + use super::*; - use crate::Error; #[test] fn test_deserialize_config() -> Result<(), Error> { diff --git a/src/lib.rs b/src/lib.rs index 2167701d..337bfd89 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,43 +1,8 @@ #![allow(clippy::iter_kv_map, clippy::enum_variant_names)] -use axum::{http::StatusCode, Json}; - mod clients; mod config; mod models; mod orchestrator; mod pb; pub mod server; - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error(transparent)] - ClientError(#[from] crate::clients::Error), - #[error(transparent)] - IoError(#[from] std::io::Error), - #[error(transparent)] - YamlError(#[from] serde_yml::Error), -} - -// TODO: create better errors and properly convert -impl From for (StatusCode, Json) { - fn from(value: Error) -> Self { - use Error::*; - match value { - ClientError(error) => match error { - clients::Error::ModelNotFound(message) => { - (StatusCode::UNPROCESSABLE_ENTITY, Json(message)) - } - clients::Error::ReqwestError(error) => { - (StatusCode::INTERNAL_SERVER_ERROR, Json(error.to_string())) - } - clients::Error::TonicError(error) => { - (StatusCode::INTERNAL_SERVER_ERROR, Json(error.to_string())) - } - clients::Error::IoError(_) => todo!(), - }, - IoError(_) => todo!(), - YamlError(_) => todo!(), - } - } -} diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 03c74bbc..4f6ec473 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -1,11 +1,13 @@ +pub mod errors; use std::{collections::HashMap, sync::Arc}; +pub use errors::Error; use futures::{ future::try_join_all, stream::{self, StreamExt}, }; use tokio_stream::wrappers::ReceiverStream; -use tracing::{debug, info}; +use tracing::{debug, error, info}; use uuid::Uuid; use crate::{ @@ -20,13 +22,14 @@ use crate::{ InputWarningReason, TextGenTokenClassificationResults, TokenClassificationResult, }, pb::{ - caikit::runtime::chunkers::TokenizationTaskRequest as ChunkersTokenizationTaskRequest, - caikit::runtime::nlp::{TextGenerationTaskRequest, TokenizationTaskRequest}, + caikit::runtime::{ + chunkers::TokenizationTaskRequest as ChunkersTokenizationTaskRequest, + nlp::{TextGenerationTaskRequest, TokenizationTaskRequest}, + }, fmaas::{ BatchedGenerationRequest, BatchedTokenizeRequest, GenerationRequest, TokenizeRequest, }, }, - Error, }; const UNSUITABLE_INPUT_MESSAGE: &str = "Unsuitable input detected. \ @@ -47,7 +50,7 @@ pub struct Orchestrator { impl Orchestrator { pub async fn new(config: OrchestratorConfig) -> Result { - let (generation_client, chunker_client, detector_client) = create_clients(&config).await?; + let (generation_client, chunker_client, detector_client) = create_clients(&config).await; let ctx = Arc::new(Context { config, generation_client, @@ -69,7 +72,7 @@ impl Orchestrator { "handling unary task" ); let ctx = self.ctx.clone(); - tokio::spawn(async move { + let task_handle = tokio::spawn(async move { let masks = task.guardrails_config.input_masks(); let input_detectors = task.guardrails_config.input_detectors(); let output_detectors = task.guardrails_config.output_detectors(); @@ -137,9 +140,22 @@ impl Orchestrator { } Ok(generation_results) } - }) - .await - .unwrap() + }); + match task_handle.await { + // Task completed successfully + Ok(Ok(result)) => Ok(result), + // Task failed, return error propagated from child task that failed + Ok(Err(error)) => { + error!(request_id = ?task.request_id, %error, "unary task failed"); + Err(error) + } + // Task cancelled or panicked + Err(error) => { + let error = error.into(); + error!(request_id = ?task.request_id, %error, "unary task failed"); + Err(error) + } + } } /// Handles streaming tasks. @@ -164,7 +180,6 @@ async fn chunk_and_detect( text: String, masks: Option<&[(usize, usize)]>, ) -> Result, Error> { - // TODO: propogate errors // Apply masks let text_with_offsets = masks .map(|masks| apply_masks(&text, masks)) @@ -172,8 +187,16 @@ async fn chunk_and_detect( // Create a list of required chunkers let chunker_ids = detectors .keys() - .map(|detector_id| ctx.config.get_chunker_id(detector_id)) - .collect::>(); + .map(|detector_id| { + let chunker_id = + ctx.config + .get_chunker_id(detector_id) + .ok_or_else(|| Error::DetectorNotFound { + detector_id: detector_id.clone(), + })?; + Ok::(chunker_id) + }) + .collect::, Error>>()?; // Spawn chunking tasks, returning a map of chunker_id->chunks. let chunks = chunk(ctx.clone(), chunker_ids, text_with_offsets).await?; // Spawn detection tasks @@ -187,24 +210,18 @@ async fn chunk( chunker_ids: Vec, text_with_offsets: Vec<(usize, String)>, ) -> Result>, Error> { - // TODO: propogate errors let tasks = chunker_ids .into_iter() .map(|chunker_id| { let ctx = ctx.clone(); let text_with_offsets = text_with_offsets.clone(); - tokio::spawn(async move { - handle_chunk_task(ctx, chunker_id, text_with_offsets) - .await - .unwrap() - }) + tokio::spawn(async move { handle_chunk_task(ctx, chunker_id, text_with_offsets).await }) }) .collect::>(); let results = try_join_all(tasks) - .await - .unwrap() + .await? .into_iter() - .collect::>(); + .collect::, Error>>()?; Ok(results) } @@ -220,18 +237,22 @@ async fn detect( let ctx = ctx.clone(); let detector_id = detector_id.clone(); let detector_params = detector_params.clone(); - let chunker_id = ctx.config.get_chunker_id(&detector_id); + let chunker_id = + ctx.config + .get_chunker_id(&detector_id) + .ok_or_else(|| Error::DetectorNotFound { + detector_id: detector_id.clone(), + })?; let chunks = chunks.get(&chunker_id).unwrap().clone(); - tokio::spawn(async move { - handle_detection_task(ctx, detector_id, detector_params, chunks) - .await - .unwrap() - }) + Ok(tokio::spawn(async move { + handle_detection_task(ctx, detector_id, detector_params, chunks).await + })) }) - .collect::>(); + .collect::, Error>>()?; let results = try_join_all(tasks) - .await - .unwrap() + .await? + .into_iter() + .collect::, Error>>()? .into_iter() .flatten() .collect::>(); @@ -244,7 +265,6 @@ async fn handle_chunk_task( chunker_id: String, text_with_offsets: Vec<(usize, String)>, ) -> Result<(String, Vec), Error> { - // TODO: propogate errors let chunks = stream::iter(text_with_offsets) .map(|(offset, text)| { let ctx = ctx.clone(); @@ -256,23 +276,36 @@ async fn handle_chunk_task( ?request, "sending chunker request" ); - ctx.chunker_client + let response = ctx + .chunker_client .tokenization_task_predict(&chunker_id, request) .await - .unwrap() + .map_err(|error| Error::ChunkerRequestFailed { + chunker_id: chunker_id.clone(), + error, + })?; + debug!( + %chunker_id, + ?response, + "received chunker response" + ); + let results = response .results .into_iter() .map(|token| Chunk { offset, text: token.text, }) - .collect::>() + .collect::>(); + Ok::, Error>(results) } }) .buffer_unordered(5) .collect::>() .await .into_iter() + .collect::, Error>>()? + .into_iter() .flatten() .collect::>(); Ok((chunker_id, chunks)) @@ -285,7 +318,6 @@ async fn handle_detection_task( detector_params: DetectorParams, chunks: Vec, ) -> Result, Error> { - // TODO: propogate errors let detections = stream::iter(chunks) .map(|chunk| { let ctx = ctx.clone(); @@ -302,8 +334,16 @@ async fn handle_detection_task( .detector_client .classify(&detector_id, request) .await - .unwrap(); - response + .map_err(|error| Error::DetectorRequestFailed { + detector_id: detector_id.clone(), + error, + })?; + debug!( + %detector_id, + ?response, + "received detector response" + ); + let results = response .detections .into_iter() .map(|detection| { @@ -312,13 +352,16 @@ async fn handle_detection_task( result.end += chunk.offset as i32; result }) - .collect::>() + .collect::>(); + Ok::, Error>(results) } }) .buffer_unordered(5) .collect::>() .await .into_iter() + .collect::, Error>>()? + .into_iter() .flatten() .collect::>(); Ok(detections) @@ -346,7 +389,20 @@ async fn tokenize( ?request, "sending tokenize request" ); - let mut response = client.tokenize(request).await?; + let mut response = + client + .tokenize(request) + .await + .map_err(|error| Error::TokenizeRequestFailed { + model_id: model_id.clone(), + error, + })?; + debug!( + %model_id, + provider = "tgis", + ?response, + "received tokenize response" + ); let response = response.responses.swap_remove(0); Ok((response.token_count, response.tokens)) } @@ -358,7 +414,19 @@ async fn tokenize( ?request, "sending tokenize request" ); - let response = client.tokenization_task_predict(&model_id, request).await?; + let response = client + .tokenization_task_predict(&model_id, request) + .await + .map_err(|error| Error::TokenizeRequestFailed { + model_id: model_id.clone(), + error, + })?; + debug!( + %model_id, + provider = "nlp", + ?response, + "received tokenize response" + ); let tokens = response .results .into_iter() @@ -392,7 +460,20 @@ async fn generate( ?request, "sending generate request" ); - let mut response = client.generate(request).await?; + let mut response = + client + .generate(request) + .await + .map_err(|error| Error::GenerateRequestFailed { + model_id: model_id.clone(), + error, + })?; + debug!( + %model_id, + provider = "tgis", + ?response, + "received generate response" + ); let response = response.responses.swap_remove(0); Ok(ClassifiedGeneratedTextResult { generated_text: Some(response.text.clone()), @@ -456,7 +537,17 @@ async fn generate( ); let response = client .text_generation_task_predict(&model_id, request) - .await?; + .await + .map_err(|error| Error::GenerateRequestFailed { + model_id: model_id.clone(), + error, + })?; + debug!( + %model_id, + provider = "nlp", + ?response, + "received generate response" + ); Ok(ClassifiedGeneratedTextResult { generated_text: Some(response.generated_text.clone()), finish_reason: Some(response.finish_reason().into()), @@ -497,7 +588,7 @@ fn apply_masks(text: &str, masks: &[(usize, usize)]) -> Vec<(usize, String)> { async fn create_clients( config: &OrchestratorConfig, -) -> Result<(GenerationClient, ChunkerClient, DetectorClient), Error> { +) -> (GenerationClient, ChunkerClient, DetectorClient) { // TODO: create better solution for routers let generation_client = match config.generation.provider { GenerationProvider::Tgis => { @@ -505,7 +596,7 @@ async fn create_clients( clients::DEFAULT_TGIS_PORT, &[("tgis-router".to_string(), config.generation.service.clone())], ) - .await?; + .await; GenerationClient::Tgis(client) } GenerationProvider::Nlp => { @@ -513,7 +604,7 @@ async fn create_clients( clients::DEFAULT_CAIKIT_NLP_PORT, &[("tgis-router".to_string(), config.generation.service.clone())], ) - .await?; + .await; GenerationClient::Nlp(client) } }; @@ -523,7 +614,7 @@ async fn create_clients( .iter() .map(|(chunker_id, config)| (chunker_id.clone(), config.service.clone())) .collect::>(); - let chunker_client = ChunkerClient::new(clients::DEFAULT_CHUNKER_PORT, &chunker_config).await?; + let chunker_client = ChunkerClient::new(clients::DEFAULT_CHUNKER_PORT, &chunker_config).await; let detector_config = config .detectors @@ -531,9 +622,9 @@ async fn create_clients( .map(|(detector_id, config)| (detector_id.clone(), config.service.clone())) .collect::>(); let detector_client = - DetectorClient::new(clients::DEFAULT_DETECTOR_PORT, &detector_config).await?; + DetectorClient::new(clients::DEFAULT_DETECTOR_PORT, &detector_config).await; - Ok((generation_client, chunker_client, detector_client)) + (generation_client, chunker_client, detector_client) } #[derive(Debug, Clone)] diff --git a/src/orchestrator/errors.rs b/src/orchestrator/errors.rs new file mode 100644 index 00000000..b20d3b89 --- /dev/null +++ b/src/orchestrator/errors.rs @@ -0,0 +1,42 @@ +use crate::clients; + +/// Orchestrator errors. +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("detector not found: {detector_id}")] + DetectorNotFound { detector_id: String }, + #[error("detector request failed for detector_id={detector_id}: {error}")] + DetectorRequestFailed { + detector_id: String, + error: clients::Error, + }, + #[error("chunker request failed for chunker_id={chunker_id}: {error}")] + ChunkerRequestFailed { + chunker_id: String, + error: clients::Error, + }, + #[error("generate request failed for model_id={model_id}: {error}")] + GenerateRequestFailed { + model_id: String, + error: clients::Error, + }, + #[error("tokenize request failed for model_id={model_id}: {error}")] + TokenizeRequestFailed { + model_id: String, + error: clients::Error, + }, + #[error("task cancelled")] + Cancelled, + #[error("{0}")] + Other(String), +} + +impl From for Error { + fn from(error: tokio::task::JoinError) -> Self { + if error.is_cancelled() { + Self::Cancelled + } else { + Self::Other(format!("task panicked: {error}")) + } + } +} diff --git a/src/server.rs b/src/server.rs index a05e4bad..33ce5c90 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,21 +5,22 @@ use axum::{ http::StatusCode, response::{ sse::{Event, KeepAlive, Sse}, - IntoResponse, + IntoResponse, Response, }, routing::{get, post}, Json, Router, }; use futures::StreamExt; use tokio::{net::TcpListener, signal}; -use tracing::info; +use tracing::{error, info}; use uuid::Uuid; use crate::{ config::OrchestratorConfig, models, - orchestrator::{ClassificationWithGenTask, Orchestrator, StreamingClassificationWithGenTask}, - Error, + orchestrator::{ + self, ClassificationWithGenTask, Orchestrator, StreamingClassificationWithGenTask, + }, }; const API_PREFIX: &str = r#"/api/v1/task"#; @@ -59,7 +60,9 @@ pub async fn run( ) .with_state(shared_state); - let listener = TcpListener::bind(&http_addr).await?; + let listener = TcpListener::bind(&http_addr) + .await + .unwrap_or_else(|_| panic!("failed to bind to {http_addr}")); let server = axum::serve(listener, app.into_make_service()).with_graceful_shutdown(shutdown_signal()); @@ -77,7 +80,7 @@ async fn health() -> Result<(), ()> { async fn classification_with_gen( State(state): State>, Json(request): Json, -) -> Result)> { +) -> Result { let request_id = Uuid::new_v4(); let task = ClassificationWithGenTask::new(request_id, request); match state @@ -93,7 +96,7 @@ async fn classification_with_gen( async fn stream_classification_with_gen( State(state): State>, Json(request): Json, -) -> Result)> { +) -> Result { let request_id = Uuid::new_v4(); let task = StreamingClassificationWithGenTask::new(request_id, request); let response_stream = state @@ -131,3 +134,50 @@ async fn shutdown_signal() { info!("signal received, starting graceful shutdown"); } + +/// High-level errors to return to clients. +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("{0}")] + Validation(String), + #[error("{0}")] + NotFound(String), + #[error("unexpected error occured while processing request")] + Unexpected, +} + +impl From for Error { + fn from(error: orchestrator::Error) -> Self { + use orchestrator::Error::*; + match error { + DetectorNotFound { .. } => Self::NotFound(error.to_string()), + DetectorRequestFailed { error, .. } + | ChunkerRequestFailed { error, .. } + | GenerateRequestFailed { error, .. } + | TokenizeRequestFailed { error, .. } => match error.status_code() { + StatusCode::BAD_REQUEST | StatusCode::UNPROCESSABLE_ENTITY => { + Self::Validation(error.to_string()) + } + StatusCode::NOT_FOUND => Self::NotFound(error.to_string()), + _ => Self::Unexpected, + }, + _ => Self::Unexpected, + } + } +} + +impl IntoResponse for Error { + fn into_response(self) -> Response { + use Error::*; + let (code, message) = match self { + Validation(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), + NotFound(_) => (StatusCode::NOT_FOUND, self.to_string()), + Unexpected => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), + }; + let error = serde_json::json!({ + "code": code.as_u16(), + "details": message, + }); + (code, Json(error)).into_response() + } +}