From b57dae0e9b4a8598b91f2a81bbe4362e323e1b0d Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 23 May 2024 17:07:19 -0600 Subject: [PATCH 01/12] :bulb::construction: Detector threshold comments Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/config.rs | 2 ++ src/orchestrator.rs | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/config.rs b/src/config.rs index 95c42baf..2619ac2c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -58,7 +58,9 @@ pub struct ChunkerConfig { pub struct DetectorConfig { pub service: ServiceConfig, pub chunker_id: String, + // Put threshold here _in_ config -> need to change type //pub config: HashMap, + // or threshold could be at this level but then would have to be optional } #[derive(Debug, Clone, Deserialize)] diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 4f6ec473..e1a36717 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -236,6 +236,7 @@ async fn detect( .map(|(detector_id, detector_params)| { let ctx = ctx.clone(); let detector_id = detector_id.clone(); + // Use default threshold here (from ctx?/detector config?) if not present in detector_params let detector_params = detector_params.clone(); let chunker_id = ctx.config @@ -343,6 +344,7 @@ async fn handle_detection_task( ?response, "received detector response" ); + // Filter results based on threshold (if applicable) here let results = response .detections .into_iter() From 7a5837cc02a09d7ff9dd63b6770ab510712dc1ba Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Fri, 24 May 2024 12:29:28 -0600 Subject: [PATCH 02/12] :recycle::goal_net: Refactor validation error Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Co-authored-by: declark1 --- src/models.rs | 18 +++++++++++++----- src/server.rs | 6 ++++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/models.rs b/src/models.rs index 276cb143..f3bd40f4 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,6 +1,6 @@ #![allow(unused_qualifications)] -use crate::{pb, server}; +use crate::pb; use std::collections::HashMap; pub type DetectorParams = HashMap; @@ -28,15 +28,23 @@ pub struct GuardrailsHttpRequest { pub text_gen_parameters: Option, } +#[derive(Debug, thiserror::Error)] +pub enum ValidationError { + #[error("`{0}` is required")] + Required(String), + #[error("{0}")] + Invalid(String), +} + impl GuardrailsHttpRequest { /// Upfront validation of user request - pub fn validate(&self) -> Result<(), server::Error> { + pub fn validate(&self) -> Result<(), ValidationError> { // Validate required parameters if self.model_id.is_empty() { - return Err(server::Error::Validation("`model_id` is required".into())); + return Err(ValidationError::Required("model_id".into())); } if self.inputs.is_empty() { - return Err(server::Error::Validation("`inputs` is required".into())); + return Err(ValidationError::Required("inputs".into())); } // Validate masks let input_range = 0..self.inputs.len(); @@ -48,7 +56,7 @@ impl GuardrailsHttpRequest { if !input_masks.iter().all(|(start, end)| { input_range.contains(start) && input_range.contains(end) && start < end }) { - return Err(server::Error::Validation("invalid masks".into())); + return Err(ValidationError::Invalid("invalid masks".into())); } } Ok(()) diff --git a/src/server.rs b/src/server.rs index 61468815..028294f2 100644 --- a/src/server.rs +++ b/src/server.rs @@ -183,3 +183,9 @@ impl IntoResponse for Error { (code, Json(error)).into_response() } } + +impl From for Error { + fn from(value: models::ValidationError) -> Self { + Self::Validation(value.to_string()) + } +} From 398250f18bf690daa963600e185463bc36d2ccf4 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Fri, 24 May 2024 13:07:55 -0600 Subject: [PATCH 03/12] :label::memo: Add detector config params and type descriptions Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/config.rs | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/config.rs b/src/config.rs index 2619ac2c..b023084f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -6,6 +6,8 @@ use std::{ use serde::Deserialize; use tracing::debug; +/// Configuration for service needed for +/// orchestrator to communicate with it #[derive(Debug, Clone, Deserialize)] pub struct ServiceConfig { pub hostname: String, @@ -13,6 +15,7 @@ pub struct ServiceConfig { pub tls: Option, } +/// TLS provider #[derive(Debug, Clone, Deserialize)] #[serde(untagged)] pub enum Tls { @@ -20,6 +23,7 @@ pub enum Tls { Config(TlsConfig), } +/// Client TLS configuration #[derive(Debug, Clone, Deserialize)] pub struct TlsConfig { pub cert_path: Option, @@ -27,6 +31,7 @@ pub struct TlsConfig { pub client_ca_cert_path: Option, } +/// Generation service provider #[derive(Debug, Clone, Copy, Deserialize)] #[serde(rename_all = "lowercase")] pub enum GenerationProvider { @@ -34,12 +39,16 @@ pub enum GenerationProvider { Nlp, } +/// Generate service configuration #[derive(Debug, Clone, Deserialize)] pub struct GenerationConfig { + /// Generation service provider pub provider: GenerationProvider, + /// Generation service connection information pub service: ServiceConfig, } +/// Chunker parser type #[derive(Debug, Clone, Copy, Deserialize)] #[serde(rename_all = "lowercase")] pub enum ChunkerType { @@ -47,31 +56,50 @@ pub enum ChunkerType { All, } +/// Configuration for each chunker #[allow(dead_code)] #[derive(Debug, Clone, Deserialize)] pub struct ChunkerConfig { + /// Chunker type pub r#type: ChunkerType, + /// Chunker service connection information pub service: ServiceConfig, } +/// Configuration parameters applicable to each detector +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct DetectorConfigParams { + /// Default threshold with which to filter detector results by score + pub default_threshold: Option, +} + +/// Configuration for each detector #[derive(Debug, Clone, Deserialize)] pub struct DetectorConfig { + /// Detector service connection information pub service: ServiceConfig, + /// ID of chunker that this detector will use pub chunker_id: String, - // Put threshold here _in_ config -> need to change type - //pub config: HashMap, - // or threshold could be at this level but then would have to be optional + /// Optional detector configuration parameters + pub config: Option, } +/// Overall orchestrator server configuration #[derive(Debug, Clone, Deserialize)] pub struct OrchestratorConfig { + /// Generation service and associated configuration pub generation: GenerationConfig, + /// Chunker services and associated configurations pub chunkers: HashMap, + /// Detector services and associated configurations pub detectors: HashMap, + /// Map of TLS connections, allowing reuse across services + /// that may require the same TLS information pub tls: HashMap, } impl OrchestratorConfig { + /// Load overall orchestrator server configuration pub async fn load(path: impl AsRef) -> Self { let path = path.as_ref(); let s = tokio::fs::read_to_string(path) @@ -110,6 +138,7 @@ impl OrchestratorConfig { todo!() } + /// Get ID of chunker associated with a particular detector pub fn get_chunker_id(&self, detector_id: &str) -> Option { self.detectors .get(detector_id) From 85dec7bb0b9d0e235350c06137568c4ebc09c918 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Fri, 24 May 2024 16:22:29 -0600 Subject: [PATCH 04/12] :wrench: Update config yaml Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- config/config.yaml | 3 ++- src/config.rs | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 7e1a4b63..fd7491ef 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -17,7 +17,8 @@ detectors: port: 8080 tls: caikit chunker_id: en_regex - config: {} + config: + detector_threshold: 0.5 tls: caikit: cert_path: /path/to/tls.crt diff --git a/src/config.rs b/src/config.rs index b023084f..28a939ed 100644 --- a/src/config.rs +++ b/src/config.rs @@ -188,7 +188,8 @@ detectors: hostname: localhost port: 9000 chunker_id: sentence-en - config: {} + config: + default_threshold: 0.5 tls: {} "#; let config: OrchestratorConfig = serde_yml::from_str(s)?; From 9537e9ad91e3c3d7f23e5d9598576d45fe95ff85 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Fri, 24 May 2024 17:26:01 -0600 Subject: [PATCH 05/12] :sparkles: Get default threshold from config Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/config.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/config.rs b/src/config.rs index 28a939ed..6078bbf4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -144,6 +144,16 @@ impl OrchestratorConfig { .get(detector_id) .map(|detector_config| detector_config.chunker_id.clone()) } + + /// Get default threshold of a particular detector + pub fn get_default_threshold(&self, detector_id: &str) -> Option { + self.detectors.get(detector_id).and_then(|detector_config| { + detector_config + .config + .as_ref() + .and_then(|config| config.default_threshold) + }) + } } fn service_tls_name_to_config( From fb73c0e6cfd944f5e237cf6169a082df3b86bb36 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Tue, 28 May 2024 09:41:07 -0600 Subject: [PATCH 06/12] :bug::wrench: Use default threshold Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- config/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config.yaml b/config/config.yaml index fd7491ef..1434cf29 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -18,7 +18,7 @@ detectors: tls: caikit chunker_id: en_regex config: - detector_threshold: 0.5 + default_threshold: 0.5 tls: caikit: cert_path: /path/to/tls.crt From 513b25dbf8ddbf8fe20f03316f062f0e4a214e25 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Tue, 28 May 2024 16:32:55 -0600 Subject: [PATCH 07/12] :sparkles: Pass along default threshold in detector params Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/orchestrator.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/orchestrator.rs b/src/orchestrator.rs index c5004ca8..94154aef 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -237,8 +237,13 @@ async fn detect( .map(|(detector_id, detector_params)| { let ctx = ctx.clone(); let detector_id = detector_id.clone(); - // Use default threshold here (from ctx?/detector config?) if not present in detector_params - let detector_params = detector_params.clone(); + let mut detector_params = detector_params.clone(); + if let Some(default_threshold) = ctx.config.get_default_threshold(&detector_id) { + // Use a default threshold if threshold is not provided by the user + detector_params + .entry("threshold".into()) + .or_insert(default_threshold.into()); + } let chunker_id = ctx.config .get_chunker_id(&detector_id) @@ -326,6 +331,9 @@ async fn handle_detection_task( let detector_id = detector_id.clone(); let detector_params = detector_params.clone(); async move { + // NOTE: The detector request is expected to change and not actually + // take parameters. However, any parameters will be ignored for now + // ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37 let request = DetectorRequest::new(chunk.text.clone(), detector_params); debug!( %detector_id, From 31bbdbd31037705063fdc777a25f6c1a2403a324 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Tue, 28 May 2024 16:52:01 -0600 Subject: [PATCH 08/12] :construction: Filtering based on threshold Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/orchestrator.rs | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 94154aef..bb384470 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -334,7 +334,7 @@ async fn handle_detection_task( // NOTE: The detector request is expected to change and not actually // take parameters. However, any parameters will be ignored for now // ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37 - let request = DetectorRequest::new(chunk.text.clone(), detector_params); + let request = DetectorRequest::new(chunk.text.clone(), detector_params.clone()); debug!( %detector_id, ?request, @@ -357,11 +357,20 @@ async fn handle_detection_task( let results = response .detections .into_iter() - .map(|detection| { + .filter_map(|detection| { let mut result: TokenClassificationResult = detection.into(); result.start += chunk.offset as u32; result.end += chunk.offset as u32; - result + if let Some(threshold_value) = detector_params.get("threshold") { + if let Some(threshold) = threshold_value.as_f64() { + if result.score >= threshold { + return Some(result); + } + } + } else { + return Some(result); + } + None }) .collect::>(); Ok::, Error>(results) From 448c6a1958a2b5edbf8485f11f62a83832b64644 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 29 May 2024 12:20:21 -0600 Subject: [PATCH 09/12] :recycle: Reduce nested ifs Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Co-authored-by: declark1 --- src/orchestrator.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/orchestrator.rs b/src/orchestrator.rs index bb384470..63f95f92 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -361,16 +361,12 @@ async fn handle_detection_task( let mut result: TokenClassificationResult = detection.into(); result.start += chunk.offset as u32; result.end += chunk.offset as u32; - if let Some(threshold_value) = detector_params.get("threshold") { - if let Some(threshold) = threshold_value.as_f64() { - if result.score >= threshold { - return Some(result); - } - } + let threshold = detector_params.get("threshold").and_then(|v| v.as_f64()); + if threshold.is_some_and(|value| result.score < value) { + None } else { - return Some(result); + Some(result) } - None }) .collect::>(); Ok::, Error>(results) From 0655a7df614cdafa2238cf8743affc1897fc87dc Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 30 May 2024 12:18:53 -0600 Subject: [PATCH 10/12] :wrench: Change default threshold under detector Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- config/config.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 1434cf29..db040233 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -17,8 +17,7 @@ detectors: port: 8080 tls: caikit chunker_id: en_regex - config: - default_threshold: 0.5 + default_threshold: 0.5 tls: caikit: cert_path: /path/to/tls.crt From 9cce8b057dc74ff818f3939e5bbc07dad0c1e7bc Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 30 May 2024 12:33:27 -0600 Subject: [PATCH 11/12] :recycle: Require default_threshold for detectors Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/config.rs | 23 ++++++----------------- src/models.rs | 4 ++++ src/orchestrator.rs | 33 ++++++++++++++++++--------------- 3 files changed, 28 insertions(+), 32 deletions(-) diff --git a/src/config.rs b/src/config.rs index 6078bbf4..441318c5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -66,13 +66,6 @@ pub struct ChunkerConfig { pub service: ServiceConfig, } -/// Configuration parameters applicable to each detector -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct DetectorConfigParams { - /// Default threshold with which to filter detector results by score - pub default_threshold: Option, -} - /// Configuration for each detector #[derive(Debug, Clone, Deserialize)] pub struct DetectorConfig { @@ -80,8 +73,8 @@ pub struct DetectorConfig { pub service: ServiceConfig, /// ID of chunker that this detector will use pub chunker_id: String, - /// Optional detector configuration parameters - pub config: Option, + /// Default threshold with which to filter detector results by score + pub default_threshold: f32, } /// Overall orchestrator server configuration @@ -147,12 +140,9 @@ impl OrchestratorConfig { /// Get default threshold of a particular detector pub fn get_default_threshold(&self, detector_id: &str) -> Option { - self.detectors.get(detector_id).and_then(|detector_config| { - detector_config - .config - .as_ref() - .and_then(|config| config.default_threshold) - }) + self.detectors + .get(detector_id) + .map(|detector_config| detector_config.default_threshold) } } @@ -198,8 +188,7 @@ detectors: hostname: localhost port: 9000 chunker_id: sentence-en - config: - default_threshold: 0.5 + default_threshold: 0.5 tls: {} "#; let config: OrchestratorConfig = serde_yml::from_str(s)?; diff --git a/src/models.rs b/src/models.rs index f3bd40f4..5b48522c 100644 --- a/src/models.rs +++ b/src/models.rs @@ -3,6 +3,10 @@ use crate::pb; use std::collections::HashMap; +// TODO: When detector API is updated, consider if fields +// like 'threshold' can be named options instead of the +// use a generic HashMap with Values here +// ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37 pub type DetectorParams = HashMap; /// User request to orchestrator diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 63f95f92..00fb8b49 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -237,13 +237,15 @@ async fn detect( .map(|(detector_id, detector_params)| { let ctx = ctx.clone(); let detector_id = detector_id.clone(); - let mut detector_params = detector_params.clone(); - if let Some(default_threshold) = ctx.config.get_default_threshold(&detector_id) { - // Use a default threshold if threshold is not provided by the user - detector_params - .entry("threshold".into()) - .or_insert(default_threshold.into()); - } + let detector_params = detector_params.clone(); + // Get the default threshold to use if threshold is not provided by the user + let default_threshold = + ctx.config + .get_default_threshold(&detector_id) + .ok_or_else(|| Error::DetectorNotFound { + detector_id: detector_id.clone(), + })?; + // Get chunker for detector let chunker_id = ctx.config .get_chunker_id(&detector_id) @@ -252,7 +254,8 @@ async fn detect( })?; let chunks = chunks.get(&chunker_id).unwrap().clone(); Ok(tokio::spawn(async move { - handle_detection_task(ctx, detector_id, detector_params, chunks).await + handle_detection_task(ctx, detector_id, default_threshold, detector_params, chunks) + .await })) }) .collect::, Error>>()?; @@ -322,6 +325,7 @@ async fn handle_chunk_task( async fn handle_detection_task( ctx: Arc, detector_id: String, + default_threshold: f32, detector_params: DetectorParams, chunks: Vec, ) -> Result, Error> { @@ -332,7 +336,7 @@ async fn handle_detection_task( let detector_params = detector_params.clone(); async move { // NOTE: The detector request is expected to change and not actually - // take parameters. However, any parameters will be ignored for now + // take parameters. Any parameters will be ignored for now // ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37 let request = DetectorRequest::new(chunk.text.clone(), detector_params.clone()); debug!( @@ -361,12 +365,11 @@ async fn handle_detection_task( let mut result: TokenClassificationResult = detection.into(); result.start += chunk.offset as u32; result.end += chunk.offset as u32; - let threshold = detector_params.get("threshold").and_then(|v| v.as_f64()); - if threshold.is_some_and(|value| result.score < value) { - None - } else { - Some(result) - } + let threshold = detector_params + .get("threshold") + .and_then(|v| v.as_f64()) + .unwrap_or(default_threshold as f64); + (result.score >= threshold).then_some(result) }) .collect::>(); Ok::, Error>(results) From 41de8f9bf98b400e95a0ed9492241b76e45c6b4f Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 30 May 2024 13:34:22 -0600 Subject: [PATCH 12/12] :recycle: Fetch detector config Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Co-authored-by: declark1 --- src/config.rs | 7 ------- src/orchestrator.rs | 18 ++++++++---------- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/src/config.rs b/src/config.rs index 441318c5..ef64b43e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -137,13 +137,6 @@ impl OrchestratorConfig { .get(detector_id) .map(|detector_config| detector_config.chunker_id.clone()) } - - /// Get default threshold of a particular detector - pub fn get_default_threshold(&self, detector_id: &str) -> Option { - self.detectors - .get(detector_id) - .map(|detector_config| detector_config.default_threshold) - } } fn service_tls_name_to_config( diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 00fb8b49..6328cd5a 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -238,21 +238,19 @@ async fn detect( let ctx = ctx.clone(); let detector_id = detector_id.clone(); let detector_params = detector_params.clone(); - // Get the default threshold to use if threshold is not provided by the user - let default_threshold = + // Get the detector config + let detector_config = ctx.config - .get_default_threshold(&detector_id) + .detectors + .get(&detector_id) .ok_or_else(|| Error::DetectorNotFound { detector_id: detector_id.clone(), })?; + // Get the default threshold to use if threshold is not provided by the user + let default_threshold = detector_config.default_threshold; // Get chunker for detector - let chunker_id = - ctx.config - .get_chunker_id(&detector_id) - .ok_or_else(|| Error::DetectorNotFound { - detector_id: detector_id.clone(), - })?; - let chunks = chunks.get(&chunker_id).unwrap().clone(); + let chunker_id = detector_config.chunker_id.as_str(); + let chunks = chunks.get(chunker_id).unwrap().clone(); Ok(tokio::spawn(async move { handle_detection_task(ctx, detector_id, default_threshold, detector_params, chunks) .await