Skip to content

Commit

Permalink
Add error handling for unknown fields for chat completion detection r…
Browse files Browse the repository at this point in the history
…equest (#296)

* added serde deny_unknown_fields attribute to relevant chat completion request

Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com>

* added validation unit tests

Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com>

* modidified json_data to remove actual model and detector ids in unit tests

Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com>

* removed async from chat completions unit test

Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com>

---------

Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com>
Co-authored-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com>
  • Loading branch information
swith004 and Shonda-Adena-Witherspoon authored Feb 6, 2025
1 parent d82b00d commit c202774
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/clients/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ impl From<ChatCompletion> for ChatCompletionsResponse {
}

#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ChatCompletionsRequest {
/// A list of messages comprising the conversation so far.
pub messages: Vec<Message>,
Expand Down Expand Up @@ -290,6 +291,7 @@ pub struct ChatCompletionsRequest {

/// Structure to contain parameters for detectors.
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct DetectorConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub input: Option<HashMap<String, DetectorParams>>,
Expand Down Expand Up @@ -398,6 +400,7 @@ pub enum Role {
}

#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Message {
/// The role of the author of this message.
pub role: Role,
Expand Down
91 changes: 87 additions & 4 deletions src/orchestrator/chat_completions_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ mod tests {

// Test to verify preprocess_chat_messages works correctly for multiple content type detectors
// with single message in chat request
#[tokio::test]
async fn pretest_process_chat_messages_multiple_content_detector() {
#[test]
fn pretest_process_chat_messages_multiple_content_detector() {
// Test setup
let clients = ClientMap::new();
let detector_1_id = "detector1";
Expand Down Expand Up @@ -436,8 +436,8 @@ mod tests {

// Test preprocess_chat_messages returns error correctly for multiple content type detectors
// with incorrect message requirements
#[tokio::test]
async fn pretest_process_chat_messages_error_handling() {
#[test]
fn pretest_process_chat_messages_error_handling() {
// Test setup
let clients = ClientMap::new();
let detector_1_id = "detector1";
Expand Down Expand Up @@ -473,4 +473,87 @@ mod tests {
"validation error: Last message role must be user, assistant, or system"
);
}
// validate chat completions request with invalid fields
// (nonexistant fields or typos)
#[test]
fn test_validate() {
// Additional unknown field (additional_field)
let json_data = r#"
{
"messages": [
{
"content": "this is a nice sentence",
"role": "user",
"name": "string"
}
],
"model": "my_model",
"additional_field": "test",
"n": 1,
"temperature": 1,
"top_p": 1,
"user": "user-1234",
"detectors": {
"input": {}
}
}
"#;
let result: Result<ChatCompletionsRequest, _> = serde_json::from_str(json_data);
assert!(result.is_err());
let error = result.unwrap_err().to_string();
assert!(error
.to_string()
.contains("unknown field `additional_field"));

// Additional unknown field (additional_message")
let json_data = r#"
{
"messages": [
{
"content": "this is a nice sentence",
"role": "user",
"name": "string",
"additional_msg: "test"
}
],
"model": "my_model",
"n": 1,
"temperature": 1,
"top_p": 1,
"user": "user-1234",
"detectors": {
"input": {}
}
}
"#;
let result: Result<ChatCompletionsRequest, _> = serde_json::from_str(json_data);
assert!(result.is_err());
let error = result.unwrap_err().to_string();
assert!(error.to_string().contains("unknown field `additional_msg"));

// Additional unknown field (typo for input field in detectors)
let json_data = r#"
{
"messages": [
{
"content": "this is a nice sentence",
"role": "user",
"name": "string"
}
],
"model": "my_model",
"n": 1,
"temperature": 1,
"top_p": 1,
"user": "user-1234",
"detectors": {
"inputs": {}
}
}
"#;
let result: Result<ChatCompletionsRequest, _> = serde_json::from_str(json_data);
assert!(result.is_err());
let error = result.unwrap_err().to_string();
assert!(error.to_string().contains("unknown field `inputs"));
}
}

0 comments on commit c202774

Please sign in to comment.