From 34f04ff5ff8507965a18cd4c55d02660794f4aec Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Fri, 24 May 2024 10:00:54 -0600 Subject: [PATCH] :safety_vest: Initial request validation (#42) * :heavy_plus_sign: Switch validator for garde Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :safety_vest: Add initial input request validation with garde Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :construction: Attempt custom validation Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :goal_net::construction: Input masks validation Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :white_check_mark: Initial request validation tests Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :white_check_mark: Garde validation case Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :wrench: Pass along input text preservation Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :white_check_mark: Error string checking Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :goal_net: Use error handling objects Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :memo::safety_vest: Update errors in API Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :pushpin: Switch API validation dependencies Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :art: Lint Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :recycle: Validation without garde Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :white_check_mark::label: Update tests and types Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :art: Linting Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --------- Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- Cargo.toml | 1 - docs/api/orchestrator_openapi_0_1_0.yaml | 56 ++--- src/clients/detector.rs | 4 +- src/models.rs | 281 +++++++++++++++-------- src/orchestrator.rs | 18 +- src/server.rs | 2 + 6 files changed, 221 insertions(+), 141 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a19013ed..56321242 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,6 @@ tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] } url = "2.5.0" uuid = { version = "1.8.0", features = ["v4", "fast-rng"] } -validator = { version = "0.18.1", features = ["derive"] } # For API validation [build-dependencies] tonic-build = "0.11.0" diff --git a/docs/api/orchestrator_openapi_0_1_0.yaml b/docs/api/orchestrator_openapi_0_1_0.yaml index 3f2449bc..06dcf269 100644 --- a/docs/api/orchestrator_openapi_0_1_0.yaml +++ b/docs/api/orchestrator_openapi_0_1_0.yaml @@ -23,12 +23,18 @@ paths: application/json: schema: $ref: '#/components/schemas/ClassifiedGeneratedTextResult' + '404': + description: Resource Not Found + content: + application/json: + schema: + $ref: '#/components/schemas/Error' '422': description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/HTTPValidationError' + $ref: '#/components/schemas/Error' /api/v1/task/server-streaming-classification-with-text-generation: post: tags: @@ -49,12 +55,18 @@ paths: application/json: schema: $ref: '#/components/schemas/ClassifiedGeneratedTextStreamResult' + '404': + description: Resource Not Found + content: + application/json: + schema: + $ref: '#/components/schemas/Error' '422': description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/HTTPValidationError' + $ref: '#/components/schemas/Error' components: schemas: ClassifiedGeneratedTextResult: @@ -156,6 +168,16 @@ components: required: ["input_token_count", "token_classification_results", "start_index"] type: object title: ClassifiedGeneratedTextStreamResult + Error: + type: object + properties: + code: + type: string + details: + type: string + required: + - code + - details ExponentialDecayLengthPenalty: properties: start_index: @@ -299,15 +321,6 @@ components: additionalProperties: false type: object title: GuardrailsTextGenerationParameters - HTTPValidationError: - properties: - detail: - items: - $ref: '#/components/schemas/ValidationError' - type: array - title: Detail - type: object - title: HTTPValidationError InputWarning: properties: id: @@ -368,24 +381,3 @@ components: required: ["start", "end", "word", "entity", "entity_group", "score"] type: object title: TokenClassificationResult - ValidationError: - properties: - loc: - items: - anyOf: - - type: string - - type: integer - type: array - title: Location - msg: - type: string - title: Message - type: - type: string - title: Error Type - type: object - required: - - loc - - msg - - type - title: ValidationError diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 6d684e70..183f6c3c 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -82,8 +82,8 @@ pub struct DetectorResponse { impl From for crate::models::TokenClassificationResult { fn from(value: Detection) -> Self { Self { - start: value.start as i32, - end: value.end as i32, + start: value.start as u32, + end: value.end as u32, word: value.text, entity: value.detection, entity_group: value.detection_type, diff --git a/src/models.rs b/src/models.rs index 3a8d6131..276cb143 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,13 +1,12 @@ #![allow(unused_qualifications)] +use crate::{pb, server}; use std::collections::HashMap; -use crate::pb; - pub type DetectorParams = HashMap; /// User request to orchestrator -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct GuardrailsHttpRequest { /// Text generation model ID #[serde(rename = "model_id")] @@ -29,11 +28,36 @@ pub struct GuardrailsHttpRequest { pub text_gen_parameters: Option, } +impl GuardrailsHttpRequest { + /// Upfront validation of user request + pub fn validate(&self) -> Result<(), server::Error> { + // Validate required parameters + if self.model_id.is_empty() { + return Err(server::Error::Validation("`model_id` is required".into())); + } + if self.inputs.is_empty() { + return Err(server::Error::Validation("`inputs` is required".into())); + } + // Validate masks + let input_range = 0..self.inputs.len(); + let input_masks = self + .guardrail_config + .as_ref() + .and_then(|config| config.input.as_ref().and_then(|input| input.masks.as_ref())); + if let Some(input_masks) = input_masks { + if !input_masks.iter().all(|(start, end)| { + input_range.contains(start) && input_range.contains(end) && start < end + }) { + return Err(server::Error::Validation("invalid masks".into())); + } + } + Ok(()) + } +} + /// Configuration of guardrails models for either or both input to a text generation model /// (e.g. user prompt) and output of a text generation model -#[derive( - Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate, -)] +#[derive(Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct GuardrailsConfig { /// Configuration for detection on input to a text generation model (e.g. user prompt) #[serde(rename = "input")] @@ -52,24 +76,20 @@ impl GuardrailsConfig { } pub fn input_detectors(&self) -> Option<&HashMap> { - self.input.as_ref().and_then(|input| input.models.as_ref()) + self.input.as_ref().map(|input| &input.models) } pub fn output_detectors(&self) -> Option<&HashMap> { - self.output - .as_ref() - .and_then(|output| output.models.as_ref()) + self.output.as_ref().map(|output| &output.models) } } /// Configuration for detection on input to a text generation model (e.g. user prompt) -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct GuardrailsConfigInput { /// Map of model name to model specific parameters #[serde(rename = "models")] - #[serde(skip_serializing_if = "Option::is_none")] - pub models: Option>, - //Option>>, + pub models: HashMap, /// Vector of spans are in the form of (span_start, span_end) corresponding /// to spans of input text on which to run input detection #[serde(rename = "masks")] @@ -78,32 +98,31 @@ pub struct GuardrailsConfigInput { } /// Configuration for detection on output of a text generation model -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct GuardrailsConfigOutput { /// Map of model name to model specific parameters #[serde(rename = "models")] - #[serde(skip_serializing_if = "Option::is_none")] - pub models: Option>, - //Option>>, + pub models: HashMap, } /// Parameters for text generation, ref. -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct GuardrailsTextGenerationParameters { + // Leave most validation of parameters to downstream text generation servers /// Maximum number of new tokens to generate #[serde(rename = "max_new_tokens")] #[serde(skip_serializing_if = "Option::is_none")] - pub max_new_tokens: Option, + pub max_new_tokens: Option, /// Minimum number of new tokens to generate #[serde(rename = "min_new_tokens")] #[serde(skip_serializing_if = "Option::is_none")] - pub min_new_tokens: Option, + pub min_new_tokens: Option, /// Truncate to this many input tokens for generation #[serde(rename = "truncate_input_tokens")] #[serde(skip_serializing_if = "Option::is_none")] - pub truncate_input_tokens: Option, + pub truncate_input_tokens: Option, /// The high level decoding strategy for picking /// tokens during text generation @@ -116,7 +135,7 @@ pub struct GuardrailsTextGenerationParameters { /// only the top_k most likely tokens are considered as candidates for the next generated token. #[serde(rename = "top_k")] #[serde(skip_serializing_if = "Option::is_none")] - pub top_k: Option, + pub top_k: Option, /// Similar to top_k except the candidates to generate the next token are the /// most likely tokens with probabilities that add up to at least top_p. @@ -167,7 +186,7 @@ pub struct GuardrailsTextGenerationParameters { /// Random seed used for text generation #[serde(rename = "seed")] #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, + pub seed: Option, /// Whether or not to include input text #[serde(rename = "preserve_input_text")] @@ -199,12 +218,12 @@ pub struct GuardrailsTextGenerationParameters { /// Parameters to exponentially increase the likelihood of the text generation /// terminating once a specified number of tokens have been generated. -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct ExponentialDecayLengthPenalty { /// Start the decay after this number of tokens have been generated #[serde(rename = "start_index")] #[serde(skip_serializing_if = "Option::is_none")] - pub start_index: Option, + pub start_index: Option, /// Factor of exponential decay #[serde(rename = "decay_factor")] @@ -215,9 +234,7 @@ pub struct ExponentialDecayLengthPenalty { /// Classification result on text produced by a text generation model, containing /// information from the original text generation output as well as the result of /// classification on the generated text. -#[derive( - Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate, -)] +#[derive(Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct ClassifiedGeneratedTextResult { /// Generated text @@ -238,16 +255,16 @@ pub struct ClassifiedGeneratedTextResult { /// Length of sequence of generated tokens #[serde(rename = "generated_token_count")] #[serde(skip_serializing_if = "Option::is_none")] - pub generated_token_count: Option, + pub generated_token_count: Option, /// Random seed used for text generation #[serde(rename = "seed")] #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, + pub seed: Option, /// Length of input #[serde(rename = "input_token_count")] - pub input_token_count: i32, + pub input_token_count: u32, /// Vector of warnings on input detection #[serde(rename = "warnings")] @@ -268,7 +285,7 @@ pub struct ClassifiedGeneratedTextResult { /// Streaming classification result on text produced by a text generation model, containing /// information from the original text generation output as well as the result of /// classification on the generated text. Also indicates where in stream is processed. -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct ClassifiedGeneratedTextStreamResult { #[serde(rename = "generated_text")] @@ -288,16 +305,16 @@ pub struct ClassifiedGeneratedTextStreamResult { /// Length of sequence of generated tokens #[serde(rename = "generated_token_count")] #[serde(skip_serializing_if = "Option::is_none")] - pub generated_token_count: Option, + pub generated_token_count: Option, /// Random seed used for text generation #[serde(rename = "seed")] #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, + pub seed: Option, /// Length of input #[serde(rename = "input_token_count")] - pub input_token_count: i32, + pub input_token_count: u32, /// Vector of warnings on input detection #[serde(rename = "warnings")] @@ -317,18 +334,16 @@ pub struct ClassifiedGeneratedTextStreamResult { /// Result index up to which text is processed #[serde(rename = "processed_index")] #[serde(skip_serializing_if = "Option::is_none")] - pub processed_index: Option, + pub processed_index: Option, /// Result start index for processed text #[serde(rename = "start_index")] - pub start_index: i32, + pub start_index: u32, } /// Results of classification on input to a text generation model (e.g. user prompt) /// or output of a text generation model -#[derive( - Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate, -)] +#[derive(Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct TextGenTokenClassificationResults { /// Classification results on input to a text generation model @@ -348,15 +363,15 @@ pub struct TextGenTokenClassificationResults { /// The field `word` does not necessarily correspond to a single "word", /// and `entity` may not always be applicable beyond "entity" in the NER /// (named entity recognition) sense -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct TokenClassificationResult { /// Beginning/start offset of token #[serde(rename = "start")] - pub start: i32, + pub start: u32, /// End offset of token #[serde(rename = "end")] - pub end: i32, + pub end: u32, /// Text referenced by token #[serde(rename = "word")] @@ -377,7 +392,7 @@ pub struct TokenClassificationResult { /// Length of tokens in the text #[serde(rename = "token_count")] #[serde(skip_serializing_if = "Option::is_none")] - pub token_count: Option, + pub token_count: Option, } /// Enumeration of reasons why text generation stopped @@ -409,7 +424,7 @@ pub enum FinishReason { } /// Warning reason and message on input detection -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct InputWarning { /// Warning reason @@ -439,7 +454,7 @@ pub enum InputWarningReason { } /// Generated token information -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct GeneratedToken { /// Token text @@ -454,11 +469,11 @@ pub struct GeneratedToken { /// One-based rank relative to other tokens #[serde(rename = "rank")] #[serde(skip_serializing_if = "Option::is_none")] - pub rank: Option, + pub rank: Option, } /// Result of a text generation model -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct GeneratedTextResult { /// Generated text @@ -477,12 +492,12 @@ pub struct GeneratedTextResult { /// Length of input #[serde(rename = "input_token_count")] - pub input_token_count: i32, + pub input_token_count: u32, /// Random seed used for text generation #[serde(rename = "seed")] #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, + pub seed: Option, /// Individual generated tokens and associated details, if requested #[serde(rename = "tokens")] @@ -496,7 +511,7 @@ pub struct GeneratedTextResult { } /// Details on the streaming result of a text generation model -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct TokenStreamDetails { /// Why text generation stopped @@ -507,20 +522,20 @@ pub struct TokenStreamDetails { /// Length of sequence of generated tokens #[serde(rename = "generated_tokens")] #[serde(skip_serializing_if = "Option::is_none")] - pub generated_tokens: Option, + pub generated_tokens: Option, /// Random seed used for text generation #[serde(rename = "seed")] #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, + pub seed: Option, /// Length of input #[serde(rename = "input_token_count")] - pub input_token_count: i32, + pub input_token_count: u32, } /// Streaming result of a text generation model -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] pub struct GeneratedTextStreamResult { /// Generated text @@ -543,43 +558,10 @@ pub struct GeneratedTextStreamResult { pub input_tokens: Option>, } -// TODO: The below errors follow FastAPI concepts esp. for loc -// It may be worth revisiting if the orchestrator without FastAPI -// should be using these error types - -/// HTTP validation error -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] -#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] -pub struct HttpValidationError { - #[serde(rename = "detail")] - #[serde(skip_serializing_if = "Option::is_none")] - pub detail: Option>, -} - -/// Validation error -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] -#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] -pub struct ValidationError { - #[serde(rename = "loc")] - pub loc: Vec, - - /// Error message - #[serde(rename = "msg")] - pub msg: String, - - /// Error type - #[serde(rename = "type")] - pub r#type: String, -} - -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)] -#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))] -pub struct LocationInner {} - impl From for pb::fmaas::decoding_parameters::LengthPenalty { fn from(value: ExponentialDecayLengthPenalty) -> Self { Self { - start_index: value.start_index.unwrap_or_default() as u32, + start_index: value.start_index.unwrap_or_default(), decay_factor: value.decay_factor.unwrap_or_default() as f32, } } @@ -592,20 +574,20 @@ impl From for pb::fmaas::Parameters { let method = pb::fmaas::DecodingMethod::from_str_name(&decoding_method).unwrap_or_default(); let sampling = pb::fmaas::SamplingParameters { temperature: value.temperature.unwrap_or_default() as f32, - top_k: value.top_k.unwrap_or_default() as u32, + top_k: value.top_k.unwrap_or_default(), top_p: value.top_p.unwrap_or_default() as f32, typical_p: value.typical_p.unwrap_or_default() as f32, seed: value.seed.map(|v| v as u64), }; let stopping = pb::fmaas::StoppingCriteria { - max_new_tokens: value.max_new_tokens.unwrap_or_default() as u32, - min_new_tokens: value.min_new_tokens.unwrap_or_default() as u32, + max_new_tokens: value.max_new_tokens.unwrap_or_default(), + min_new_tokens: value.min_new_tokens.unwrap_or_default(), time_limit_millis: value.max_time.unwrap_or_default() as u32, stop_sequences: value.stop_sequences.unwrap_or_default(), include_stop_sequence: None, }; let response = pb::fmaas::ResponseOptions { - input_text: false, // missing? + input_text: value.preserve_input_text.unwrap_or_default(), generated_tokens: value.generated_tokens.unwrap_or_default(), input_tokens: value.input_tokens.unwrap_or_default(), token_logprobs: value.token_logprobs.unwrap_or_default(), @@ -616,7 +598,7 @@ impl From for pb::fmaas::Parameters { repetition_penalty: value.repetition_penalty.unwrap_or_default() as f32, length_penalty: value.exponential_decay_length_penalty.map(Into::into), }; - let truncate_input_tokens = value.truncate_input_tokens.unwrap_or_default() as u32; + let truncate_input_tokens = value.truncate_input_tokens.unwrap_or_default(); Self { method: method as i32, sampling: Some(sampling), @@ -650,7 +632,7 @@ impl From for GeneratedToken { Self { text: value.text, logprob: Some(value.logprob as f64), - rank: Some(value.rank as i32), + rank: Some(value.rank), } } } @@ -660,7 +642,7 @@ impl From for GeneratedToken { Self { text: value.text, logprob: Some(value.logprob), - rank: Some(value.rank as i32), + rank: Some(value.rank as u32), } } } @@ -691,3 +673,108 @@ impl From } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate() { + // Expected OK case + let request = GuardrailsHttpRequest { + model_id: "model".to_string(), + inputs: "The cow jumped over the moon!".to_string(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + masks: Some(vec![(5, 8)]), + models: HashMap::new(), + }), + output: Some(GuardrailsConfigOutput { + models: HashMap::new(), + }), + }), + text_gen_parameters: None, + }; + assert!(request.validate().is_ok()); + + // No model ID + let request = GuardrailsHttpRequest { + model_id: "".to_string(), + inputs: "short".to_string(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + masks: Some(vec![]), + models: HashMap::new(), + }), + output: Some(GuardrailsConfigOutput { + models: HashMap::new(), + }), + }), + text_gen_parameters: None, + }; + let result = request.validate(); + assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + assert!(error.contains("`model_id` is required")); + + // No inputs + let request = GuardrailsHttpRequest { + model_id: "model".to_string(), + inputs: "".to_string(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + masks: None, + models: HashMap::new(), + }), + output: Some(GuardrailsConfigOutput { + models: HashMap::new(), + }), + }), + text_gen_parameters: None, + }; + let result = request.validate(); + assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + assert!(error.contains("`inputs` is required")); + + // Mask span beyond inputs + let request = GuardrailsHttpRequest { + model_id: "model".to_string(), + inputs: "short".to_string(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + masks: Some(vec![(0, 12)]), + models: HashMap::new(), + }), + output: Some(GuardrailsConfigOutput { + models: HashMap::new(), + }), + }), + text_gen_parameters: None, + }; + let result = request.validate(); + assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + assert!(error.contains("invalid masks")); + + // Mask span end less than span start + let request = GuardrailsHttpRequest { + model_id: "model".to_string(), + inputs: "This is ignored anyway!".to_string(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + masks: Some(vec![(12, 8)]), + models: HashMap::new(), + }), + output: Some(GuardrailsConfigOutput { + models: HashMap::new(), + }), + }), + text_gen_parameters: None, + }; + let result = request.validate(); + assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + assert!(error.contains("invalid masks")); + } +} diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 3625ead2..e6bb1dff 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -96,7 +96,7 @@ impl Orchestrator { tokenize(ctx.clone(), task.model_id.clone(), task.inputs.clone()).await?; // Send result with input detections Ok(ClassifiedGeneratedTextResult { - input_token_count: input_token_count as i32, + input_token_count, token_classification_results: TextGenTokenClassificationResults { input: input_detections, output: None, @@ -349,8 +349,8 @@ async fn handle_detection_task( .into_iter() .map(|detection| { let mut result: TokenClassificationResult = detection.into(); - result.start += chunk.offset as i32; - result.end += chunk.offset as i32; + result.start += chunk.offset as u32; + result.end += chunk.offset as u32; result }) .collect::>(); @@ -479,9 +479,9 @@ async fn generate( Ok(ClassifiedGeneratedTextResult { generated_text: Some(response.text.clone()), finish_reason: Some(response.stop_reason().into()), - generated_token_count: Some(response.generated_token_count as i32), - seed: Some(response.seed as i32), - input_token_count: response.input_token_count as i32, + generated_token_count: Some(response.generated_token_count), + seed: Some(response.seed as u32), + input_token_count: response.input_token_count, warnings: None, tokens: if response.tokens.is_empty() { None @@ -552,9 +552,9 @@ async fn generate( Ok(ClassifiedGeneratedTextResult { generated_text: Some(response.generated_text.clone()), finish_reason: Some(response.finish_reason().into()), - generated_token_count: Some(response.generated_tokens as i32), - seed: Some(response.seed as i32), - input_token_count: response.input_token_count as i32, + generated_token_count: Some(response.generated_tokens as u32), + seed: Some(response.seed as u32), + input_token_count: response.input_token_count as u32, warnings: None, tokens: if response.tokens.is_empty() { None diff --git a/src/server.rs b/src/server.rs index 33ce5c90..61468815 100644 --- a/src/server.rs +++ b/src/server.rs @@ -82,6 +82,8 @@ async fn classification_with_gen( Json(request): Json, ) -> Result { let request_id = Uuid::new_v4(); + // Upfront request validation + request.validate()?; let task = ClassificationWithGenTask::new(request_id, request); match state .orchestrator