Skip to content

Commit

Permalink
♻️ Require default_threshold for detectors
Browse files Browse the repository at this point in the history
Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>
  • Loading branch information
evaline-ju committed May 30, 2024
1 parent 0655a7d commit 9cce8b0
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 32 deletions.
23 changes: 6 additions & 17 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,15 @@ pub struct ChunkerConfig {
pub service: ServiceConfig,
}

/// Configuration parameters applicable to each detector
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DetectorConfigParams {
/// Default threshold with which to filter detector results by score
pub default_threshold: Option<f32>,
}

/// Configuration for each detector
#[derive(Debug, Clone, Deserialize)]
pub struct DetectorConfig {
/// Detector service connection information
pub service: ServiceConfig,
/// ID of chunker that this detector will use
pub chunker_id: String,
/// Optional detector configuration parameters
pub config: Option<DetectorConfigParams>,
/// Default threshold with which to filter detector results by score
pub default_threshold: f32,
}

/// Overall orchestrator server configuration
Expand Down Expand Up @@ -147,12 +140,9 @@ impl OrchestratorConfig {

/// Get default threshold of a particular detector
pub fn get_default_threshold(&self, detector_id: &str) -> Option<f32> {
self.detectors.get(detector_id).and_then(|detector_config| {
detector_config
.config
.as_ref()
.and_then(|config| config.default_threshold)
})
self.detectors
.get(detector_id)
.map(|detector_config| detector_config.default_threshold)
}
}

Expand Down Expand Up @@ -198,8 +188,7 @@ detectors:
hostname: localhost
port: 9000
chunker_id: sentence-en
config:
default_threshold: 0.5
default_threshold: 0.5
tls: {}
"#;
let config: OrchestratorConfig = serde_yml::from_str(s)?;
Expand Down
4 changes: 4 additions & 0 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
use crate::pb;
use std::collections::HashMap;

// TODO: When detector API is updated, consider if fields
// like 'threshold' can be named options instead of the
// use a generic HashMap with Values here
// ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37
pub type DetectorParams = HashMap<String, serde_json::Value>;

/// User request to orchestrator
Expand Down
33 changes: 18 additions & 15 deletions src/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,15 @@ async fn detect(
.map(|(detector_id, detector_params)| {
let ctx = ctx.clone();
let detector_id = detector_id.clone();
let mut detector_params = detector_params.clone();
if let Some(default_threshold) = ctx.config.get_default_threshold(&detector_id) {
// Use a default threshold if threshold is not provided by the user
detector_params
.entry("threshold".into())
.or_insert(default_threshold.into());
}
let detector_params = detector_params.clone();
// Get the default threshold to use if threshold is not provided by the user
let default_threshold =
ctx.config
.get_default_threshold(&detector_id)
.ok_or_else(|| Error::DetectorNotFound {
detector_id: detector_id.clone(),
})?;
// Get chunker for detector
let chunker_id =
ctx.config
.get_chunker_id(&detector_id)
Expand All @@ -252,7 +254,8 @@ async fn detect(
})?;
let chunks = chunks.get(&chunker_id).unwrap().clone();
Ok(tokio::spawn(async move {
handle_detection_task(ctx, detector_id, detector_params, chunks).await
handle_detection_task(ctx, detector_id, default_threshold, detector_params, chunks)
.await
}))
})
.collect::<Result<Vec<_>, Error>>()?;
Expand Down Expand Up @@ -322,6 +325,7 @@ async fn handle_chunk_task(
async fn handle_detection_task(
ctx: Arc<Context>,
detector_id: String,
default_threshold: f32,
detector_params: DetectorParams,
chunks: Vec<Chunk>,
) -> Result<Vec<TokenClassificationResult>, Error> {
Expand All @@ -332,7 +336,7 @@ async fn handle_detection_task(
let detector_params = detector_params.clone();
async move {
// NOTE: The detector request is expected to change and not actually
// take parameters. However, any parameters will be ignored for now
// take parameters. Any parameters will be ignored for now
// ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37
let request = DetectorRequest::new(chunk.text.clone(), detector_params.clone());
debug!(
Expand Down Expand Up @@ -361,12 +365,11 @@ async fn handle_detection_task(
let mut result: TokenClassificationResult = detection.into();
result.start += chunk.offset as u32;
result.end += chunk.offset as u32;
let threshold = detector_params.get("threshold").and_then(|v| v.as_f64());
if threshold.is_some_and(|value| result.score < value) {
None
} else {
Some(result)
}
let threshold = detector_params
.get("threshold")
.and_then(|v| v.as_f64())
.unwrap_or(default_threshold as f64);
(result.score >= threshold).then_some(result)
})
.collect::<Vec<_>>();
Ok::<Vec<TokenClassificationResult>, Error>(results)
Expand Down

0 comments on commit 9cce8b0

Please sign in to comment.