From 9cce8b057dc74ff818f3939e5bbc07dad0c1e7bc Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 30 May 2024 12:33:27 -0600 Subject: [PATCH] :recycle: Require default_threshold for detectors Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/config.rs | 23 ++++++----------------- src/models.rs | 4 ++++ src/orchestrator.rs | 33 ++++++++++++++++++--------------- 3 files changed, 28 insertions(+), 32 deletions(-) diff --git a/src/config.rs b/src/config.rs index 6078bbf4..441318c5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -66,13 +66,6 @@ pub struct ChunkerConfig { pub service: ServiceConfig, } -/// Configuration parameters applicable to each detector -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct DetectorConfigParams { - /// Default threshold with which to filter detector results by score - pub default_threshold: Option, -} - /// Configuration for each detector #[derive(Debug, Clone, Deserialize)] pub struct DetectorConfig { @@ -80,8 +73,8 @@ pub struct DetectorConfig { pub service: ServiceConfig, /// ID of chunker that this detector will use pub chunker_id: String, - /// Optional detector configuration parameters - pub config: Option, + /// Default threshold with which to filter detector results by score + pub default_threshold: f32, } /// Overall orchestrator server configuration @@ -147,12 +140,9 @@ impl OrchestratorConfig { /// Get default threshold of a particular detector pub fn get_default_threshold(&self, detector_id: &str) -> Option { - self.detectors.get(detector_id).and_then(|detector_config| { - detector_config - .config - .as_ref() - .and_then(|config| config.default_threshold) - }) + self.detectors + .get(detector_id) + .map(|detector_config| detector_config.default_threshold) } } @@ -198,8 +188,7 @@ detectors: hostname: localhost port: 9000 chunker_id: sentence-en - config: - default_threshold: 0.5 + default_threshold: 0.5 tls: {} "#; let config: OrchestratorConfig = serde_yml::from_str(s)?; diff --git a/src/models.rs b/src/models.rs index f3bd40f4..5b48522c 100644 --- a/src/models.rs +++ b/src/models.rs @@ -3,6 +3,10 @@ use crate::pb; use std::collections::HashMap; +// TODO: When detector API is updated, consider if fields +// like 'threshold' can be named options instead of the +// use a generic HashMap with Values here +// ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37 pub type DetectorParams = HashMap; /// User request to orchestrator diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 63f95f92..00fb8b49 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -237,13 +237,15 @@ async fn detect( .map(|(detector_id, detector_params)| { let ctx = ctx.clone(); let detector_id = detector_id.clone(); - let mut detector_params = detector_params.clone(); - if let Some(default_threshold) = ctx.config.get_default_threshold(&detector_id) { - // Use a default threshold if threshold is not provided by the user - detector_params - .entry("threshold".into()) - .or_insert(default_threshold.into()); - } + let detector_params = detector_params.clone(); + // Get the default threshold to use if threshold is not provided by the user + let default_threshold = + ctx.config + .get_default_threshold(&detector_id) + .ok_or_else(|| Error::DetectorNotFound { + detector_id: detector_id.clone(), + })?; + // Get chunker for detector let chunker_id = ctx.config .get_chunker_id(&detector_id) @@ -252,7 +254,8 @@ async fn detect( })?; let chunks = chunks.get(&chunker_id).unwrap().clone(); Ok(tokio::spawn(async move { - handle_detection_task(ctx, detector_id, detector_params, chunks).await + handle_detection_task(ctx, detector_id, default_threshold, detector_params, chunks) + .await })) }) .collect::, Error>>()?; @@ -322,6 +325,7 @@ async fn handle_chunk_task( async fn handle_detection_task( ctx: Arc, detector_id: String, + default_threshold: f32, detector_params: DetectorParams, chunks: Vec, ) -> Result, Error> { @@ -332,7 +336,7 @@ async fn handle_detection_task( let detector_params = detector_params.clone(); async move { // NOTE: The detector request is expected to change and not actually - // take parameters. However, any parameters will be ignored for now + // take parameters. Any parameters will be ignored for now // ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37 let request = DetectorRequest::new(chunk.text.clone(), detector_params.clone()); debug!( @@ -361,12 +365,11 @@ async fn handle_detection_task( let mut result: TokenClassificationResult = detection.into(); result.start += chunk.offset as u32; result.end += chunk.offset as u32; - let threshold = detector_params.get("threshold").and_then(|v| v.as_f64()); - if threshold.is_some_and(|value| result.score < value) { - None - } else { - Some(result) - } + let threshold = detector_params + .get("threshold") + .and_then(|v| v.as_f64()) + .unwrap_or(default_threshold as f64); + (result.score >= threshold).then_some(result) }) .collect::>(); Ok::, Error>(results)