Skip to content

Commit

Permalink
Implement error handling (#31)
Browse files Browse the repository at this point in the history
Signed-off-by: declark1 <daniel.clark@ibm.com>
  • Loading branch information
declark1 authored and gkumbhat committed May 22, 2024
1 parent 8faa26d commit 6ca2e22
Show file tree
Hide file tree
Showing 11 changed files with 321 additions and 131 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ name = "fms-guardrails-orchestr8"
path = "src/main.rs"

[dependencies]
anyhow = "1.0.83"
axum = { version = "0.7.5", features = ["json"] }
clap = { version = "4.5.3", features = ["derive", "env"] }
futures = "0.3.30"
Expand All @@ -25,15 +26,15 @@ rustls-webpki = "0.102.2"
serde = { version = "1.0.200", features = ["derive"] }
serde_json = "1.0.116"
serde_yml = "0.0.5"
thiserror = "1.0.59"
thiserror = "1.0.60"
tokio = { version = "1.37.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] }
tokio-stream = "0.1.14"
tonic = { version = "0.11.0", features = ["tls"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
url = "2.5.0"
validator = { version = "0.18.1", features = ["derive"] } # For API validation
uuid = { version = "1.8.0", features = ["v4", "fast-rng"] }
validator = { version = "0.18.1", features = ["derive"] } # For API validation

[build-dependencies]
tonic-build = "0.11.0"
Expand Down
68 changes: 50 additions & 18 deletions src/clients.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![allow(dead_code)]
use std::{collections::HashMap, time::Duration};

use futures::future::try_join_all;
use futures::future::join_all;
use ginepro::LoadBalancedChannel;
use reqwest::StatusCode;
use url::Url;

use crate::config::{ServiceConfig, Tls};
Expand All @@ -26,16 +27,44 @@ pub const DEFAULT_DETECTOR_PORT: u16 = 8080;
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);

/// Client errors.
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("model not found: {0}")]
ModelNotFound(String),
#[error(transparent)]
ReqwestError(#[from] reqwest::Error),
#[error(transparent)]
TonicError(#[from] tonic::Status),
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error("{}", .0.message())]
Grpc(#[from] tonic::Status),
#[error("{0}")]
Http(#[from] reqwest::Error),
#[error("model not found: {model_id}")]
ModelNotFound { model_id: String },
}

impl Error {
/// Returns status code.
pub fn status_code(&self) -> StatusCode {
use tonic::Code::*;
match self {
// Return equivalent http status code for grpc status code
Error::Grpc(error) => match error.code() {
InvalidArgument => StatusCode::BAD_REQUEST,
Internal => StatusCode::INTERNAL_SERVER_ERROR,
NotFound => StatusCode::NOT_FOUND,
DeadlineExceeded => StatusCode::REQUEST_TIMEOUT,
Unimplemented => StatusCode::NOT_IMPLEMENTED,
Unauthenticated => StatusCode::UNAUTHORIZED,
PermissionDenied => StatusCode::FORBIDDEN,
Ok => StatusCode::OK,
_ => StatusCode::INTERNAL_SERVER_ERROR,
},
// Return http status code for error responses
// and 500 for other errors
Error::Http(error) => match error.status() {
Some(code) => code,
None => StatusCode::INTERNAL_SERVER_ERROR,
},
// Return 404 for model not found
Error::ModelNotFound { .. } => StatusCode::NOT_FOUND,
}
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -71,7 +100,7 @@ impl std::ops::Deref for HttpClient {
pub async fn create_http_clients(
default_port: u16,
config: &[(String, ServiceConfig)],
) -> Result<HashMap<String, HttpClient>, Error> {
) -> HashMap<String, HttpClient> {
let clients = config
.iter()
.map(|(name, service_config)| async move {
Expand All @@ -86,22 +115,25 @@ pub async fn create_http_clients(
let cert_pem = tokio::fs::read(cert_path).await.unwrap_or_else(|error| {
panic!("error reading cert from {cert_path:?}: {error}")
});
let identity = reqwest::Identity::from_pem(&cert_pem)?;
let identity = reqwest::Identity::from_pem(&cert_pem)
.unwrap_or_else(|error| panic!("error parsing cert: {error}"));
builder = builder.use_rustls_tls().identity(identity);
}
let client = builder.build()?;
let client = builder
.build()
.unwrap_or_else(|error| panic!("error creating http client for {name}: {error}"));
let client = HttpClient::new(base_url, client);
Ok((name.clone(), client)) as Result<(String, HttpClient), Error>
(name.clone(), client)
})
.collect::<Vec<_>>();
Ok(try_join_all(clients).await?.into_iter().collect())
join_all(clients).await.into_iter().collect()
}

async fn create_grpc_clients<C>(
default_port: u16,
config: &[(String, ServiceConfig)],
new: fn(LoadBalancedChannel) -> C,
) -> Result<HashMap<String, C>, Error> {
) -> HashMap<String, C> {
let clients = config
.iter()
.map(|(name, service_config)| async move {
Expand Down Expand Up @@ -140,9 +172,9 @@ async fn create_grpc_clients<C>(
if let Some(client_tls_config) = client_tls_config {
builder = builder.with_tls(client_tls_config);
}
let channel = builder.channel().await.unwrap(); // TODO: handle error
Ok((name.clone(), new(channel))) as Result<(String, C), Error>
let channel = builder.channel().await.unwrap_or_else(|error| panic!("error creating grpc client for {name}: {error}"));
(name.clone(), new(channel))
})
.collect::<Vec<_>>();
Ok(try_join_all(clients).await?.into_iter().collect())
join_all(clients).await.into_iter().collect()
}
10 changes: 6 additions & 4 deletions src/clients/chunker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,18 @@ pub struct ChunkerClient {
}

impl ChunkerClient {
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await?;
Ok(Self { clients })
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await;
Self { clients }
}

fn client(&self, model_id: &str) -> Result<ChunkersServiceClient<LoadBalancedChannel>, Error> {
Ok(self
.clients
.get(model_id)
.ok_or_else(|| Error::ModelNotFound(model_id.into()))?
.ok_or_else(|| Error::ModelNotFound {
model_id: model_id.to_string(),
})?
.clone())
}

Expand Down
11 changes: 6 additions & 5 deletions src/clients/detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@ pub struct DetectorClient {
}

impl DetectorClient {
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
let clients: HashMap<String, HttpClient> =
create_http_clients(default_port, config).await?;
Ok(Self { clients })
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
let clients: HashMap<String, HttpClient> = create_http_clients(default_port, config).await;
Self { clients }
}

fn client(&self, model_id: &str) -> Result<HttpClient, Error> {
Ok(self
.clients
.get(model_id)
.ok_or_else(|| Error::ModelNotFound(model_id.into()))?
.ok_or_else(|| Error::ModelNotFound {
model_id: model_id.to_string(),
})?
.clone())
}

Expand Down
10 changes: 6 additions & 4 deletions src/clients/nlp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,18 @@ pub struct NlpClient {
}

impl NlpClient {
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await?;
Ok(Self { clients })
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await;
Self { clients }
}

fn client(&self, model_id: &str) -> Result<NlpServiceClient<LoadBalancedChannel>, Error> {
Ok(self
.clients
.get(model_id)
.ok_or_else(|| Error::ModelNotFound(model_id.into()))?
.ok_or_else(|| Error::ModelNotFound {
model_id: model_id.to_string(),
})?
.clone())
}

Expand Down
11 changes: 6 additions & 5 deletions src/clients/tgis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ pub struct TgisClient {
}

impl TgisClient {
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
let clients =
create_grpc_clients(default_port, config, GenerationServiceClient::new).await?;
Ok(Self { clients })
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
let clients = create_grpc_clients(default_port, config, GenerationServiceClient::new).await;
Self { clients }
}

fn client(
Expand All @@ -36,7 +35,9 @@ impl TgisClient {
Ok(self
.clients
.get(model_id)
.ok_or_else(|| Error::ModelNotFound(model_id.into()))?
.ok_or_else(|| Error::ModelNotFound {
model_id: model_id.to_string(),
})?
.clone())
}

Expand Down
9 changes: 6 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ impl OrchestratorConfig {
todo!()
}

pub fn get_chunker_id(&self, detector_id: &str) -> String {
self.detectors.get(detector_id).unwrap().chunker_id.clone()
pub fn get_chunker_id(&self, detector_id: &str) -> Option<String> {
self.detectors
.get(detector_id)
.map(|detector_config| detector_config.chunker_id.clone())
}
}

Expand All @@ -126,8 +128,9 @@ fn service_tls_name_to_config(

#[cfg(test)]
mod tests {
use anyhow::Error;

use super::*;
use crate::Error;

#[test]
fn test_deserialize_config() -> Result<(), Error> {
Expand Down
35 changes: 0 additions & 35 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,8 @@
#![allow(clippy::iter_kv_map, clippy::enum_variant_names)]

use axum::{http::StatusCode, Json};

mod clients;
mod config;
mod models;
mod orchestrator;
mod pb;
pub mod server;

#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
ClientError(#[from] crate::clients::Error),
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error(transparent)]
YamlError(#[from] serde_yml::Error),
}

// TODO: create better errors and properly convert
impl From<Error> for (StatusCode, Json<String>) {
fn from(value: Error) -> Self {
use Error::*;
match value {
ClientError(error) => match error {
clients::Error::ModelNotFound(message) => {
(StatusCode::UNPROCESSABLE_ENTITY, Json(message))
}
clients::Error::ReqwestError(error) => {
(StatusCode::INTERNAL_SERVER_ERROR, Json(error.to_string()))
}
clients::Error::TonicError(error) => {
(StatusCode::INTERNAL_SERVER_ERROR, Json(error.to_string()))
}
clients::Error::IoError(_) => todo!(),
},
IoError(_) => todo!(),
YamlError(_) => todo!(),
}
}
}
Loading

0 comments on commit 6ca2e22

Please sign in to comment.