Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement error handling #31

Merged
merged 5 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ name = "fms-guardrails-orchestr8"
path = "src/main.rs"

[dependencies]
anyhow = "1.0.83"
axum = { version = "0.7.5", features = ["json"] }
clap = { version = "4.5.3", features = ["derive", "env"] }
futures = "0.3.30"
Expand All @@ -25,15 +26,15 @@ rustls-webpki = "0.102.2"
serde = { version = "1.0.200", features = ["derive"] }
serde_json = "1.0.116"
serde_yml = "0.0.5"
thiserror = "1.0.59"
thiserror = "1.0.60"
tokio = { version = "1.37.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] }
tokio-stream = "0.1.14"
tonic = { version = "0.11.0", features = ["tls"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
url = "2.5.0"
validator = { version = "0.18.1", features = ["derive"] } # For API validation
uuid = { version = "1.8.0", features = ["v4", "fast-rng"] }
validator = { version = "0.18.1", features = ["derive"] } # For API validation

[build-dependencies]
tonic-build = "0.11.0"
Expand Down
68 changes: 50 additions & 18 deletions src/clients.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![allow(dead_code)]
use std::{collections::HashMap, time::Duration};

use futures::future::try_join_all;
use futures::future::join_all;
use ginepro::LoadBalancedChannel;
use reqwest::StatusCode;
use url::Url;

use crate::config::{ServiceConfig, Tls};
Expand All @@ -26,16 +27,44 @@ pub const DEFAULT_DETECTOR_PORT: u16 = 8080;
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);

/// Client errors.
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("model not found: {0}")]
ModelNotFound(String),
#[error(transparent)]
ReqwestError(#[from] reqwest::Error),
#[error(transparent)]
TonicError(#[from] tonic::Status),
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error("{}", .0.message())]
Grpc(#[from] tonic::Status),
#[error("{0}")]
Http(#[from] reqwest::Error),
#[error("model not found: {model_id}")]
ModelNotFound { model_id: String },
}

impl Error {
/// Returns status code.
pub fn status_code(&self) -> StatusCode {
use tonic::Code::*;
match self {
// Return equivalent http status code for grpc status code
Error::Grpc(error) => match error.code() {
InvalidArgument => StatusCode::BAD_REQUEST,
Internal => StatusCode::INTERNAL_SERVER_ERROR,
NotFound => StatusCode::NOT_FOUND,
DeadlineExceeded => StatusCode::REQUEST_TIMEOUT,
Unimplemented => StatusCode::NOT_IMPLEMENTED,
Unauthenticated => StatusCode::UNAUTHORIZED,
PermissionDenied => StatusCode::FORBIDDEN,
Ok => StatusCode::OK,
_ => StatusCode::INTERNAL_SERVER_ERROR,
},
// Return http status code for error responses
// and 500 for other errors
Error::Http(error) => match error.status() {
Some(code) => code,
None => StatusCode::INTERNAL_SERVER_ERROR,
},
// Return 404 for model not found
Error::ModelNotFound { .. } => StatusCode::NOT_FOUND,
}
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -71,7 +100,7 @@ impl std::ops::Deref for HttpClient {
pub async fn create_http_clients(
default_port: u16,
config: &[(String, ServiceConfig)],
) -> Result<HashMap<String, HttpClient>, Error> {
) -> HashMap<String, HttpClient> {
let clients = config
.iter()
.map(|(name, service_config)| async move {
Expand All @@ -86,22 +115,25 @@ pub async fn create_http_clients(
let cert_pem = tokio::fs::read(cert_path).await.unwrap_or_else(|error| {
panic!("error reading cert from {cert_path:?}: {error}")
});
let identity = reqwest::Identity::from_pem(&cert_pem)?;
let identity = reqwest::Identity::from_pem(&cert_pem)
.unwrap_or_else(|error| panic!("error parsing cert: {error}"));
builder = builder.use_rustls_tls().identity(identity);
}
let client = builder.build()?;
let client = builder
.build()
.unwrap_or_else(|error| panic!("error creating http client for {name}: {error}"));
let client = HttpClient::new(base_url, client);
Ok((name.clone(), client)) as Result<(String, HttpClient), Error>
(name.clone(), client)
})
.collect::<Vec<_>>();
Ok(try_join_all(clients).await?.into_iter().collect())
join_all(clients).await.into_iter().collect()
}

async fn create_grpc_clients<C>(
default_port: u16,
config: &[(String, ServiceConfig)],
new: fn(LoadBalancedChannel) -> C,
) -> Result<HashMap<String, C>, Error> {
) -> HashMap<String, C> {
let clients = config
.iter()
.map(|(name, service_config)| async move {
Expand Down Expand Up @@ -140,9 +172,9 @@ async fn create_grpc_clients<C>(
if let Some(client_tls_config) = client_tls_config {
builder = builder.with_tls(client_tls_config);
}
let channel = builder.channel().await.unwrap(); // TODO: handle error
Ok((name.clone(), new(channel))) as Result<(String, C), Error>
let channel = builder.channel().await.unwrap_or_else(|error| panic!("error creating grpc client for {name}: {error}"));
(name.clone(), new(channel))
})
.collect::<Vec<_>>();
Ok(try_join_all(clients).await?.into_iter().collect())
join_all(clients).await.into_iter().collect()
}
10 changes: 6 additions & 4 deletions src/clients/chunker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,18 @@ pub struct ChunkerClient {
}

impl ChunkerClient {
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await?;
Ok(Self { clients })
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await;
Self { clients }
}

fn client(&self, model_id: &str) -> Result<ChunkersServiceClient<LoadBalancedChannel>, Error> {
Ok(self
.clients
.get(model_id)
.ok_or_else(|| Error::ModelNotFound(model_id.into()))?
.ok_or_else(|| Error::ModelNotFound {
model_id: model_id.to_string(),
})?
.clone())
}

Expand Down
11 changes: 6 additions & 5 deletions src/clients/detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@ pub struct DetectorClient {
}

impl DetectorClient {
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
let clients: HashMap<String, HttpClient> =
create_http_clients(default_port, config).await?;
Ok(Self { clients })
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
let clients: HashMap<String, HttpClient> = create_http_clients(default_port, config).await;
Self { clients }
}

fn client(&self, model_id: &str) -> Result<HttpClient, Error> {
Ok(self
.clients
.get(model_id)
.ok_or_else(|| Error::ModelNotFound(model_id.into()))?
.ok_or_else(|| Error::ModelNotFound {
model_id: model_id.to_string(),
})?
.clone())
}

Expand Down
10 changes: 6 additions & 4 deletions src/clients/nlp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,18 @@ pub struct NlpClient {
}

impl NlpClient {
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await?;
Ok(Self { clients })
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await;
Self { clients }
}

fn client(&self, model_id: &str) -> Result<NlpServiceClient<LoadBalancedChannel>, Error> {
Ok(self
.clients
.get(model_id)
.ok_or_else(|| Error::ModelNotFound(model_id.into()))?
.ok_or_else(|| Error::ModelNotFound {
model_id: model_id.to_string(),
})?
.clone())
}

Expand Down
11 changes: 6 additions & 5 deletions src/clients/tgis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ pub struct TgisClient {
}

impl TgisClient {
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
let clients =
create_grpc_clients(default_port, config, GenerationServiceClient::new).await?;
Ok(Self { clients })
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
let clients = create_grpc_clients(default_port, config, GenerationServiceClient::new).await;
Self { clients }
}

fn client(
Expand All @@ -36,7 +35,9 @@ impl TgisClient {
Ok(self
.clients
.get(model_id)
.ok_or_else(|| Error::ModelNotFound(model_id.into()))?
.ok_or_else(|| Error::ModelNotFound {
model_id: model_id.to_string(),
})?
.clone())
}

Expand Down
9 changes: 6 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ impl OrchestratorConfig {
todo!()
}

pub fn get_chunker_id(&self, detector_id: &str) -> String {
self.detectors.get(detector_id).unwrap().chunker_id.clone()
pub fn get_chunker_id(&self, detector_id: &str) -> Option<String> {
self.detectors
.get(detector_id)
.map(|detector_config| detector_config.chunker_id.clone())
}
}

Expand All @@ -126,8 +128,9 @@ fn service_tls_name_to_config(

#[cfg(test)]
mod tests {
use anyhow::Error;

use super::*;
use crate::Error;

#[test]
fn test_deserialize_config() -> Result<(), Error> {
Expand Down
35 changes: 0 additions & 35 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,8 @@
#![allow(clippy::iter_kv_map, clippy::enum_variant_names)]

use axum::{http::StatusCode, Json};

mod clients;
mod config;
mod models;
mod orchestrator;
mod pb;
pub mod server;

#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
ClientError(#[from] crate::clients::Error),
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error(transparent)]
YamlError(#[from] serde_yml::Error),
}

// TODO: create better errors and properly convert
impl From<Error> for (StatusCode, Json<String>) {
fn from(value: Error) -> Self {
use Error::*;
match value {
ClientError(error) => match error {
clients::Error::ModelNotFound(message) => {
(StatusCode::UNPROCESSABLE_ENTITY, Json(message))
}
clients::Error::ReqwestError(error) => {
(StatusCode::INTERNAL_SERVER_ERROR, Json(error.to_string()))
}
clients::Error::TonicError(error) => {
(StatusCode::INTERNAL_SERVER_ERROR, Json(error.to_string()))
}
clients::Error::IoError(_) => todo!(),
},
IoError(_) => todo!(),
YamlError(_) => todo!(),
}
}
}
Loading