diff --git a/src/config.rs b/src/config.rs index 6078bbf4..d46b4991 100644 --- a/src/config.rs +++ b/src/config.rs @@ -66,13 +66,6 @@ pub struct ChunkerConfig { pub service: ServiceConfig, } -/// Configuration parameters applicable to each detector -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct DetectorConfigParams { - /// Default threshold with which to filter detector results by score - pub default_threshold: Option, -} - /// Configuration for each detector #[derive(Debug, Clone, Deserialize)] pub struct DetectorConfig { @@ -80,8 +73,8 @@ pub struct DetectorConfig { pub service: ServiceConfig, /// ID of chunker that this detector will use pub chunker_id: String, - /// Optional detector configuration parameters - pub config: Option, + /// Default threshold with which to filter detector results by score + pub default_threshold: f32, } /// Overall orchestrator server configuration @@ -147,12 +140,9 @@ impl OrchestratorConfig { /// Get default threshold of a particular detector pub fn get_default_threshold(&self, detector_id: &str) -> Option { - self.detectors.get(detector_id).and_then(|detector_config| { - detector_config - .config - .as_ref() - .and_then(|config| config.default_threshold) - }) + self.detectors + .get(detector_id) + .map(|detector_config| detector_config.default_threshold) } } diff --git a/src/main.rs b/src/main.rs index 25ad39ed..9e98aae0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,7 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilte #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { - #[clap(default_value = "8033", long, env)] + #[clap(default_value = "8081", long, env)] http_port: u16, #[clap(long, env)] json_output: bool, diff --git a/src/models.rs b/src/models.rs index f3bd40f4..5b48522c 100644 --- a/src/models.rs +++ b/src/models.rs @@ -3,6 +3,10 @@ use crate::pb; use std::collections::HashMap; +// TODO: When detector API is updated, consider if fields +// like 'threshold' can be named options instead of the +// use a generic HashMap with Values here +// ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37 pub type DetectorParams = HashMap; /// User request to orchestrator diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 63f95f92..a31b4d58 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -237,13 +237,15 @@ async fn detect( .map(|(detector_id, detector_params)| { let ctx = ctx.clone(); let detector_id = detector_id.clone(); - let mut detector_params = detector_params.clone(); - if let Some(default_threshold) = ctx.config.get_default_threshold(&detector_id) { - // Use a default threshold if threshold is not provided by the user - detector_params - .entry("threshold".into()) - .or_insert(default_threshold.into()); - } + let detector_params: HashMap = detector_params.clone(); + // Get the default threshold to use if threshold is not provided by the user + let default_threshold = + ctx.config + .get_default_threshold(&detector_id) + .ok_or_else(|| Error::DetectorNotFound { + detector_id: detector_id.clone(), + })?; + // Get chunker for detector let chunker_id = ctx.config .get_chunker_id(&detector_id) @@ -252,7 +254,8 @@ async fn detect( })?; let chunks = chunks.get(&chunker_id).unwrap().clone(); Ok(tokio::spawn(async move { - handle_detection_task(ctx, detector_id, detector_params, chunks).await + handle_detection_task(ctx, detector_id, default_threshold, detector_params, chunks) + .await })) }) .collect::, Error>>()?; @@ -322,6 +325,7 @@ async fn handle_chunk_task( async fn handle_detection_task( ctx: Arc, detector_id: String, + default_threshold: f32, detector_params: DetectorParams, chunks: Vec, ) -> Result, Error> { @@ -332,7 +336,7 @@ async fn handle_detection_task( let detector_params = detector_params.clone(); async move { // NOTE: The detector request is expected to change and not actually - // take parameters. However, any parameters will be ignored for now + // take parameters. Any parameters will be ignored for now // ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37 let request = DetectorRequest::new(chunk.text.clone(), detector_params.clone()); debug!( @@ -361,12 +365,11 @@ async fn handle_detection_task( let mut result: TokenClassificationResult = detection.into(); result.start += chunk.offset as u32; result.end += chunk.offset as u32; - let threshold = detector_params.get("threshold").and_then(|v| v.as_f64()); - if threshold.is_some_and(|value| result.score < value) { - None - } else { - Some(result) - } + let threshold = detector_params + .get("threshold") + .and_then(|v| v.as_f64()) + .unwrap_or(default_threshold as f64); + (result.score >= threshold).then_some(result) }) .collect::>(); Ok::, Error>(results)