From 00ee9a98c337c69ee952dbee9789c409d41278be Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 16 May 2024 16:01:30 -0700 Subject: [PATCH 01/15] :heavy_plus_sign: Switch validator for garde Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index a19013ed..45a52113 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ anyhow = "1.0.83" axum = { version = "0.7.5", features = ["json"] } clap = { version = "4.5.3", features = ["derive", "env"] } futures = "0.3.30" +garde = { version = "0.18.0", features = ["full"] } ginepro = "0.7.2" mio = "0.8.11" prost = "0.12.3" From 0a0cb6c7acda20c6bde0443f3168ed43874cc0c6 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 16 May 2024 16:06:16 -0700 Subject: [PATCH 02/15] :safety_vest: Add initial input request validation with garde Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/models.rs | 72 ++++++++++++++++++++++++++++++++++----------------- src/server.rs | 5 ++++ 2 files changed, 53 insertions(+), 24 deletions(-) diff --git a/src/models.rs b/src/models.rs index 3a8d6131..d0d26ce0 100644 --- a/src/models.rs +++ b/src/models.rs @@ -7,42 +7,48 @@ use crate::pb; pub type DetectorParams = HashMap; /// User request to orchestrator -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, garde::Validate)] pub struct GuardrailsHttpRequest { /// Text generation model ID #[serde(rename = "model_id")] + #[garde(length(min = 1))] pub model_id: String, /// User prompt/input text to a text generation model #[serde(rename = "inputs")] + #[garde(length(min = 1))] pub inputs: String, /// Configuration of guardrails models for either or both input to a text generation model /// (e.g. user prompt) and output of a text generation model #[serde(rename = "guardrail_config")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(dive)] pub guardrail_config: Option, /// Parameters for text generation #[serde(rename = "text_gen_parameters")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(dive)] pub text_gen_parameters: Option, } /// Configuration of guardrails models for either or both input to a text generation model /// (e.g. user prompt) and output of a text generation model #[derive( - Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate, + Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate, )] pub struct GuardrailsConfig { /// Configuration for detection on input to a text generation model (e.g. user prompt) #[serde(rename = "input")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(dive)] pub input: Option, /// Configuration for detection on output of a text generation model #[serde(rename = "output")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(dive)] pub output: Option, } @@ -63,52 +69,58 @@ impl GuardrailsConfig { } /// Configuration for detection on input to a text generation model (e.g. user prompt) -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate)] pub struct GuardrailsConfigInput { /// Map of model name to model specific parameters #[serde(rename = "models")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(required)] // input field must have `models` pub models: Option>, - //Option>>, /// Vector of spans are in the form of (span_start, span_end) corresponding /// to spans of input text on which to run input detection #[serde(rename = "masks")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] // TODO: custom pub masks: Option>, } /// Configuration for detection on output of a text generation model -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate)] pub struct GuardrailsConfigOutput { /// Map of model name to model specific parameters #[serde(rename = "models")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(required)] // output field must have `models` pub models: Option>, - //Option>>, } /// Parameters for text generation, ref. -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate)] pub struct GuardrailsTextGenerationParameters { + // Leave most validation of parameters to downstream text generation servers /// Maximum number of new tokens to generate #[serde(rename = "max_new_tokens")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(range(min = 0))] pub max_new_tokens: Option, /// Minimum number of new tokens to generate #[serde(rename = "min_new_tokens")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(range(min = 0))] pub min_new_tokens: Option, /// Truncate to this many input tokens for generation #[serde(rename = "truncate_input_tokens")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(range(min = 0))] pub truncate_input_tokens: Option, /// The high level decoding strategy for picking /// tokens during text generation #[serde(rename = "decoding_method")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub decoding_method: Option, /// Number of highest probability vocabulary tokens to keep for top-k-filtering. @@ -116,6 +128,7 @@ pub struct GuardrailsTextGenerationParameters { /// only the top_k most likely tokens are considered as candidates for the next generated token. #[serde(rename = "top_k")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub top_k: Option, /// Similar to top_k except the candidates to generate the next token are the @@ -123,6 +136,7 @@ pub struct GuardrailsTextGenerationParameters { /// Also known as nucleus sampling. A value of 1.0 is equivalent to disabled. #[serde(rename = "top_p")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub top_p: Option, /// Local typicality measures how similar the conditional probability of @@ -131,6 +145,7 @@ pub struct GuardrailsTextGenerationParameters { /// already generated #[serde(rename = "typical_p")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub typical_p: Option, /// A value used to modify the next-token probabilities in sampling mode. @@ -139,85 +154,96 @@ pub struct GuardrailsTextGenerationParameters { /// resulting in "more random" output. A value of 1.0 has no effect. #[serde(rename = "temperature")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(range(min = 0.0))] pub temperature: Option, /// Represents the penalty for penalizing tokens that have already been generated /// or belong to the context. The value 1.0 means that there is no penalty. #[serde(rename = "repetition_penalty")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub repetition_penalty: Option, /// Time limit in milliseconds for text generation to complete #[serde(rename = "max_time")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub max_time: Option, /// Parameters to exponentially increase the likelihood of the text generation /// terminating once a specified number of tokens have been generated. #[serde(rename = "exponential_decay_length_penalty")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub exponential_decay_length_penalty: Option, /// One or more strings which will cause the text generation to stop if/when /// they are produced as part of the output. #[serde(rename = "stop_sequences")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub stop_sequences: Option>, /// Random seed used for text generation #[serde(rename = "seed")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub seed: Option, /// Whether or not to include input text #[serde(rename = "preserve_input_text")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub preserve_input_text: Option, /// Whether or not to include input text #[serde(rename = "input_tokens")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub input_tokens: Option, /// Whether or not to include list of individual generated tokens #[serde(rename = "generated_tokens")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub generated_tokens: Option, /// Whether or not to include logprob for each returned token /// Applicable only if generated_tokens == true and/or input_tokens == true #[serde(rename = "token_logprobs")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub token_logprobs: Option, /// Whether or not to include rank of each returned token /// Applicable only if generated_tokens == true and/or input_tokens == true #[serde(rename = "token_ranks")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub token_ranks: Option, } /// Parameters to exponentially increase the likelihood of the text generation /// terminating once a specified number of tokens have been generated. -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate)] pub struct ExponentialDecayLengthPenalty { /// Start the decay after this number of tokens have been generated #[serde(rename = "start_index")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(range(min = 0))] pub start_index: Option, /// Factor of exponential decay #[serde(rename = "decay_factor")] #[serde(skip_serializing_if = "Option::is_none")] + #[garde(skip)] pub decay_factor: Option, } /// Classification result on text produced by a text generation model, containing /// information from the original text generation output as well as the result of /// classification on the generated text. -#[derive( - Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate, -)] +#[derive(Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct ClassifiedGeneratedTextResult { /// Generated text @@ -268,7 +294,7 @@ pub struct ClassifiedGeneratedTextResult { /// Streaming classification result on text produced by a text generation model, containing /// information from the original text generation output as well as the result of /// classification on the generated text. Also indicates where in stream is processed. -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct ClassifiedGeneratedTextStreamResult { #[serde(rename = "generated_text")] @@ -326,9 +352,7 @@ pub struct ClassifiedGeneratedTextStreamResult { /// Results of classification on input to a text generation model (e.g. user prompt) /// or output of a text generation model -#[derive( - Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate, -)] +#[derive(Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct TextGenTokenClassificationResults { /// Classification results on input to a text generation model @@ -348,7 +372,7 @@ pub struct TextGenTokenClassificationResults { /// The field `word` does not necessarily correspond to a single "word", /// and `entity` may not always be applicable beyond "entity" in the NER /// (named entity recognition) sense -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct TokenClassificationResult { /// Beginning/start offset of token #[serde(rename = "start")] @@ -409,7 +433,7 @@ pub enum FinishReason { } /// Warning reason and message on input detection -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct InputWarning { /// Warning reason @@ -439,7 +463,7 @@ pub enum InputWarningReason { } /// Generated token information -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct GeneratedToken { /// Token text @@ -458,7 +482,7 @@ pub struct GeneratedToken { } /// Result of a text generation model -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct GeneratedTextResult { /// Generated text @@ -496,7 +520,7 @@ pub struct GeneratedTextResult { } /// Details on the streaming result of a text generation model -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct TokenStreamDetails { /// Why text generation stopped @@ -520,7 +544,7 @@ pub struct TokenStreamDetails { } /// Streaming result of a text generation model -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct GeneratedTextStreamResult { /// Generated text @@ -548,7 +572,7 @@ pub struct GeneratedTextStreamResult { // should be using these error types /// HTTP validation error -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct HttpValidationError { #[serde(rename = "detail")] @@ -557,7 +581,7 @@ pub struct HttpValidationError { } /// Validation error -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct ValidationError { #[serde(rename = "loc")] @@ -572,7 +596,7 @@ pub struct ValidationError { pub r#type: String, } -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct LocationInner {} diff --git a/src/server.rs b/src/server.rs index 33ce5c90..2d2fcfda 100644 --- a/src/server.rs +++ b/src/server.rs @@ -11,6 +11,7 @@ use axum::{ Json, Router, }; use futures::StreamExt; +use garde::Validate; use tokio::{net::TcpListener, signal}; use tracing::{error, info}; use uuid::Uuid; @@ -82,6 +83,10 @@ async fn classification_with_gen( Json(request): Json, ) -> Result { let request_id = Uuid::new_v4(); + // Upfront request validation + if let Err(e) = request.validate(&()) { + return Err((StatusCode::BAD_REQUEST, Json(e.to_string()))); + }; let task = ClassificationWithGenTask::new(request_id, request); match state .orchestrator From 050140cd3d67a9a9391d6761ef23e2d0cc0e29fb Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Mon, 20 May 2024 15:54:56 -0700 Subject: [PATCH 03/15] :construction: Attempt custom validation Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/models.rs | 48 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/src/models.rs b/src/models.rs index d0d26ce0..80918c83 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,8 +1,8 @@ #![allow(unused_qualifications)] -use std::collections::HashMap; - use crate::pb; +use std::collections::HashMap; +use std::ops::Range; pub type DetectorParams = HashMap; @@ -68,8 +68,14 @@ impl GuardrailsConfig { } } +#[derive(Debug, Clone)] +pub struct SpanContext { + pub(crate) span_range: Range, +} + /// Configuration for detection on input to a text generation model (e.g. user prompt) #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate)] +#[garde(context(SpanContext))] pub struct GuardrailsConfigInput { /// Map of model name to model specific parameters #[serde(rename = "models")] @@ -80,10 +86,46 @@ pub struct GuardrailsConfigInput { /// to spans of input text on which to run input detection #[serde(rename = "masks")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] // TODO: custom + #[garde(inner(inner(custom(custom_validate_fn))))] // evaluate span portion pub masks: Option>, } +fn custom_validate_fn(v: &(usize, usize), args: &SpanContext) -> garde::Result { + let span_start = &v.0; + let span_end = &v.1; + let contains_span_start = args + .span_range + .contains(span_start) + .then_some(()) + .ok_or_else(|| garde::Error::new(format!("span start {span_start} is out of range"))); + + let contains_span_end = args + .span_range + .contains(span_end) + .then_some(()) + .ok_or_else(|| garde::Error::new(format!("span end {span_end} is out of range"))); + + match (contains_span_start, contains_span_end) { + (Ok(()), Ok(())) => Ok(()), + (Err(e1), Ok(())) => Err(e1), + (Ok(()), Err(e2)) => Err(e2), + (Err(e1), Err(e2)) => Err(e1), // Show at least one error + _ => Err(garde::Error::new("Error validating spans")), + } +} + +// fn custom_validate_fn(value: &Vec<(usize, usize)>, context: &Context) -> garde::Result { +// // if value.is_none() { +// // return Ok(()); +// // } +// for (span_start, span_end) in value.into_iter() { +// if (*span_start < 0) || (*span_end < 0) { +// return Err(garde::Error::new("span_start {} or span_end {} < 0")); +// } +// } +// Ok(()) +// } + /// Configuration for detection on output of a text generation model #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate)] pub struct GuardrailsConfigOutput { From 898c5859435aa21be27c8a1f3b1bcbb77d571861 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Tue, 21 May 2024 13:40:12 -0700 Subject: [PATCH 04/15] :goal_net::construction: Input masks validation Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/models.rs | 73 ++++++++++++++++++++------------------------------- src/server.rs | 3 +-- 2 files changed, 30 insertions(+), 46 deletions(-) diff --git a/src/models.rs b/src/models.rs index 80918c83..e866154d 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,8 +1,9 @@ #![allow(unused_qualifications)] +use garde::Validate; + use crate::pb; use std::collections::HashMap; -use std::ops::Range; pub type DetectorParams = HashMap; @@ -33,6 +34,32 @@ pub struct GuardrailsHttpRequest { pub text_gen_parameters: Option, } +impl GuardrailsHttpRequest { + /// Upfront validation of user request + // TODO: Change to validation error when present + pub fn upfront_validate(&self) -> Result<(), crate::Error> { + // Invoke garde validation for various fields + if let Err(e) = self.validate(&()) { + return Err(crate::Error::ValidationError(e.to_string())); // TODO: update on presence of validation error + }; + // Validate masks + let input_range = 0..self.inputs.len(); + let input_masks = self + .guardrail_config + .as_ref() + .and_then(|config| config.input.as_ref().and_then(|input| input.masks.as_ref())); + if let Some(input_masks) = input_masks { + if !input_masks.iter().all(|(start, end)| { + input_range.contains(start) && input_range.contains(end) && start < end + }) { + return Err(crate::Error::ValidationError("invalid masks".into())); + // TODO: update on presence of validation error + } + } + Ok(()) + } +} + /// Configuration of guardrails models for either or both input to a text generation model /// (e.g. user prompt) and output of a text generation model #[derive( @@ -68,14 +95,8 @@ impl GuardrailsConfig { } } -#[derive(Debug, Clone)] -pub struct SpanContext { - pub(crate) span_range: Range, -} - /// Configuration for detection on input to a text generation model (e.g. user prompt) #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate)] -#[garde(context(SpanContext))] pub struct GuardrailsConfigInput { /// Map of model name to model specific parameters #[serde(rename = "models")] @@ -86,46 +107,10 @@ pub struct GuardrailsConfigInput { /// to spans of input text on which to run input detection #[serde(rename = "masks")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(inner(inner(custom(custom_validate_fn))))] // evaluate span portion + #[garde(skip)] // Separate validation will happen pub masks: Option>, } -fn custom_validate_fn(v: &(usize, usize), args: &SpanContext) -> garde::Result { - let span_start = &v.0; - let span_end = &v.1; - let contains_span_start = args - .span_range - .contains(span_start) - .then_some(()) - .ok_or_else(|| garde::Error::new(format!("span start {span_start} is out of range"))); - - let contains_span_end = args - .span_range - .contains(span_end) - .then_some(()) - .ok_or_else(|| garde::Error::new(format!("span end {span_end} is out of range"))); - - match (contains_span_start, contains_span_end) { - (Ok(()), Ok(())) => Ok(()), - (Err(e1), Ok(())) => Err(e1), - (Ok(()), Err(e2)) => Err(e2), - (Err(e1), Err(e2)) => Err(e1), // Show at least one error - _ => Err(garde::Error::new("Error validating spans")), - } -} - -// fn custom_validate_fn(value: &Vec<(usize, usize)>, context: &Context) -> garde::Result { -// // if value.is_none() { -// // return Ok(()); -// // } -// for (span_start, span_end) in value.into_iter() { -// if (*span_start < 0) || (*span_end < 0) { -// return Err(garde::Error::new("span_start {} or span_end {} < 0")); -// } -// } -// Ok(()) -// } - /// Configuration for detection on output of a text generation model #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate)] pub struct GuardrailsConfigOutput { diff --git a/src/server.rs b/src/server.rs index 2d2fcfda..3d1bebb7 100644 --- a/src/server.rs +++ b/src/server.rs @@ -11,7 +11,6 @@ use axum::{ Json, Router, }; use futures::StreamExt; -use garde::Validate; use tokio::{net::TcpListener, signal}; use tracing::{error, info}; use uuid::Uuid; @@ -84,7 +83,7 @@ async fn classification_with_gen( ) -> Result { let request_id = Uuid::new_v4(); // Upfront request validation - if let Err(e) = request.validate(&()) { + if let Err(e) = request.upfront_validate() { return Err((StatusCode::BAD_REQUEST, Json(e.to_string()))); }; let task = ClassificationWithGenTask::new(request_id, request); From 2a8998c282e6e4fa484f3b4e6d19fbc4bfbf1c8b Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Tue, 21 May 2024 13:58:52 -0700 Subject: [PATCH 05/15] :white_check_mark: Initial request validation tests Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/models.rs | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/src/models.rs b/src/models.rs index e866154d..274a7ebc 100644 --- a/src/models.rs +++ b/src/models.rs @@ -742,3 +742,62 @@ impl From } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_upfront_validate() { + // Expected OK case + let request = GuardrailsHttpRequest { + model_id: "model".to_string(), + inputs: "The cow jumped over the moon!".to_string(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + masks: Some(vec![(5, 8)]), + models: Some(HashMap::new()), + }), + output: Some(GuardrailsConfigOutput { + models: Some(HashMap::new()), + }), + }), + text_gen_parameters: None, + }; + assert!(request.upfront_validate().is_ok()); + + // Mask span beyond inputs + let request = GuardrailsHttpRequest { + model_id: "model".to_string(), + inputs: "short".to_string(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + masks: Some(vec![(0, 12)]), + models: Some(HashMap::new()), + }), + output: Some(GuardrailsConfigOutput { + models: Some(HashMap::new()), + }), + }), + text_gen_parameters: None, + }; + assert!(request.upfront_validate().is_err()); + + // Mask span end less than span start + let request = GuardrailsHttpRequest { + model_id: "model".to_string(), + inputs: "This is ignored anyway!".to_string(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + masks: Some(vec![(12, 8)]), + models: Some(HashMap::new()), + }), + output: Some(GuardrailsConfigOutput { + models: Some(HashMap::new()), + }), + }), + text_gen_parameters: None, + }; + assert!(request.upfront_validate().is_err()); + } +} From a128a910506ae7734728d527d2a021dd6d900b87 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 22 May 2024 13:29:09 -0600 Subject: [PATCH 06/15] :white_check_mark: Garde validation case Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/models.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/models.rs b/src/models.rs index 274a7ebc..11890c19 100644 --- a/src/models.rs +++ b/src/models.rs @@ -766,6 +766,23 @@ mod tests { }; assert!(request.upfront_validate().is_ok()); + // No model ID - garde validation case + let request = GuardrailsHttpRequest { + model_id: "".to_string(), + inputs: "short".to_string(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + masks: Some(vec![]), + models: Some(HashMap::new()), + }), + output: Some(GuardrailsConfigOutput { + models: Some(HashMap::new()), + }), + }), + text_gen_parameters: None, + }; + assert!(request.upfront_validate().is_err()); + // Mask span beyond inputs let request = GuardrailsHttpRequest { model_id: "model".to_string(), From 494ad459b2eba6da2ab4b5473cd9498d1664017a Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 22 May 2024 13:56:28 -0600 Subject: [PATCH 07/15] :wrench: Pass along input text preservation Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/models.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models.rs b/src/models.rs index 11890c19..a53e26b3 100644 --- a/src/models.rs +++ b/src/models.rs @@ -656,7 +656,7 @@ impl From for pb::fmaas::Parameters { include_stop_sequence: None, }; let response = pb::fmaas::ResponseOptions { - input_text: false, // missing? + input_text: value.preserve_input_text.unwrap_or_default(), generated_tokens: value.generated_tokens.unwrap_or_default(), input_tokens: value.input_tokens.unwrap_or_default(), token_logprobs: value.token_logprobs.unwrap_or_default(), From 37faea0f948928abb64fd37221bca8230be4c21a Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 22 May 2024 14:44:29 -0600 Subject: [PATCH 08/15] :white_check_mark: Error string checking Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/models.rs | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/models.rs b/src/models.rs index a53e26b3..a7d4c6aa 100644 --- a/src/models.rs +++ b/src/models.rs @@ -781,7 +781,30 @@ mod tests { }), text_gen_parameters: None, }; - assert!(request.upfront_validate().is_err()); + let result = request.upfront_validate(); + assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + assert!(error.contains("model_id: length is lower than 1")); + + // No models on input - garde validation case + let request = GuardrailsHttpRequest { + model_id: "model".to_string(), + inputs: "short".to_string(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + masks: Some(vec![]), + models: None, + }), + output: Some(GuardrailsConfigOutput { + models: Some(HashMap::new()), + }), + }), + text_gen_parameters: None, + }; + let result = request.upfront_validate(); + assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + assert!(error.contains("guardrail_config.input.models: not set")); // Mask span beyond inputs let request = GuardrailsHttpRequest { @@ -798,7 +821,10 @@ mod tests { }), text_gen_parameters: None, }; - assert!(request.upfront_validate().is_err()); + let result = request.upfront_validate(); + assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + assert!(error.contains("invalid masks")); // Mask span end less than span start let request = GuardrailsHttpRequest { @@ -815,6 +841,9 @@ mod tests { }), text_gen_parameters: None, }; - assert!(request.upfront_validate().is_err()); + let result = request.upfront_validate(); + assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + assert!(error.contains("invalid masks")); } } From d2bbba733f757a391983414c89d693297fe13085 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 22 May 2024 15:08:26 -0600 Subject: [PATCH 09/15] :goal_net: Use error handling objects Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/models.rs | 43 ++++--------------------------------------- src/server.rs | 2 +- 2 files changed, 5 insertions(+), 40 deletions(-) diff --git a/src/models.rs b/src/models.rs index a7d4c6aa..12c686ac 100644 --- a/src/models.rs +++ b/src/models.rs @@ -2,7 +2,7 @@ use garde::Validate; -use crate::pb; +use crate::{pb, server}; use std::collections::HashMap; pub type DetectorParams = HashMap; @@ -36,11 +36,10 @@ pub struct GuardrailsHttpRequest { impl GuardrailsHttpRequest { /// Upfront validation of user request - // TODO: Change to validation error when present - pub fn upfront_validate(&self) -> Result<(), crate::Error> { + pub fn upfront_validate(&self) -> Result<(), server::Error> { // Invoke garde validation for various fields if let Err(e) = self.validate(&()) { - return Err(crate::Error::ValidationError(e.to_string())); // TODO: update on presence of validation error + return Err(server::Error::Validation(e.to_string())); }; // Validate masks let input_range = 0..self.inputs.len(); @@ -52,8 +51,7 @@ impl GuardrailsHttpRequest { if !input_masks.iter().all(|(start, end)| { input_range.contains(start) && input_range.contains(end) && start < end }) { - return Err(crate::Error::ValidationError("invalid masks".into())); - // TODO: update on presence of validation error + return Err(server::Error::Validation("invalid masks".into())); } } Ok(()) @@ -594,39 +592,6 @@ pub struct GeneratedTextStreamResult { pub input_tokens: Option>, } -// TODO: The below errors follow FastAPI concepts esp. for loc -// It may be worth revisiting if the orchestrator without FastAPI -// should be using these error types - -/// HTTP validation error -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] -#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] -pub struct HttpValidationError { - #[serde(rename = "detail")] - #[serde(skip_serializing_if = "Option::is_none")] - pub detail: Option>, -} - -/// Validation error -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] -#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] -pub struct ValidationError { - #[serde(rename = "loc")] - pub loc: Vec, - - /// Error message - #[serde(rename = "msg")] - pub msg: String, - - /// Error type - #[serde(rename = "type")] - pub r#type: String, -} - -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] -#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] -pub struct LocationInner {} - impl From for pb::fmaas::decoding_parameters::LengthPenalty { fn from(value: ExponentialDecayLengthPenalty) -> Self { Self { diff --git a/src/server.rs b/src/server.rs index 3d1bebb7..d9df747f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -84,7 +84,7 @@ async fn classification_with_gen( let request_id = Uuid::new_v4(); // Upfront request validation if let Err(e) = request.upfront_validate() { - return Err((StatusCode::BAD_REQUEST, Json(e.to_string()))); + return Err(e.into()); }; let task = ClassificationWithGenTask::new(request_id, request); match state From cf98f7450f241959921015c66319dbf6d85344ee Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 22 May 2024 15:10:47 -0600 Subject: [PATCH 10/15] :memo::safety_vest: Update errors in API Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- docs/api/orchestrator_openapi_0_1_0.yaml | 56 ++++++++++-------------- 1 file changed, 24 insertions(+), 32 deletions(-) diff --git a/docs/api/orchestrator_openapi_0_1_0.yaml b/docs/api/orchestrator_openapi_0_1_0.yaml index 3f2449bc..06dcf269 100644 --- a/docs/api/orchestrator_openapi_0_1_0.yaml +++ b/docs/api/orchestrator_openapi_0_1_0.yaml @@ -23,12 +23,18 @@ paths: application/json: schema: $ref: '#/components/schemas/ClassifiedGeneratedTextResult' + '404': + description: Resource Not Found + content: + application/json: + schema: + $ref: '#/components/schemas/Error' '422': description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/HTTPValidationError' + $ref: '#/components/schemas/Error' /api/v1/task/server-streaming-classification-with-text-generation: post: tags: @@ -49,12 +55,18 @@ paths: application/json: schema: $ref: '#/components/schemas/ClassifiedGeneratedTextStreamResult' + '404': + description: Resource Not Found + content: + application/json: + schema: + $ref: '#/components/schemas/Error' '422': description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/HTTPValidationError' + $ref: '#/components/schemas/Error' components: schemas: ClassifiedGeneratedTextResult: @@ -156,6 +168,16 @@ components: required: ["input_token_count", "token_classification_results", "start_index"] type: object title: ClassifiedGeneratedTextStreamResult + Error: + type: object + properties: + code: + type: string + details: + type: string + required: + - code + - details ExponentialDecayLengthPenalty: properties: start_index: @@ -299,15 +321,6 @@ components: additionalProperties: false type: object title: GuardrailsTextGenerationParameters - HTTPValidationError: - properties: - detail: - items: - $ref: '#/components/schemas/ValidationError' - type: array - title: Detail - type: object - title: HTTPValidationError InputWarning: properties: id: @@ -368,24 +381,3 @@ components: required: ["start", "end", "word", "entity", "entity_group", "score"] type: object title: TokenClassificationResult - ValidationError: - properties: - loc: - items: - anyOf: - - type: string - - type: integer - type: array - title: Location - msg: - type: string - title: Message - type: - type: string - title: Error Type - type: object - required: - - loc - - msg - - type - title: ValidationError From a55478b2d89484b5c9e307577cb11ef58ea804ae Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 22 May 2024 15:14:13 -0600 Subject: [PATCH 11/15] :pushpin: Switch API validation dependencies Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- Cargo.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 45a52113..89597bce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ anyhow = "1.0.83" axum = { version = "0.7.5", features = ["json"] } clap = { version = "4.5.3", features = ["derive", "env"] } futures = "0.3.30" -garde = { version = "0.18.0", features = ["full"] } +garde = { version = "0.18.0", features = ["full"] } # For API validation ginepro = "0.7.2" mio = "0.8.11" prost = "0.12.3" @@ -35,7 +35,6 @@ tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] } url = "2.5.0" uuid = { version = "1.8.0", features = ["v4", "fast-rng"] } -validator = { version = "0.18.1", features = ["derive"] } # For API validation [build-dependencies] tonic-build = "0.11.0" From 2801725e5f7274b499775c52f5140486c489ffeb Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 22 May 2024 15:36:32 -0600 Subject: [PATCH 12/15] :art: Lint Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/server.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/server.rs b/src/server.rs index d9df747f..b32453c9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -83,9 +83,7 @@ async fn classification_with_gen( ) -> Result { let request_id = Uuid::new_v4(); // Upfront request validation - if let Err(e) = request.upfront_validate() { - return Err(e.into()); - }; + request.upfront_validate()?; let task = ClassificationWithGenTask::new(request_id, request); match state .orchestrator From 45881d466b996b3dbde14c1c0f179d2b38f3a401 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 22 May 2024 17:14:01 -0600 Subject: [PATCH 13/15] :recycle: Validation without garde Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- Cargo.toml | 1 - src/clients/detector.rs | 4 +- src/models.rs | 140 ++++++++++++++++------------------------ src/orchestrator.rs | 18 +++--- src/server.rs | 2 +- 5 files changed, 67 insertions(+), 98 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 89597bce..56321242 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,6 @@ anyhow = "1.0.83" axum = { version = "0.7.5", features = ["json"] } clap = { version = "4.5.3", features = ["derive", "env"] } futures = "0.3.30" -garde = { version = "0.18.0", features = ["full"] } # For API validation ginepro = "0.7.2" mio = "0.8.11" prost = "0.12.3" diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 6d684e70..183f6c3c 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -82,8 +82,8 @@ pub struct DetectorResponse { impl From for crate::models::TokenClassificationResult { fn from(value: Detection) -> Self { Self { - start: value.start as i32, - end: value.end as i32, + start: value.start as u32, + end: value.end as u32, word: value.text, entity: value.detection, entity_group: value.detection_type, diff --git a/src/models.rs b/src/models.rs index 12c686ac..fe57e552 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,46 +1,43 @@ #![allow(unused_qualifications)] -use garde::Validate; - use crate::{pb, server}; use std::collections::HashMap; pub type DetectorParams = HashMap; /// User request to orchestrator -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, garde::Validate)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct GuardrailsHttpRequest { /// Text generation model ID #[serde(rename = "model_id")] - #[garde(length(min = 1))] pub model_id: String, /// User prompt/input text to a text generation model #[serde(rename = "inputs")] - #[garde(length(min = 1))] pub inputs: String, /// Configuration of guardrails models for either or both input to a text generation model /// (e.g. user prompt) and output of a text generation model #[serde(rename = "guardrail_config")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(dive)] pub guardrail_config: Option, /// Parameters for text generation #[serde(rename = "text_gen_parameters")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(dive)] pub text_gen_parameters: Option, } impl GuardrailsHttpRequest { /// Upfront validation of user request - pub fn upfront_validate(&self) -> Result<(), server::Error> { - // Invoke garde validation for various fields - if let Err(e) = self.validate(&()) { - return Err(server::Error::Validation(e.to_string())); - }; + pub fn validate(&self) -> Result<(), server::Error> { + // Validate required parameters + if self.model_id.is_empty() { + return Err(server::Error::Validation("`model_id` is required".into())); + } + if self.inputs.is_empty() { + return Err(server::Error::Validation("`inputs` is required".into())); + } // Validate masks let input_range = 0..self.inputs.len(); let input_masks = self @@ -60,20 +57,16 @@ impl GuardrailsHttpRequest { /// Configuration of guardrails models for either or both input to a text generation model /// (e.g. user prompt) and output of a text generation model -#[derive( - Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate, -)] +#[derive(Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct GuardrailsConfig { /// Configuration for detection on input to a text generation model (e.g. user prompt) #[serde(rename = "input")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(dive)] pub input: Option, /// Configuration for detection on output of a text generation model #[serde(rename = "output")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(dive)] pub output: Option, } @@ -83,69 +76,64 @@ impl GuardrailsConfig { } pub fn input_detectors(&self) -> Option<&HashMap> { - self.input.as_ref().and_then(|input| input.models.as_ref()) + self.input + .as_ref() + .and_then(|input| Some(input.models)) + .as_ref() } pub fn output_detectors(&self) -> Option<&HashMap> { self.output .as_ref() - .and_then(|output| output.models.as_ref()) + .and_then(|output| Some(output.models)) + .as_ref() } } /// Configuration for detection on input to a text generation model (e.g. user prompt) -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct GuardrailsConfigInput { /// Map of model name to model specific parameters #[serde(rename = "models")] - #[serde(skip_serializing_if = "Option::is_none")] - #[garde(required)] // input field must have `models` - pub models: Option>, + pub models: HashMap, /// Vector of spans are in the form of (span_start, span_end) corresponding /// to spans of input text on which to run input detection #[serde(rename = "masks")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] // Separate validation will happen pub masks: Option>, } /// Configuration for detection on output of a text generation model -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct GuardrailsConfigOutput { /// Map of model name to model specific parameters #[serde(rename = "models")] - #[serde(skip_serializing_if = "Option::is_none")] - #[garde(required)] // output field must have `models` - pub models: Option>, + pub models: HashMap, } /// Parameters for text generation, ref. -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct GuardrailsTextGenerationParameters { // Leave most validation of parameters to downstream text generation servers /// Maximum number of new tokens to generate #[serde(rename = "max_new_tokens")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(range(min = 0))] - pub max_new_tokens: Option, + pub max_new_tokens: Option, /// Minimum number of new tokens to generate #[serde(rename = "min_new_tokens")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(range(min = 0))] - pub min_new_tokens: Option, + pub min_new_tokens: Option, /// Truncate to this many input tokens for generation #[serde(rename = "truncate_input_tokens")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(range(min = 0))] - pub truncate_input_tokens: Option, + pub truncate_input_tokens: Option, /// The high level decoding strategy for picking /// tokens during text generation #[serde(rename = "decoding_method")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] pub decoding_method: Option, /// Number of highest probability vocabulary tokens to keep for top-k-filtering. @@ -153,15 +141,13 @@ pub struct GuardrailsTextGenerationParameters { /// only the top_k most likely tokens are considered as candidates for the next generated token. #[serde(rename = "top_k")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] - pub top_k: Option, + pub top_k: Option, /// Similar to top_k except the candidates to generate the next token are the /// most likely tokens with probabilities that add up to at least top_p. /// Also known as nucleus sampling. A value of 1.0 is equivalent to disabled. #[serde(rename = "top_p")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] pub top_p: Option, /// Local typicality measures how similar the conditional probability of @@ -170,7 +156,6 @@ pub struct GuardrailsTextGenerationParameters { /// already generated #[serde(rename = "typical_p")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] pub typical_p: Option, /// A value used to modify the next-token probabilities in sampling mode. @@ -179,89 +164,76 @@ pub struct GuardrailsTextGenerationParameters { /// resulting in "more random" output. A value of 1.0 has no effect. #[serde(rename = "temperature")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(range(min = 0.0))] pub temperature: Option, /// Represents the penalty for penalizing tokens that have already been generated /// or belong to the context. The value 1.0 means that there is no penalty. #[serde(rename = "repetition_penalty")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] pub repetition_penalty: Option, /// Time limit in milliseconds for text generation to complete #[serde(rename = "max_time")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] pub max_time: Option, /// Parameters to exponentially increase the likelihood of the text generation /// terminating once a specified number of tokens have been generated. #[serde(rename = "exponential_decay_length_penalty")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] pub exponential_decay_length_penalty: Option, /// One or more strings which will cause the text generation to stop if/when /// they are produced as part of the output. #[serde(rename = "stop_sequences")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] pub stop_sequences: Option>, /// Random seed used for text generation #[serde(rename = "seed")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] - pub seed: Option, + pub seed: Option, /// Whether or not to include input text #[serde(rename = "preserve_input_text")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] pub preserve_input_text: Option, /// Whether or not to include input text #[serde(rename = "input_tokens")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] pub input_tokens: Option, /// Whether or not to include list of individual generated tokens #[serde(rename = "generated_tokens")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] pub generated_tokens: Option, /// Whether or not to include logprob for each returned token /// Applicable only if generated_tokens == true and/or input_tokens == true #[serde(rename = "token_logprobs")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] pub token_logprobs: Option, /// Whether or not to include rank of each returned token /// Applicable only if generated_tokens == true and/or input_tokens == true #[serde(rename = "token_ranks")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] pub token_ranks: Option, } /// Parameters to exponentially increase the likelihood of the text generation /// terminating once a specified number of tokens have been generated. -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, garde::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct ExponentialDecayLengthPenalty { /// Start the decay after this number of tokens have been generated #[serde(rename = "start_index")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(range(min = 0))] - pub start_index: Option, + pub start_index: Option, /// Factor of exponential decay #[serde(rename = "decay_factor")] #[serde(skip_serializing_if = "Option::is_none")] - #[garde(skip)] pub decay_factor: Option, } @@ -289,16 +261,16 @@ pub struct ClassifiedGeneratedTextResult { /// Length of sequence of generated tokens #[serde(rename = "generated_token_count")] #[serde(skip_serializing_if = "Option::is_none")] - pub generated_token_count: Option, + pub generated_token_count: Option, /// Random seed used for text generation #[serde(rename = "seed")] #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, + pub seed: Option, /// Length of input #[serde(rename = "input_token_count")] - pub input_token_count: i32, + pub input_token_count: u32, /// Vector of warnings on input detection #[serde(rename = "warnings")] @@ -339,16 +311,16 @@ pub struct ClassifiedGeneratedTextStreamResult { /// Length of sequence of generated tokens #[serde(rename = "generated_token_count")] #[serde(skip_serializing_if = "Option::is_none")] - pub generated_token_count: Option, + pub generated_token_count: Option, /// Random seed used for text generation #[serde(rename = "seed")] #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, + pub seed: Option, /// Length of input #[serde(rename = "input_token_count")] - pub input_token_count: i32, + pub input_token_count: u32, /// Vector of warnings on input detection #[serde(rename = "warnings")] @@ -368,11 +340,11 @@ pub struct ClassifiedGeneratedTextStreamResult { /// Result index up to which text is processed #[serde(rename = "processed_index")] #[serde(skip_serializing_if = "Option::is_none")] - pub processed_index: Option, + pub processed_index: Option, /// Result start index for processed text #[serde(rename = "start_index")] - pub start_index: i32, + pub start_index: u32, } /// Results of classification on input to a text generation model (e.g. user prompt) @@ -401,11 +373,11 @@ pub struct TextGenTokenClassificationResults { pub struct TokenClassificationResult { /// Beginning/start offset of token #[serde(rename = "start")] - pub start: i32, + pub start: u32, /// End offset of token #[serde(rename = "end")] - pub end: i32, + pub end: u32, /// Text referenced by token #[serde(rename = "word")] @@ -426,7 +398,7 @@ pub struct TokenClassificationResult { /// Length of tokens in the text #[serde(rename = "token_count")] #[serde(skip_serializing_if = "Option::is_none")] - pub token_count: Option, + pub token_count: Option, } /// Enumeration of reasons why text generation stopped @@ -503,7 +475,7 @@ pub struct GeneratedToken { /// One-based rank relative to other tokens #[serde(rename = "rank")] #[serde(skip_serializing_if = "Option::is_none")] - pub rank: Option, + pub rank: Option, } /// Result of a text generation model @@ -526,12 +498,12 @@ pub struct GeneratedTextResult { /// Length of input #[serde(rename = "input_token_count")] - pub input_token_count: i32, + pub input_token_count: u32, /// Random seed used for text generation #[serde(rename = "seed")] #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, + pub seed: Option, /// Individual generated tokens and associated details, if requested #[serde(rename = "tokens")] @@ -556,16 +528,16 @@ pub struct TokenStreamDetails { /// Length of sequence of generated tokens #[serde(rename = "generated_tokens")] #[serde(skip_serializing_if = "Option::is_none")] - pub generated_tokens: Option, + pub generated_tokens: Option, /// Random seed used for text generation #[serde(rename = "seed")] #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, + pub seed: Option, /// Length of input #[serde(rename = "input_token_count")] - pub input_token_count: i32, + pub input_token_count: u32, } /// Streaming result of a text generation model @@ -595,7 +567,7 @@ pub struct GeneratedTextStreamResult { impl From for pb::fmaas::decoding_parameters::LengthPenalty { fn from(value: ExponentialDecayLengthPenalty) -> Self { Self { - start_index: value.start_index.unwrap_or_default() as u32, + start_index: value.start_index.unwrap_or_default(), decay_factor: value.decay_factor.unwrap_or_default() as f32, } } @@ -608,14 +580,14 @@ impl From for pb::fmaas::Parameters { let method = pb::fmaas::DecodingMethod::from_str_name(&decoding_method).unwrap_or_default(); let sampling = pb::fmaas::SamplingParameters { temperature: value.temperature.unwrap_or_default() as f32, - top_k: value.top_k.unwrap_or_default() as u32, + top_k: value.top_k.unwrap_or_default(), top_p: value.top_p.unwrap_or_default() as f32, typical_p: value.typical_p.unwrap_or_default() as f32, seed: value.seed.map(|v| v as u64), }; let stopping = pb::fmaas::StoppingCriteria { - max_new_tokens: value.max_new_tokens.unwrap_or_default() as u32, - min_new_tokens: value.min_new_tokens.unwrap_or_default() as u32, + max_new_tokens: value.max_new_tokens.unwrap_or_default(), + min_new_tokens: value.min_new_tokens.unwrap_or_default(), time_limit_millis: value.max_time.unwrap_or_default() as u32, stop_sequences: value.stop_sequences.unwrap_or_default(), include_stop_sequence: None, @@ -632,7 +604,7 @@ impl From for pb::fmaas::Parameters { repetition_penalty: value.repetition_penalty.unwrap_or_default() as f32, length_penalty: value.exponential_decay_length_penalty.map(Into::into), }; - let truncate_input_tokens = value.truncate_input_tokens.unwrap_or_default() as u32; + let truncate_input_tokens = value.truncate_input_tokens.unwrap_or_default(); Self { method: method as i32, sampling: Some(sampling), @@ -666,7 +638,7 @@ impl From for GeneratedToken { Self { text: value.text, logprob: Some(value.logprob as f64), - rank: Some(value.rank as i32), + rank: Some(value.rank as u32), } } } @@ -676,7 +648,7 @@ impl From for GeneratedToken { Self { text: value.text, logprob: Some(value.logprob), - rank: Some(value.rank as i32), + rank: Some(value.rank as u32), } } } @@ -731,7 +703,7 @@ mod tests { }; assert!(request.upfront_validate().is_ok()); - // No model ID - garde validation case + // No model ID let request = GuardrailsHttpRequest { model_id: "".to_string(), inputs: "short".to_string(), @@ -749,9 +721,9 @@ mod tests { let result = request.upfront_validate(); assert!(result.is_err()); let error = result.unwrap_err().to_string(); - assert!(error.contains("model_id: length is lower than 1")); + assert!(error.contains("`model_id` is required")); - // No models on input - garde validation case + // No models on input let request = GuardrailsHttpRequest { model_id: "model".to_string(), inputs: "short".to_string(), @@ -768,8 +740,6 @@ mod tests { }; let result = request.upfront_validate(); assert!(result.is_err()); - let error = result.unwrap_err().to_string(); - assert!(error.contains("guardrail_config.input.models: not set")); // Mask span beyond inputs let request = GuardrailsHttpRequest { diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 4f6ec473..6dcb627e 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -96,7 +96,7 @@ impl Orchestrator { tokenize(ctx.clone(), task.model_id.clone(), task.inputs.clone()).await?; // Send result with input detections Ok(ClassifiedGeneratedTextResult { - input_token_count: input_token_count as i32, + input_token_count: input_token_count as u32, token_classification_results: TextGenTokenClassificationResults { input: input_detections, output: None, @@ -348,8 +348,8 @@ async fn handle_detection_task( .into_iter() .map(|detection| { let mut result: TokenClassificationResult = detection.into(); - result.start += chunk.offset as i32; - result.end += chunk.offset as i32; + result.start += chunk.offset as u32; + result.end += chunk.offset as u32; result }) .collect::>(); @@ -478,9 +478,9 @@ async fn generate( Ok(ClassifiedGeneratedTextResult { generated_text: Some(response.text.clone()), finish_reason: Some(response.stop_reason().into()), - generated_token_count: Some(response.generated_token_count as i32), - seed: Some(response.seed as i32), - input_token_count: response.input_token_count as i32, + generated_token_count: Some(response.generated_token_count as u32), + seed: Some(response.seed as u32), + input_token_count: response.input_token_count as u32, warnings: None, tokens: if response.tokens.is_empty() { None @@ -551,9 +551,9 @@ async fn generate( Ok(ClassifiedGeneratedTextResult { generated_text: Some(response.generated_text.clone()), finish_reason: Some(response.finish_reason().into()), - generated_token_count: Some(response.generated_tokens as i32), - seed: Some(response.seed as i32), - input_token_count: response.input_token_count as i32, + generated_token_count: Some(response.generated_tokens as u32), + seed: Some(response.seed as u32), + input_token_count: response.input_token_count as u32, warnings: None, tokens: if response.tokens.is_empty() { None diff --git a/src/server.rs b/src/server.rs index b32453c9..61468815 100644 --- a/src/server.rs +++ b/src/server.rs @@ -83,7 +83,7 @@ async fn classification_with_gen( ) -> Result { let request_id = Uuid::new_v4(); // Upfront request validation - request.upfront_validate()?; + request.validate()?; let task = ClassificationWithGenTask::new(request_id, request); match state .orchestrator From 5f5145f0d1cd5dbef98204aa325dddfa49a6a766 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 22 May 2024 17:42:37 -0600 Subject: [PATCH 14/15] :white_check_mark::label: Update tests and types Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/models.rs | 52 +++++++++++++++++++++------------------------ src/orchestrator.rs | 2 +- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/src/models.rs b/src/models.rs index fe57e552..276cb143 100644 --- a/src/models.rs +++ b/src/models.rs @@ -76,17 +76,11 @@ impl GuardrailsConfig { } pub fn input_detectors(&self) -> Option<&HashMap> { - self.input - .as_ref() - .and_then(|input| Some(input.models)) - .as_ref() + self.input.as_ref().map(|input| &input.models) } pub fn output_detectors(&self) -> Option<&HashMap> { - self.output - .as_ref() - .and_then(|output| Some(output.models)) - .as_ref() + self.output.as_ref().map(|output| &output.models) } } @@ -638,7 +632,7 @@ impl From for GeneratedToken { Self { text: value.text, logprob: Some(value.logprob as f64), - rank: Some(value.rank as u32), + rank: Some(value.rank), } } } @@ -685,7 +679,7 @@ mod tests { use super::*; #[test] - fn test_upfront_validate() { + fn test_validate() { // Expected OK case let request = GuardrailsHttpRequest { model_id: "model".to_string(), @@ -693,15 +687,15 @@ mod tests { guardrail_config: Some(GuardrailsConfig { input: Some(GuardrailsConfigInput { masks: Some(vec![(5, 8)]), - models: Some(HashMap::new()), + models: HashMap::new(), }), output: Some(GuardrailsConfigOutput { - models: Some(HashMap::new()), + models: HashMap::new(), }), }), text_gen_parameters: None, }; - assert!(request.upfront_validate().is_ok()); + assert!(request.validate().is_ok()); // No model ID let request = GuardrailsHttpRequest { @@ -710,36 +704,38 @@ mod tests { guardrail_config: Some(GuardrailsConfig { input: Some(GuardrailsConfigInput { masks: Some(vec![]), - models: Some(HashMap::new()), + models: HashMap::new(), }), output: Some(GuardrailsConfigOutput { - models: Some(HashMap::new()), + models: HashMap::new(), }), }), text_gen_parameters: None, }; - let result = request.upfront_validate(); + let result = request.validate(); assert!(result.is_err()); let error = result.unwrap_err().to_string(); assert!(error.contains("`model_id` is required")); - // No models on input + // No inputs let request = GuardrailsHttpRequest { model_id: "model".to_string(), - inputs: "short".to_string(), + inputs: "".to_string(), guardrail_config: Some(GuardrailsConfig { input: Some(GuardrailsConfigInput { - masks: Some(vec![]), - models: None, + masks: None, + models: HashMap::new(), }), output: Some(GuardrailsConfigOutput { - models: Some(HashMap::new()), + models: HashMap::new(), }), }), text_gen_parameters: None, }; - let result = request.upfront_validate(); + let result = request.validate(); assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + assert!(error.contains("`inputs` is required")); // Mask span beyond inputs let request = GuardrailsHttpRequest { @@ -748,15 +744,15 @@ mod tests { guardrail_config: Some(GuardrailsConfig { input: Some(GuardrailsConfigInput { masks: Some(vec![(0, 12)]), - models: Some(HashMap::new()), + models: HashMap::new(), }), output: Some(GuardrailsConfigOutput { - models: Some(HashMap::new()), + models: HashMap::new(), }), }), text_gen_parameters: None, }; - let result = request.upfront_validate(); + let result = request.validate(); assert!(result.is_err()); let error = result.unwrap_err().to_string(); assert!(error.contains("invalid masks")); @@ -768,15 +764,15 @@ mod tests { guardrail_config: Some(GuardrailsConfig { input: Some(GuardrailsConfigInput { masks: Some(vec![(12, 8)]), - models: Some(HashMap::new()), + models: HashMap::new(), }), output: Some(GuardrailsConfigOutput { - models: Some(HashMap::new()), + models: HashMap::new(), }), }), text_gen_parameters: None, }; - let result = request.upfront_validate(); + let result = request.validate(); assert!(result.is_err()); let error = result.unwrap_err().to_string(); assert!(error.contains("invalid masks")); diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 6dcb627e..79c9d8d4 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -96,7 +96,7 @@ impl Orchestrator { tokenize(ctx.clone(), task.model_id.clone(), task.inputs.clone()).await?; // Send result with input detections Ok(ClassifiedGeneratedTextResult { - input_token_count: input_token_count as u32, + input_token_count, token_classification_results: TextGenTokenClassificationResults { input: input_detections, output: None, From 2c60b42f3e544ec71f2fa2130f7c70420cff3dfa Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 22 May 2024 20:32:20 -0600 Subject: [PATCH 15/15] :art: Linting Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/orchestrator.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 79c9d8d4..0cf29ab1 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -478,9 +478,9 @@ async fn generate( Ok(ClassifiedGeneratedTextResult { generated_text: Some(response.text.clone()), finish_reason: Some(response.stop_reason().into()), - generated_token_count: Some(response.generated_token_count as u32), + generated_token_count: Some(response.generated_token_count), seed: Some(response.seed as u32), - input_token_count: response.input_token_count as u32, + input_token_count: response.input_token_count, warnings: None, tokens: if response.tokens.is_empty() { None