From 5c5346e9180b80c2329e159dca14984edefef0a1 Mon Sep 17 00:00:00 2001 From: declark1 Date: Mon, 13 May 2024 09:07:00 -0700 Subject: [PATCH 1/5] Implement error handling Signed-off-by: declark1 --- Cargo.toml | 2 +- src/clients.rs | 21 +++---- src/clients/chunker.rs | 5 +- src/clients/detector.rs | 5 +- src/clients/nlp.rs | 5 +- src/clients/tgis.rs | 5 +- src/config.rs | 9 ++- src/lib.rs | 35 ------------ src/orchestrator.rs | 120 ++++++++++++++++++++++++++++------------ src/server.rs | 37 +++++++++++-- 10 files changed, 142 insertions(+), 102 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 997cd046..9672fe1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,6 @@ rustls-webpki = "0.102.2" serde = { version = "1.0.200", features = ["derive"] } serde_json = "1.0.116" serde_yml = "0.0.5" -thiserror = "1.0.59" tokio = { version = "1.37.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] } tokio-stream = "0.1.14" tonic = { version = "0.11.0", features = ["tls"] } @@ -34,6 +33,7 @@ tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] } url = "2.5.0" validator = { version = "0.18.1", features = ["derive"] } # For API validation uuid = { version = "1.8.0", features = ["v4", "fast-rng"] } +anyhow = "1.0.83" [build-dependencies] tonic-build = "0.11.0" diff --git a/src/clients.rs b/src/clients.rs index 048c6ae3..6f08c059 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -1,6 +1,7 @@ #![allow(dead_code)] use std::{collections::HashMap, time::Duration}; +use anyhow::{Context, Error}; use futures::future::try_join_all; use ginepro::LoadBalancedChannel; use url::Url; @@ -26,18 +27,6 @@ pub const DEFAULT_DETECTOR_PORT: u16 = 8080; const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10); -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("model not found: {0}")] - ModelNotFound(String), - #[error(transparent)] - ReqwestError(#[from] reqwest::Error), - #[error(transparent)] - TonicError(#[from] tonic::Status), - #[error(transparent)] - IoError(#[from] std::io::Error), -} - #[derive(Clone)] pub enum GenerationClient { Tgis(TgisClient), @@ -89,7 +78,11 @@ pub async fn create_http_clients( let identity = reqwest::Identity::from_pem(&cert_pem)?; builder = builder.use_rustls_tls().identity(identity); } - let client = builder.build()?; + let client = builder.build().with_context(|| { + format!( + "error creating http client, name={name}, service_config={service_config:?}" + ) + })?; let client = HttpClient::new(base_url, client); Ok((name.clone(), client)) as Result<(String, HttpClient), Error> }) @@ -140,7 +133,7 @@ async fn create_grpc_clients( if let Some(client_tls_config) = client_tls_config { builder = builder.with_tls(client_tls_config); } - let channel = builder.channel().await.unwrap(); // TODO: handle error + let channel = builder.channel().await.with_context(|| format!("error creating grpc client, name={name}, service_config={service_config:?}"))?; Ok((name.clone(), new(channel))) as Result<(String, C), Error> }) .collect::>(); diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index 57e3a908..67959fc9 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -1,12 +1,13 @@ use std::{collections::HashMap, pin::Pin}; +use anyhow::{Context, Error}; use futures::{Stream, StreamExt}; use ginepro::LoadBalancedChannel; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tonic::Request; -use super::{create_grpc_clients, Error}; +use super::create_grpc_clients; use crate::{ config::ServiceConfig, pb::{ @@ -35,7 +36,7 @@ impl ChunkerClient { Ok(self .clients .get(model_id) - .ok_or_else(|| Error::ModelNotFound(model_id.into()))? + .context(format!("model not found, model_id={model_id}"))? .clone()) } diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 9d1068ed..4fde9d74 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -1,8 +1,9 @@ use std::collections::HashMap; +use anyhow::{Context, Error}; use serde::{Deserialize, Serialize}; -use super::{create_http_clients, Error, HttpClient}; +use super::{create_http_clients, HttpClient}; use crate::config::ServiceConfig; const DETECTOR_ID_HEADER_NAME: &str = "detector-id"; @@ -23,7 +24,7 @@ impl DetectorClient { Ok(self .clients .get(model_id) - .ok_or_else(|| Error::ModelNotFound(model_id.into()))? + .context(format!("model not found, model_id={model_id}"))? .clone()) } diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index 49d49573..61965e25 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -1,12 +1,13 @@ use std::collections::HashMap; +use anyhow::{Context, Error}; use futures::StreamExt; use ginepro::LoadBalancedChannel; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tonic::Request; -use super::{create_grpc_clients, Error}; +use super::create_grpc_clients; use crate::{ config::ServiceConfig, pb::{ @@ -38,7 +39,7 @@ impl NlpClient { Ok(self .clients .get(model_id) - .ok_or_else(|| Error::ModelNotFound(model_id.into()))? + .context(format!("model not found, model_id={model_id}"))? .clone()) } diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index a4306e11..fa9b3b59 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -1,11 +1,12 @@ use std::collections::HashMap; +use anyhow::{Context, Error}; use futures::StreamExt; use ginepro::LoadBalancedChannel; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; -use super::{create_grpc_clients, Error}; +use super::create_grpc_clients; use crate::{ config::ServiceConfig, pb::fmaas::{ @@ -36,7 +37,7 @@ impl TgisClient { Ok(self .clients .get(model_id) - .ok_or_else(|| Error::ModelNotFound(model_id.into()))? + .context(format!("model not found, model_id={model_id}"))? .clone()) } diff --git a/src/config.rs b/src/config.rs index 7161544f..95c42baf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -108,8 +108,10 @@ impl OrchestratorConfig { todo!() } - pub fn get_chunker_id(&self, detector_id: &str) -> String { - self.detectors.get(detector_id).unwrap().chunker_id.clone() + pub fn get_chunker_id(&self, detector_id: &str) -> Option { + self.detectors + .get(detector_id) + .map(|detector_config| detector_config.chunker_id.clone()) } } @@ -126,8 +128,9 @@ fn service_tls_name_to_config( #[cfg(test)] mod tests { + use anyhow::Error; + use super::*; - use crate::Error; #[test] fn test_deserialize_config() -> Result<(), Error> { diff --git a/src/lib.rs b/src/lib.rs index 2167701d..337bfd89 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,43 +1,8 @@ #![allow(clippy::iter_kv_map, clippy::enum_variant_names)] -use axum::{http::StatusCode, Json}; - mod clients; mod config; mod models; mod orchestrator; mod pb; pub mod server; - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error(transparent)] - ClientError(#[from] crate::clients::Error), - #[error(transparent)] - IoError(#[from] std::io::Error), - #[error(transparent)] - YamlError(#[from] serde_yml::Error), -} - -// TODO: create better errors and properly convert -impl From for (StatusCode, Json) { - fn from(value: Error) -> Self { - use Error::*; - match value { - ClientError(error) => match error { - clients::Error::ModelNotFound(message) => { - (StatusCode::UNPROCESSABLE_ENTITY, Json(message)) - } - clients::Error::ReqwestError(error) => { - (StatusCode::INTERNAL_SERVER_ERROR, Json(error.to_string())) - } - clients::Error::TonicError(error) => { - (StatusCode::INTERNAL_SERVER_ERROR, Json(error.to_string())) - } - clients::Error::IoError(_) => todo!(), - }, - IoError(_) => todo!(), - YamlError(_) => todo!(), - } - } -} diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 03c74bbc..c6f5dbb1 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -1,5 +1,6 @@ use std::{collections::HashMap, sync::Arc}; +use anyhow::{Context as _, Error}; use futures::{ future::try_join_all, stream::{self, StreamExt}, @@ -26,7 +27,6 @@ use crate::{ BatchedGenerationRequest, BatchedTokenizeRequest, GenerationRequest, TokenizeRequest, }, }, - Error, }; const UNSUITABLE_INPUT_MESSAGE: &str = "Unsuitable input detected. \ @@ -138,8 +138,7 @@ impl Orchestrator { Ok(generation_results) } }) - .await - .unwrap() + .await? } /// Handles streaming tasks. @@ -164,7 +163,6 @@ async fn chunk_and_detect( text: String, masks: Option<&[(usize, usize)]>, ) -> Result, Error> { - // TODO: propogate errors // Apply masks let text_with_offsets = masks .map(|masks| apply_masks(&text, masks)) @@ -172,8 +170,14 @@ async fn chunk_and_detect( // Create a list of required chunkers let chunker_ids = detectors .keys() - .map(|detector_id| ctx.config.get_chunker_id(detector_id)) - .collect::>(); + .map(|detector_id| { + let chunker_id = ctx + .config + .get_chunker_id(detector_id) + .context(format!("detector not found, detector_id={detector_id}"))?; + Ok::(chunker_id) + }) + .collect::, Error>>()?; // Spawn chunking tasks, returning a map of chunker_id->chunks. let chunks = chunk(ctx.clone(), chunker_ids, text_with_offsets).await?; // Spawn detection tasks @@ -187,24 +191,18 @@ async fn chunk( chunker_ids: Vec, text_with_offsets: Vec<(usize, String)>, ) -> Result>, Error> { - // TODO: propogate errors let tasks = chunker_ids .into_iter() .map(|chunker_id| { let ctx = ctx.clone(); let text_with_offsets = text_with_offsets.clone(); - tokio::spawn(async move { - handle_chunk_task(ctx, chunker_id, text_with_offsets) - .await - .unwrap() - }) + tokio::spawn(async move { handle_chunk_task(ctx, chunker_id, text_with_offsets).await }) }) .collect::>(); let results = try_join_all(tasks) - .await - .unwrap() + .await? .into_iter() - .collect::>(); + .collect::, Error>>()?; Ok(results) } @@ -220,18 +218,20 @@ async fn detect( let ctx = ctx.clone(); let detector_id = detector_id.clone(); let detector_params = detector_params.clone(); - let chunker_id = ctx.config.get_chunker_id(&detector_id); + let chunker_id = ctx + .config + .get_chunker_id(&detector_id) + .context(format!("detector not found, detector_id={detector_id}"))?; let chunks = chunks.get(&chunker_id).unwrap().clone(); - tokio::spawn(async move { - handle_detection_task(ctx, detector_id, detector_params, chunks) - .await - .unwrap() - }) + Ok(tokio::spawn(async move { + handle_detection_task(ctx, detector_id, detector_params, chunks).await + })) }) - .collect::>(); + .collect::, Error>>()?; let results = try_join_all(tasks) - .await - .unwrap() + .await? + .into_iter() + .collect::, Error>>()? .into_iter() .flatten() .collect::>(); @@ -244,7 +244,6 @@ async fn handle_chunk_task( chunker_id: String, text_with_offsets: Vec<(usize, String)>, ) -> Result<(String, Vec), Error> { - // TODO: propogate errors let chunks = stream::iter(text_with_offsets) .map(|(offset, text)| { let ctx = ctx.clone(); @@ -256,23 +255,33 @@ async fn handle_chunk_task( ?request, "sending chunker request" ); - ctx.chunker_client + let response = ctx + .chunker_client .tokenization_task_predict(&chunker_id, request) .await - .unwrap() + .with_context(|| format!("chunker request failed, chunker_id={chunker_id}"))?; + debug!( + %chunker_id, + ?response, + "received chunker response" + ); + let results = response .results .into_iter() .map(|token| Chunk { offset, text: token.text, }) - .collect::>() + .collect::>(); + Ok::, Error>(results) } }) .buffer_unordered(5) .collect::>() .await .into_iter() + .collect::, Error>>()? + .into_iter() .flatten() .collect::>(); Ok((chunker_id, chunks)) @@ -285,7 +294,6 @@ async fn handle_detection_task( detector_params: DetectorParams, chunks: Vec, ) -> Result, Error> { - // TODO: propogate errors let detections = stream::iter(chunks) .map(|chunk| { let ctx = ctx.clone(); @@ -302,8 +310,15 @@ async fn handle_detection_task( .detector_client .classify(&detector_id, request) .await - .unwrap(); - response + .with_context(|| { + format!("detector request failed, detector_id={detector_id}") + })?; + debug!( + %detector_id, + ?response, + "received detector response" + ); + let results = response .detections .into_iter() .map(|detection| { @@ -312,13 +327,16 @@ async fn handle_detection_task( result.end += chunk.offset as i32; result }) - .collect::>() + .collect::>(); + Ok::, Error>(results) } }) .buffer_unordered(5) .collect::>() .await .into_iter() + .collect::, Error>>()? + .into_iter() .flatten() .collect::>(); Ok(detections) @@ -346,7 +364,16 @@ async fn tokenize( ?request, "sending tokenize request" ); - let mut response = client.tokenize(request).await?; + let mut response = client + .tokenize(request) + .await + .with_context(|| format!("tokenize request failed, model_id={model_id}"))?; + debug!( + %model_id, + provider = "tgis", + ?response, + "received tokenize response" + ); let response = response.responses.swap_remove(0); Ok((response.token_count, response.tokens)) } @@ -359,6 +386,12 @@ async fn tokenize( "sending tokenize request" ); let response = client.tokenization_task_predict(&model_id, request).await?; + debug!( + %model_id, + provider = "nlp", + ?response, + "received tokenize response" + ); let tokens = response .results .into_iter() @@ -392,7 +425,15 @@ async fn generate( ?request, "sending generate request" ); - let mut response = client.generate(request).await?; + let mut response = client.generate(request).await.with_context(|| { + format!("generate request failed, model_id={model_id}, provider=tgis") + })?; + debug!( + %model_id, + provider = "tgis", + ?response, + "received generate response" + ); let response = response.responses.swap_remove(0); Ok(ClassifiedGeneratedTextResult { generated_text: Some(response.text.clone()), @@ -456,7 +497,16 @@ async fn generate( ); let response = client .text_generation_task_predict(&model_id, request) - .await?; + .await + .with_context(|| { + format!("generate request failed, model_id={model_id}, provider=nlp") + })?; + debug!( + %model_id, + provider = "nlp", + ?response, + "received generate response" + ); Ok(ClassifiedGeneratedTextResult { generated_text: Some(response.generated_text.clone()), finish_reason: Some(response.finish_reason().into()), diff --git a/src/server.rs b/src/server.rs index a05e4bad..68c3e5fe 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,25 +1,25 @@ use std::{net::SocketAddr, path::PathBuf, sync::Arc}; +use anyhow::Error; use axum::{ extract::State, http::StatusCode, response::{ sse::{Event, KeepAlive, Sse}, - IntoResponse, + IntoResponse, Response, }, routing::{get, post}, Json, Router, }; use futures::StreamExt; use tokio::{net::TcpListener, signal}; -use tracing::info; +use tracing::{error, info}; use uuid::Uuid; use crate::{ config::OrchestratorConfig, models, orchestrator::{ClassificationWithGenTask, Orchestrator, StreamingClassificationWithGenTask}, - Error, }; const API_PREFIX: &str = r#"/api/v1/task"#; @@ -77,7 +77,7 @@ async fn health() -> Result<(), ()> { async fn classification_with_gen( State(state): State>, Json(request): Json, -) -> Result)> { +) -> Result { let request_id = Uuid::new_v4(); let task = ClassificationWithGenTask::new(request_id, request); match state @@ -86,14 +86,17 @@ async fn classification_with_gen( .await { Ok(response) => Ok(Json(response).into_response()), - Err(error) => Err(error.into()), + Err(error) => { + error!(%request_id, "{error:#}"); + Err(error.into()) + } } } async fn stream_classification_with_gen( State(state): State>, Json(request): Json, -) -> Result)> { +) -> Result { let request_id = Uuid::new_v4(); let task = StreamingClassificationWithGenTask::new(request_id, request); let response_stream = state @@ -131,3 +134,25 @@ async fn shutdown_signal() { info!("signal received, starting graceful shutdown"); } + +pub struct ServerError(anyhow::Error); + +impl IntoResponse for ServerError { + fn into_response(self) -> Response { + let code = StatusCode::INTERNAL_SERVER_ERROR; + let error = serde_json::json!({ + "code": code.as_u16(), + "message": self.0.to_string(), + }); + (code, Json(error)).into_response() + } +} + +impl From for ServerError +where + E: Into, +{ + fn from(err: E) -> Self { + Self(err.into()) + } +} From 65e62156fcb8f11cd5e558bdef98d58eabbdddb4 Mon Sep 17 00:00:00 2001 From: declark1 Date: Tue, 21 May 2024 10:37:52 -0700 Subject: [PATCH 2/5] Refactor error handling Signed-off-by: declark1 --- Cargo.toml | 1 + src/clients.rs | 77 ++++++++++++++++++++------ src/clients/chunker.rs | 13 ++--- src/clients/detector.rs | 14 ++--- src/clients/nlp.rs | 13 ++--- src/clients/tgis.rs | 14 ++--- src/orchestrator.rs | 108 +++++++++++++++++++++++++------------ src/orchestrator/errors.rs | 64 ++++++++++++++++++++++ src/server.rs | 57 ++++++++++++-------- 9 files changed, 263 insertions(+), 98 deletions(-) create mode 100644 src/orchestrator/errors.rs diff --git a/Cargo.toml b/Cargo.toml index 9672fe1d..619f002c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ url = "2.5.0" validator = { version = "0.18.1", features = ["derive"] } # For API validation uuid = { version = "1.8.0", features = ["v4", "fast-rng"] } anyhow = "1.0.83" +thiserror = "1.0.60" [build-dependencies] tonic-build = "0.11.0" diff --git a/src/clients.rs b/src/clients.rs index 6f08c059..2509dfe0 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -1,9 +1,9 @@ #![allow(dead_code)] use std::{collections::HashMap, time::Duration}; -use anyhow::{Context, Error}; -use futures::future::try_join_all; +use futures::future::join_all; use ginepro::LoadBalancedChannel; +use reqwest::StatusCode; use url::Url; use crate::config::{ServiceConfig, Tls}; @@ -27,6 +27,54 @@ pub const DEFAULT_DETECTOR_PORT: u16 = 8080; const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10); +/// Client errors. +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("{}", .0.message())] + Grpc(#[from] tonic::Status), + #[error("{0}")] + Http(#[from] reqwest::Error), + #[error("invalid model id: {model_id}")] + InvalidModelId { model_id: String }, +} + +impl Error { + /// Returns status code. + pub fn status_code(&self) -> StatusCode { + use tonic::Code::*; + match self { + // Return equivalent http status code for grpc status code + Error::Grpc(error) => match error.code() { + InvalidArgument => StatusCode::BAD_REQUEST, + Internal => StatusCode::INTERNAL_SERVER_ERROR, + NotFound => StatusCode::NOT_FOUND, + DeadlineExceeded => StatusCode::REQUEST_TIMEOUT, + Unimplemented => StatusCode::NOT_IMPLEMENTED, + Unauthenticated => StatusCode::UNAUTHORIZED, + PermissionDenied => StatusCode::FORBIDDEN, + Ok => StatusCode::OK, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }, + // Return http status code for error responses + // and 500 for other errors + Error::Http(error) => match error.status() { + Some(code) => code, + None => StatusCode::INTERNAL_SERVER_ERROR, + }, + // Return 422 for invalid model id + Error::InvalidModelId { .. } => StatusCode::UNPROCESSABLE_ENTITY, + } + } + + /// Returns true for validation-type errors (400/422) and false for other types. + pub fn is_validation_error(&self) -> bool { + matches!( + self.status_code(), + StatusCode::BAD_REQUEST | StatusCode::UNPROCESSABLE_ENTITY + ) + } +} + #[derive(Clone)] pub enum GenerationClient { Tgis(TgisClient), @@ -60,7 +108,7 @@ impl std::ops::Deref for HttpClient { pub async fn create_http_clients( default_port: u16, config: &[(String, ServiceConfig)], -) -> Result, Error> { +) -> HashMap { let clients = config .iter() .map(|(name, service_config)| async move { @@ -75,26 +123,25 @@ pub async fn create_http_clients( let cert_pem = tokio::fs::read(cert_path).await.unwrap_or_else(|error| { panic!("error reading cert from {cert_path:?}: {error}") }); - let identity = reqwest::Identity::from_pem(&cert_pem)?; + let identity = reqwest::Identity::from_pem(&cert_pem) + .unwrap_or_else(|error| panic!("error creating identity: {error}")); builder = builder.use_rustls_tls().identity(identity); } - let client = builder.build().with_context(|| { - format!( - "error creating http client, name={name}, service_config={service_config:?}" - ) - })?; + let client = builder + .build() + .unwrap_or_else(|error| panic!("error creating http client for {name}: {error}")); let client = HttpClient::new(base_url, client); - Ok((name.clone(), client)) as Result<(String, HttpClient), Error> + (name.clone(), client) }) .collect::>(); - Ok(try_join_all(clients).await?.into_iter().collect()) + join_all(clients).await.into_iter().collect() } async fn create_grpc_clients( default_port: u16, config: &[(String, ServiceConfig)], new: fn(LoadBalancedChannel) -> C, -) -> Result, Error> { +) -> HashMap { let clients = config .iter() .map(|(name, service_config)| async move { @@ -133,9 +180,9 @@ async fn create_grpc_clients( if let Some(client_tls_config) = client_tls_config { builder = builder.with_tls(client_tls_config); } - let channel = builder.channel().await.with_context(|| format!("error creating grpc client, name={name}, service_config={service_config:?}"))?; - Ok((name.clone(), new(channel))) as Result<(String, C), Error> + let channel = builder.channel().await.unwrap_or_else(|error| panic!("error creating grpc client for {name}: {error}")); + (name.clone(), new(channel)) }) .collect::>(); - Ok(try_join_all(clients).await?.into_iter().collect()) + join_all(clients).await.into_iter().collect() } diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index 67959fc9..a1f01766 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -1,13 +1,12 @@ use std::{collections::HashMap, pin::Pin}; -use anyhow::{Context, Error}; use futures::{Stream, StreamExt}; use ginepro::LoadBalancedChannel; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tonic::Request; -use super::create_grpc_clients; +use super::{create_grpc_clients, Error}; use crate::{ config::ServiceConfig, pb::{ @@ -27,16 +26,18 @@ pub struct ChunkerClient { } impl ChunkerClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result { - let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await?; - Ok(Self { clients }) + pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { + let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await; + Self { clients } } fn client(&self, model_id: &str) -> Result, Error> { Ok(self .clients .get(model_id) - .context(format!("model not found, model_id={model_id}"))? + .ok_or_else(|| Error::InvalidModelId { + model_id: model_id.to_string(), + })? .clone()) } diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 4fde9d74..341f7d89 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -1,9 +1,8 @@ use std::collections::HashMap; -use anyhow::{Context, Error}; use serde::{Deserialize, Serialize}; -use super::{create_http_clients, HttpClient}; +use super::{create_http_clients, Error, HttpClient}; use crate::config::ServiceConfig; const DETECTOR_ID_HEADER_NAME: &str = "detector-id"; @@ -14,17 +13,18 @@ pub struct DetectorClient { } impl DetectorClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result { - let clients: HashMap = - create_http_clients(default_port, config).await?; - Ok(Self { clients }) + pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { + let clients: HashMap = create_http_clients(default_port, config).await; + Self { clients } } fn client(&self, model_id: &str) -> Result { Ok(self .clients .get(model_id) - .context(format!("model not found, model_id={model_id}"))? + .ok_or_else(|| Error::InvalidModelId { + model_id: model_id.to_string(), + })? .clone()) } diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index 61965e25..ba8faf8d 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -1,13 +1,12 @@ use std::collections::HashMap; -use anyhow::{Context, Error}; use futures::StreamExt; use ginepro::LoadBalancedChannel; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tonic::Request; -use super::create_grpc_clients; +use super::{create_grpc_clients, Error}; use crate::{ config::ServiceConfig, pb::{ @@ -30,16 +29,18 @@ pub struct NlpClient { } impl NlpClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result { - let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await?; - Ok(Self { clients }) + pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { + let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await; + Self { clients } } fn client(&self, model_id: &str) -> Result, Error> { Ok(self .clients .get(model_id) - .context(format!("model not found, model_id={model_id}"))? + .ok_or_else(|| Error::InvalidModelId { + model_id: model_id.to_string(), + })? .clone()) } diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index fa9b3b59..23193b1a 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -1,12 +1,11 @@ use std::collections::HashMap; -use anyhow::{Context, Error}; use futures::StreamExt; use ginepro::LoadBalancedChannel; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; -use super::create_grpc_clients; +use super::{create_grpc_clients, Error}; use crate::{ config::ServiceConfig, pb::fmaas::{ @@ -22,10 +21,9 @@ pub struct TgisClient { } impl TgisClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result { - let clients = - create_grpc_clients(default_port, config, GenerationServiceClient::new).await?; - Ok(Self { clients }) + pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { + let clients = create_grpc_clients(default_port, config, GenerationServiceClient::new).await; + Self { clients } } fn client( @@ -37,7 +35,9 @@ impl TgisClient { Ok(self .clients .get(model_id) - .context(format!("model not found, model_id={model_id}"))? + .ok_or_else(|| Error::InvalidModelId { + model_id: model_id.to_string(), + })? .clone()) } diff --git a/src/orchestrator.rs b/src/orchestrator.rs index c6f5dbb1..68236a4d 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -1,12 +1,13 @@ +pub mod errors; use std::{collections::HashMap, sync::Arc}; -use anyhow::{Context as _, Error}; +pub use errors::Error; use futures::{ future::try_join_all, stream::{self, StreamExt}, }; use tokio_stream::wrappers::ReceiverStream; -use tracing::{debug, info}; +use tracing::{debug, error, info}; use uuid::Uuid; use crate::{ @@ -21,8 +22,10 @@ use crate::{ InputWarningReason, TextGenTokenClassificationResults, TokenClassificationResult, }, pb::{ - caikit::runtime::chunkers::TokenizationTaskRequest as ChunkersTokenizationTaskRequest, - caikit::runtime::nlp::{TextGenerationTaskRequest, TokenizationTaskRequest}, + caikit::runtime::{ + chunkers::TokenizationTaskRequest as ChunkersTokenizationTaskRequest, + nlp::{TextGenerationTaskRequest, TokenizationTaskRequest}, + }, fmaas::{ BatchedGenerationRequest, BatchedTokenizeRequest, GenerationRequest, TokenizeRequest, }, @@ -47,7 +50,7 @@ pub struct Orchestrator { impl Orchestrator { pub async fn new(config: OrchestratorConfig) -> Result { - let (generation_client, chunker_client, detector_client) = create_clients(&config).await?; + let (generation_client, chunker_client, detector_client) = create_clients(&config).await; let ctx = Arc::new(Context { config, generation_client, @@ -69,7 +72,7 @@ impl Orchestrator { "handling unary task" ); let ctx = self.ctx.clone(); - tokio::spawn(async move { + let task_handle = tokio::spawn(async move { let masks = task.guardrails_config.input_masks(); let input_detectors = task.guardrails_config.input_detectors(); let output_detectors = task.guardrails_config.output_detectors(); @@ -137,8 +140,20 @@ impl Orchestrator { } Ok(generation_results) } - }) - .await? + }); + match task_handle.await { + Ok(Ok(result)) => Ok(result), + Ok(Err(error)) => { + error!(request_id = ?task.request_id, %error, "unary task failed"); + Err(error) + } + Err(error) => { + // Task failed due to cancellation or panic + let error = error.into(); + error!(request_id = ?task.request_id, %error, "unary task failed"); + Err(error) + } + } } /// Handles streaming tasks. @@ -171,10 +186,12 @@ async fn chunk_and_detect( let chunker_ids = detectors .keys() .map(|detector_id| { - let chunker_id = ctx - .config - .get_chunker_id(detector_id) - .context(format!("detector not found, detector_id={detector_id}"))?; + let chunker_id = + ctx.config + .get_chunker_id(detector_id) + .ok_or_else(|| Error::InvalidDetectorId { + detector_id: detector_id.clone(), + })?; Ok::(chunker_id) }) .collect::, Error>>()?; @@ -218,10 +235,11 @@ async fn detect( let ctx = ctx.clone(); let detector_id = detector_id.clone(); let detector_params = detector_params.clone(); - let chunker_id = ctx - .config - .get_chunker_id(&detector_id) - .context(format!("detector not found, detector_id={detector_id}"))?; + let chunker_id = ctx.config.get_chunker_id(&detector_id).ok_or_else(|| { + Error::InvalidDetectorId { + detector_id: detector_id.clone(), + } + })?; let chunks = chunks.get(&chunker_id).unwrap().clone(); Ok(tokio::spawn(async move { handle_detection_task(ctx, detector_id, detector_params, chunks).await @@ -259,7 +277,10 @@ async fn handle_chunk_task( .chunker_client .tokenization_task_predict(&chunker_id, request) .await - .with_context(|| format!("chunker request failed, chunker_id={chunker_id}"))?; + .map_err(|error| Error::ChunkerRequestFailed { + chunker_id: chunker_id.clone(), + error, + })?; debug!( %chunker_id, ?response, @@ -310,8 +331,9 @@ async fn handle_detection_task( .detector_client .classify(&detector_id, request) .await - .with_context(|| { - format!("detector request failed, detector_id={detector_id}") + .map_err(|error| Error::DetectorRequestFailed { + detector_id: detector_id.clone(), + error, })?; debug!( %detector_id, @@ -364,10 +386,14 @@ async fn tokenize( ?request, "sending tokenize request" ); - let mut response = client - .tokenize(request) - .await - .with_context(|| format!("tokenize request failed, model_id={model_id}"))?; + let mut response = + client + .tokenize(request) + .await + .map_err(|error| Error::TokenizeRequestFailed { + model_id: model_id.clone(), + error, + })?; debug!( %model_id, provider = "tgis", @@ -385,7 +411,13 @@ async fn tokenize( ?request, "sending tokenize request" ); - let response = client.tokenization_task_predict(&model_id, request).await?; + let response = client + .tokenization_task_predict(&model_id, request) + .await + .map_err(|error| Error::TokenizeRequestFailed { + model_id: model_id.clone(), + error, + })?; debug!( %model_id, provider = "nlp", @@ -425,9 +457,14 @@ async fn generate( ?request, "sending generate request" ); - let mut response = client.generate(request).await.with_context(|| { - format!("generate request failed, model_id={model_id}, provider=tgis") - })?; + let mut response = + client + .generate(request) + .await + .map_err(|error| Error::GenerateRequestFailed { + model_id: model_id.clone(), + error, + })?; debug!( %model_id, provider = "tgis", @@ -498,8 +535,9 @@ async fn generate( let response = client .text_generation_task_predict(&model_id, request) .await - .with_context(|| { - format!("generate request failed, model_id={model_id}, provider=nlp") + .map_err(|error| Error::GenerateRequestFailed { + model_id: model_id.clone(), + error, })?; debug!( %model_id, @@ -547,7 +585,7 @@ fn apply_masks(text: &str, masks: &[(usize, usize)]) -> Vec<(usize, String)> { async fn create_clients( config: &OrchestratorConfig, -) -> Result<(GenerationClient, ChunkerClient, DetectorClient), Error> { +) -> (GenerationClient, ChunkerClient, DetectorClient) { // TODO: create better solution for routers let generation_client = match config.generation.provider { GenerationProvider::Tgis => { @@ -555,7 +593,7 @@ async fn create_clients( clients::DEFAULT_TGIS_PORT, &[("tgis-router".to_string(), config.generation.service.clone())], ) - .await?; + .await; GenerationClient::Tgis(client) } GenerationProvider::Nlp => { @@ -563,7 +601,7 @@ async fn create_clients( clients::DEFAULT_CAIKIT_NLP_PORT, &[("tgis-router".to_string(), config.generation.service.clone())], ) - .await?; + .await; GenerationClient::Nlp(client) } }; @@ -573,7 +611,7 @@ async fn create_clients( .iter() .map(|(chunker_id, config)| (chunker_id.clone(), config.service.clone())) .collect::>(); - let chunker_client = ChunkerClient::new(clients::DEFAULT_CHUNKER_PORT, &chunker_config).await?; + let chunker_client = ChunkerClient::new(clients::DEFAULT_CHUNKER_PORT, &chunker_config).await; let detector_config = config .detectors @@ -581,9 +619,9 @@ async fn create_clients( .map(|(detector_id, config)| (detector_id.clone(), config.service.clone())) .collect::>(); let detector_client = - DetectorClient::new(clients::DEFAULT_DETECTOR_PORT, &detector_config).await?; + DetectorClient::new(clients::DEFAULT_DETECTOR_PORT, &detector_config).await; - Ok((generation_client, chunker_client, detector_client)) + (generation_client, chunker_client, detector_client) } #[derive(Debug, Clone)] diff --git a/src/orchestrator/errors.rs b/src/orchestrator/errors.rs new file mode 100644 index 00000000..d344b0ff --- /dev/null +++ b/src/orchestrator/errors.rs @@ -0,0 +1,64 @@ +use reqwest::StatusCode; + +use crate::clients; + +/// Orchestrator errors. +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("invalid detector: detector_id={detector_id}")] + InvalidDetectorId { detector_id: String }, + #[error("invalid chunker: chunker_id={chunker_id}")] + InvalidChunkerId { chunker_id: String }, + #[error("detector request failed for detector_id={detector_id}: {error}")] + DetectorRequestFailed { + detector_id: String, + error: clients::Error, + }, + #[error("chunker request failed for chunker_id={chunker_id}: {error}")] + ChunkerRequestFailed { + chunker_id: String, + error: clients::Error, + }, + #[error("generate request failed for model_id={model_id}: {error}")] + GenerateRequestFailed { + model_id: String, + error: clients::Error, + }, + #[error("tokenize request failed for model_id={model_id}: {error}")] + TokenizeRequestFailed { + model_id: String, + error: clients::Error, + }, + #[error("task cancelled")] + Cancelled, + #[error("{0}")] + Other(String), +} + +impl Error { + /// Returns true for validation-type errors and false for other types. + pub fn is_validation_error(&self) -> bool { + use Error::*; + match self { + InvalidDetectorId { .. } | InvalidChunkerId { .. } => true, + DetectorRequestFailed { error, .. } + | ChunkerRequestFailed { error, .. } + | GenerateRequestFailed { error, .. } + | TokenizeRequestFailed { error, .. } => matches!( + error.status_code(), + StatusCode::BAD_REQUEST | StatusCode::UNPROCESSABLE_ENTITY + ), + _ => false, + } + } +} + +impl From for Error { + fn from(error: tokio::task::JoinError) -> Self { + if error.is_cancelled() { + Self::Cancelled + } else { + Self::Other(format!("task panicked: {error}")) + } + } +} diff --git a/src/server.rs b/src/server.rs index 68c3e5fe..3848ba6a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,6 +1,5 @@ use std::{net::SocketAddr, path::PathBuf, sync::Arc}; -use anyhow::Error; use axum::{ extract::State, http::StatusCode, @@ -19,7 +18,9 @@ use uuid::Uuid; use crate::{ config::OrchestratorConfig, models, - orchestrator::{ClassificationWithGenTask, Orchestrator, StreamingClassificationWithGenTask}, + orchestrator::{ + self, ClassificationWithGenTask, Orchestrator, StreamingClassificationWithGenTask, + }, }; const API_PREFIX: &str = r#"/api/v1/task"#; @@ -59,7 +60,9 @@ pub async fn run( ) .with_state(shared_state); - let listener = TcpListener::bind(&http_addr).await?; + let listener = TcpListener::bind(&http_addr) + .await + .unwrap_or_else(|_| panic!("failed to bind to {http_addr}")); let server = axum::serve(listener, app.into_make_service()).with_graceful_shutdown(shutdown_signal()); @@ -77,7 +80,7 @@ async fn health() -> Result<(), ()> { async fn classification_with_gen( State(state): State>, Json(request): Json, -) -> Result { +) -> Result { let request_id = Uuid::new_v4(); let task = ClassificationWithGenTask::new(request_id, request); match state @@ -86,17 +89,14 @@ async fn classification_with_gen( .await { Ok(response) => Ok(Json(response).into_response()), - Err(error) => { - error!(%request_id, "{error:#}"); - Err(error.into()) - } + Err(error) => Err(error.into()), } } async fn stream_classification_with_gen( State(state): State>, Json(request): Json, -) -> Result { +) -> Result { let request_id = Uuid::new_v4(); let task = StreamingClassificationWithGenTask::new(request_id, request); let response_stream = state @@ -135,24 +135,37 @@ async fn shutdown_signal() { info!("signal received, starting graceful shutdown"); } -pub struct ServerError(anyhow::Error); +/// High-level errors to return to clients. +/// Validation errors are forwarded from downstream clients. +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("{0}")] + ValidationError(String), + #[error("unexpected error occured while processing request")] + UnexpectedError, +} -impl IntoResponse for ServerError { +impl From for Error { + fn from(error: orchestrator::Error) -> Self { + if error.is_validation_error() { + Self::ValidationError(error.to_string()) + } else { + Self::UnexpectedError + } + } +} + +impl IntoResponse for Error { fn into_response(self) -> Response { - let code = StatusCode::INTERNAL_SERVER_ERROR; + use Error::*; + let (code, message) = match self { + ValidationError(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), + UnexpectedError => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), + }; let error = serde_json::json!({ "code": code.as_u16(), - "message": self.0.to_string(), + "message": message, }); (code, Json(error)).into_response() } } - -impl From for ServerError -where - E: Into, -{ - fn from(err: E) -> Self { - Self(err.into()) - } -} From 792133eb6a737f72b5c24b2f2f371252cd0671ff Mon Sep 17 00:00:00 2001 From: declark1 Date: Tue, 21 May 2024 13:21:30 -0700 Subject: [PATCH 3/5] Update server::Error variants and conversion, rename Invalid vto NotFound and return 404, drop is_validation_error helper Signed-off-by: declark1 --- src/clients.rs | 16 ++++------------ src/clients/chunker.rs | 2 +- src/clients/detector.rs | 2 +- src/clients/nlp.rs | 2 +- src/clients/tgis.rs | 2 +- src/orchestrator.rs | 13 +++++++------ src/orchestrator/errors.rs | 26 ++------------------------ src/server.rs | 32 ++++++++++++++++++++++---------- 8 files changed, 39 insertions(+), 56 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index 2509dfe0..76f20a00 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -34,8 +34,8 @@ pub enum Error { Grpc(#[from] tonic::Status), #[error("{0}")] Http(#[from] reqwest::Error), - #[error("invalid model id: {model_id}")] - InvalidModelId { model_id: String }, + #[error("model not found: {model_id}")] + ModelNotFound { model_id: String }, } impl Error { @@ -61,18 +61,10 @@ impl Error { Some(code) => code, None => StatusCode::INTERNAL_SERVER_ERROR, }, - // Return 422 for invalid model id - Error::InvalidModelId { .. } => StatusCode::UNPROCESSABLE_ENTITY, + // Return 404 for model not found + Error::ModelNotFound { .. } => StatusCode::NOT_FOUND, } } - - /// Returns true for validation-type errors (400/422) and false for other types. - pub fn is_validation_error(&self) -> bool { - matches!( - self.status_code(), - StatusCode::BAD_REQUEST | StatusCode::UNPROCESSABLE_ENTITY - ) - } } #[derive(Clone)] diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index a1f01766..493ad8e6 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -35,7 +35,7 @@ impl ChunkerClient { Ok(self .clients .get(model_id) - .ok_or_else(|| Error::InvalidModelId { + .ok_or_else(|| Error::ModelNotFound { model_id: model_id.to_string(), })? .clone()) diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 341f7d89..6d684e70 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -22,7 +22,7 @@ impl DetectorClient { Ok(self .clients .get(model_id) - .ok_or_else(|| Error::InvalidModelId { + .ok_or_else(|| Error::ModelNotFound { model_id: model_id.to_string(), })? .clone()) diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index ba8faf8d..bfc3499f 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -38,7 +38,7 @@ impl NlpClient { Ok(self .clients .get(model_id) - .ok_or_else(|| Error::InvalidModelId { + .ok_or_else(|| Error::ModelNotFound { model_id: model_id.to_string(), })? .clone()) diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index 23193b1a..b1dd3942 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -35,7 +35,7 @@ impl TgisClient { Ok(self .clients .get(model_id) - .ok_or_else(|| Error::InvalidModelId { + .ok_or_else(|| Error::ModelNotFound { model_id: model_id.to_string(), })? .clone()) diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 68236a4d..1a3dfda5 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -189,7 +189,7 @@ async fn chunk_and_detect( let chunker_id = ctx.config .get_chunker_id(detector_id) - .ok_or_else(|| Error::InvalidDetectorId { + .ok_or_else(|| Error::DetectorNotFound { detector_id: detector_id.clone(), })?; Ok::(chunker_id) @@ -235,11 +235,12 @@ async fn detect( let ctx = ctx.clone(); let detector_id = detector_id.clone(); let detector_params = detector_params.clone(); - let chunker_id = ctx.config.get_chunker_id(&detector_id).ok_or_else(|| { - Error::InvalidDetectorId { - detector_id: detector_id.clone(), - } - })?; + let chunker_id = + ctx.config + .get_chunker_id(&detector_id) + .ok_or_else(|| Error::DetectorNotFound { + detector_id: detector_id.clone(), + })?; let chunks = chunks.get(&chunker_id).unwrap().clone(); Ok(tokio::spawn(async move { handle_detection_task(ctx, detector_id, detector_params, chunks).await diff --git a/src/orchestrator/errors.rs b/src/orchestrator/errors.rs index d344b0ff..b20d3b89 100644 --- a/src/orchestrator/errors.rs +++ b/src/orchestrator/errors.rs @@ -1,14 +1,10 @@ -use reqwest::StatusCode; - use crate::clients; /// Orchestrator errors. #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("invalid detector: detector_id={detector_id}")] - InvalidDetectorId { detector_id: String }, - #[error("invalid chunker: chunker_id={chunker_id}")] - InvalidChunkerId { chunker_id: String }, + #[error("detector not found: {detector_id}")] + DetectorNotFound { detector_id: String }, #[error("detector request failed for detector_id={detector_id}: {error}")] DetectorRequestFailed { detector_id: String, @@ -35,24 +31,6 @@ pub enum Error { Other(String), } -impl Error { - /// Returns true for validation-type errors and false for other types. - pub fn is_validation_error(&self) -> bool { - use Error::*; - match self { - InvalidDetectorId { .. } | InvalidChunkerId { .. } => true, - DetectorRequestFailed { error, .. } - | ChunkerRequestFailed { error, .. } - | GenerateRequestFailed { error, .. } - | TokenizeRequestFailed { error, .. } => matches!( - error.status_code(), - StatusCode::BAD_REQUEST | StatusCode::UNPROCESSABLE_ENTITY - ), - _ => false, - } - } -} - impl From for Error { fn from(error: tokio::task::JoinError) -> Self { if error.is_cancelled() { diff --git a/src/server.rs b/src/server.rs index 3848ba6a..33ce5c90 100644 --- a/src/server.rs +++ b/src/server.rs @@ -136,21 +136,32 @@ async fn shutdown_signal() { } /// High-level errors to return to clients. -/// Validation errors are forwarded from downstream clients. #[derive(Debug, thiserror::Error)] pub enum Error { #[error("{0}")] - ValidationError(String), + Validation(String), + #[error("{0}")] + NotFound(String), #[error("unexpected error occured while processing request")] - UnexpectedError, + Unexpected, } impl From for Error { fn from(error: orchestrator::Error) -> Self { - if error.is_validation_error() { - Self::ValidationError(error.to_string()) - } else { - Self::UnexpectedError + use orchestrator::Error::*; + match error { + DetectorNotFound { .. } => Self::NotFound(error.to_string()), + DetectorRequestFailed { error, .. } + | ChunkerRequestFailed { error, .. } + | GenerateRequestFailed { error, .. } + | TokenizeRequestFailed { error, .. } => match error.status_code() { + StatusCode::BAD_REQUEST | StatusCode::UNPROCESSABLE_ENTITY => { + Self::Validation(error.to_string()) + } + StatusCode::NOT_FOUND => Self::NotFound(error.to_string()), + _ => Self::Unexpected, + }, + _ => Self::Unexpected, } } } @@ -159,12 +170,13 @@ impl IntoResponse for Error { fn into_response(self) -> Response { use Error::*; let (code, message) = match self { - ValidationError(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), - UnexpectedError => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), + Validation(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), + NotFound(_) => (StatusCode::NOT_FOUND, self.to_string()), + Unexpected => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), }; let error = serde_json::json!({ "code": code.as_u16(), - "message": message, + "details": message, }); (code, Json(error)).into_response() } From 572e34fd5766fc5359b90a91fb253bfcce9b8781 Mon Sep 17 00:00:00 2001 From: declark1 Date: Wed, 22 May 2024 10:16:06 -0700 Subject: [PATCH 4/5] Address review comments Signed-off-by: declark1 --- src/clients.rs | 2 +- src/orchestrator.rs | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index 76f20a00..27e12067 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -116,7 +116,7 @@ pub async fn create_http_clients( panic!("error reading cert from {cert_path:?}: {error}") }); let identity = reqwest::Identity::from_pem(&cert_pem) - .unwrap_or_else(|error| panic!("error creating identity: {error}")); + .unwrap_or_else(|error| panic!("error parsing cert: {error}")); builder = builder.use_rustls_tls().identity(identity); } let client = builder diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 1a3dfda5..4f6ec473 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -142,13 +142,15 @@ impl Orchestrator { } }); match task_handle.await { + // Task completed successfully Ok(Ok(result)) => Ok(result), + // Task failed, return error propagated from child task that failed Ok(Err(error)) => { error!(request_id = ?task.request_id, %error, "unary task failed"); Err(error) } + // Task cancelled or panicked Err(error) => { - // Task failed due to cancellation or panic let error = error.into(); error!(request_id = ?task.request_id, %error, "unary task failed"); Err(error) From ef49caf079d2f42a9556d0b92242e9cf1d321ea1 Mon Sep 17 00:00:00 2001 From: declark1 Date: Wed, 22 May 2024 10:20:38 -0700 Subject: [PATCH 5/5] Run cargo-sort to sort dependencies Signed-off-by: declark1 --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 619f002c..a19013ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ name = "fms-guardrails-orchestr8" path = "src/main.rs" [dependencies] +anyhow = "1.0.83" axum = { version = "0.7.5", features = ["json"] } clap = { version = "4.5.3", features = ["derive", "env"] } futures = "0.3.30" @@ -25,16 +26,15 @@ rustls-webpki = "0.102.2" serde = { version = "1.0.200", features = ["derive"] } serde_json = "1.0.116" serde_yml = "0.0.5" +thiserror = "1.0.60" tokio = { version = "1.37.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] } tokio-stream = "0.1.14" tonic = { version = "0.11.0", features = ["tls"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] } url = "2.5.0" -validator = { version = "0.18.1", features = ["derive"] } # For API validation uuid = { version = "1.8.0", features = ["v4", "fast-rng"] } -anyhow = "1.0.83" -thiserror = "1.0.60" +validator = { version = "0.18.1", features = ["derive"] } # For API validation [build-dependencies] tonic-build = "0.11.0"