Skip to content

Commit

Permalink
✨ Add thresholding for detector results (#52)
Browse files Browse the repository at this point in the history
* 💡🚧 Detector threshold comments

Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>

* ♻️🥅 Refactor validation error

Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>

Co-authored-by: declark1 <daniel.clark@ibm.com>

* 🏷️📝 Add detector config params and type descriptions

Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>

* 🔧 Update config yaml

Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>

* ✨ Get default threshold from config

Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>

* 🐛🔧 Use default threshold

Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>

* ✨ Pass along default threshold in detector params

Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>

* 🚧 Filtering based on threshold

Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>

* ♻️ Reduce nested ifs

Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>

Co-authored-by: declark1 <daniel.clark@ibm.com>

* 🔧 Change default threshold under detector

Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>

* ♻️ Require default_threshold for detectors

Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>

* ♻️ Fetch detector config

Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>

Co-authored-by: declark1 <daniel.clark@ibm.com>

---------

Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>
Co-authored-by: declark1 <daniel.clark@ibm.com>
  • Loading branch information
evaline-ju and declark1 authored May 30, 2024
1 parent 63b5d85 commit d29f6c4
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 15 deletions.
2 changes: 1 addition & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ detectors:
port: 8080
tls: caikit
chunker_id: en_regex
config: {}
default_threshold: 0.5
tls:
caikit:
cert_path: /path/to/tls.crt
Expand Down
28 changes: 26 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,70 +6,93 @@ use std::{
use serde::Deserialize;
use tracing::debug;

/// Configuration for service needed for
/// orchestrator to communicate with it
#[derive(Debug, Clone, Deserialize)]
pub struct ServiceConfig {
pub hostname: String,
pub port: Option<u16>,
pub tls: Option<Tls>,
}

/// TLS provider
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum Tls {
Name(String),
Config(TlsConfig),
}

/// Client TLS configuration
#[derive(Debug, Clone, Deserialize)]
pub struct TlsConfig {
pub cert_path: Option<PathBuf>,
pub key_path: Option<PathBuf>,
pub client_ca_cert_path: Option<PathBuf>,
}

/// Generation service provider
#[derive(Debug, Clone, Copy, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum GenerationProvider {
Tgis,
Nlp,
}

/// Generate service configuration
#[derive(Debug, Clone, Deserialize)]
pub struct GenerationConfig {
/// Generation service provider
pub provider: GenerationProvider,
/// Generation service connection information
pub service: ServiceConfig,
}

/// Chunker parser type
#[derive(Debug, Clone, Copy, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChunkerType {
Sentence,
All,
}

/// Configuration for each chunker
#[allow(dead_code)]
#[derive(Debug, Clone, Deserialize)]
pub struct ChunkerConfig {
/// Chunker type
pub r#type: ChunkerType,
/// Chunker service connection information
pub service: ServiceConfig,
}

/// Configuration for each detector
#[derive(Debug, Clone, Deserialize)]
pub struct DetectorConfig {
/// Detector service connection information
pub service: ServiceConfig,
/// ID of chunker that this detector will use
pub chunker_id: String,
//pub config: HashMap<String, String>,
/// Default threshold with which to filter detector results by score
pub default_threshold: f32,
}

/// Overall orchestrator server configuration
#[derive(Debug, Clone, Deserialize)]
pub struct OrchestratorConfig {
/// Generation service and associated configuration
pub generation: GenerationConfig,
/// Chunker services and associated configurations
pub chunkers: HashMap<String, ChunkerConfig>,
/// Detector services and associated configurations
pub detectors: HashMap<String, DetectorConfig>,
/// Map of TLS connections, allowing reuse across services
/// that may require the same TLS information
pub tls: HashMap<String, TlsConfig>,
}

impl OrchestratorConfig {
/// Load overall orchestrator server configuration
pub async fn load(path: impl AsRef<Path>) -> Self {
let path = path.as_ref();
let s = tokio::fs::read_to_string(path)
Expand Down Expand Up @@ -108,6 +131,7 @@ impl OrchestratorConfig {
todo!()
}

/// Get ID of chunker associated with a particular detector
pub fn get_chunker_id(&self, detector_id: &str) -> Option<String> {
self.detectors
.get(detector_id)
Expand Down Expand Up @@ -157,7 +181,7 @@ detectors:
hostname: localhost
port: 9000
chunker_id: sentence-en
config: {}
default_threshold: 0.5
tls: {}
"#;
let config: OrchestratorConfig = serde_yml::from_str(s)?;
Expand Down
22 changes: 17 additions & 5 deletions src/models.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#![allow(unused_qualifications)]

use crate::{pb, server};
use crate::pb;
use std::collections::HashMap;

// TODO: When detector API is updated, consider if fields
// like 'threshold' can be named options instead of the
// use a generic HashMap with Values here
// ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37
pub type DetectorParams = HashMap<String, serde_json::Value>;

/// User request to orchestrator
Expand All @@ -28,15 +32,23 @@ pub struct GuardrailsHttpRequest {
pub text_gen_parameters: Option<GuardrailsTextGenerationParameters>,
}

#[derive(Debug, thiserror::Error)]
pub enum ValidationError {
#[error("`{0}` is required")]
Required(String),
#[error("{0}")]
Invalid(String),
}

impl GuardrailsHttpRequest {
/// Upfront validation of user request
pub fn validate(&self) -> Result<(), server::Error> {
pub fn validate(&self) -> Result<(), ValidationError> {
// Validate required parameters
if self.model_id.is_empty() {
return Err(server::Error::Validation("`model_id` is required".into()));
return Err(ValidationError::Required("model_id".into()));
}
if self.inputs.is_empty() {
return Err(server::Error::Validation("`inputs` is required".into()));
return Err(ValidationError::Required("inputs".into()));
}
// Validate masks
let input_range = 0..self.inputs.len();
Expand All @@ -48,7 +60,7 @@ impl GuardrailsHttpRequest {
if !input_masks.iter().all(|(start, end)| {
input_range.contains(start) && input_range.contains(end) && start < end
}) {
return Err(server::Error::Validation("invalid masks".into()));
return Err(ValidationError::Invalid("invalid masks".into()));
}
}
Ok(())
Expand Down
30 changes: 23 additions & 7 deletions src/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,22 @@ async fn detect(
let ctx = ctx.clone();
let detector_id = detector_id.clone();
let detector_params = detector_params.clone();
let chunker_id =
// Get the detector config
let detector_config =
ctx.config
.get_chunker_id(&detector_id)
.detectors
.get(&detector_id)
.ok_or_else(|| Error::DetectorNotFound {
detector_id: detector_id.clone(),
})?;
let chunks = chunks.get(&chunker_id).unwrap().clone();
// Get the default threshold to use if threshold is not provided by the user
let default_threshold = detector_config.default_threshold;
// Get chunker for detector
let chunker_id = detector_config.chunker_id.as_str();
let chunks = chunks.get(chunker_id).unwrap().clone();
Ok(tokio::spawn(async move {
handle_detection_task(ctx, detector_id, detector_params, chunks).await
handle_detection_task(ctx, detector_id, default_threshold, detector_params, chunks)
.await
}))
})
.collect::<Result<Vec<_>, Error>>()?;
Expand Down Expand Up @@ -316,6 +323,7 @@ async fn handle_chunk_task(
async fn handle_detection_task(
ctx: Arc<Context>,
detector_id: String,
default_threshold: f32,
detector_params: DetectorParams,
chunks: Vec<Chunk>,
) -> Result<Vec<TokenClassificationResult>, Error> {
Expand All @@ -325,7 +333,10 @@ async fn handle_detection_task(
let detector_id = detector_id.clone();
let detector_params = detector_params.clone();
async move {
let request = DetectorRequest::new(chunk.text.clone(), detector_params);
// NOTE: The detector request is expected to change and not actually
// take parameters. Any parameters will be ignored for now
// ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37
let request = DetectorRequest::new(chunk.text.clone(), detector_params.clone());
debug!(
%detector_id,
?request,
Expand All @@ -344,14 +355,19 @@ async fn handle_detection_task(
?response,
"received detector response"
);
// Filter results based on threshold (if applicable) here
let results = response
.detections
.into_iter()
.map(|detection| {
.filter_map(|detection| {
let mut result: TokenClassificationResult = detection.into();
result.start += chunk.offset as u32;
result.end += chunk.offset as u32;
result
let threshold = detector_params
.get("threshold")
.and_then(|v| v.as_f64())
.unwrap_or(default_threshold as f64);
(result.score >= threshold).then_some(result)
})
.collect::<Vec<_>>();
Ok::<Vec<TokenClassificationResult>, Error>(results)
Expand Down
6 changes: 6 additions & 0 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,9 @@ impl IntoResponse for Error {
(code, Json(error)).into_response()
}
}

impl From<models::ValidationError> for Error {
fn from(value: models::ValidationError) -> Self {
Self::Validation(value.to_string())
}
}

0 comments on commit d29f6c4

Please sign in to comment.