Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add error handling for unknown fields for chat completion detection request #296

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/clients/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ impl From<ChatCompletion> for ChatCompletionsResponse {
}

#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ChatCompletionsRequest {
/// A list of messages comprising the conversation so far.
pub messages: Vec<Message>,
Expand Down Expand Up @@ -290,6 +291,7 @@ pub struct ChatCompletionsRequest {

/// Structure to contain parameters for detectors.
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct DetectorConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub input: Option<HashMap<String, DetectorParams>>,
Expand Down Expand Up @@ -398,6 +400,7 @@ pub enum Role {
}

#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Message {
/// The role of the author of this message.
pub role: Role,
Expand Down
83 changes: 83 additions & 0 deletions src/orchestrator/chat_completions_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,4 +473,87 @@ mod tests {
"validation error: Last message role must be user, assistant, or system"
);
}
// validate chat completions request with invalid fields
// (nonexistant fields or typos)
#[tokio::test]
async fn test_validate() {
// Additional unknown field (additional_field)
let json_data = r#"
{
"messages": [
{
"content": "this is a nice sentence",
"role": "user",
"name": "string"
}
],
"model": "my_model",
"additional_field": "test",
"n": 1,
"temperature": 1,
"top_p": 1,
"user": "user-1234",
"detectors": {
"input": {}
}
}
"#;
let result: Result<ChatCompletionsRequest, _> = serde_json::from_str(json_data);
assert!(result.is_err());
let error = result.unwrap_err().to_string();
assert!(error
.to_string()
.contains("unknown field `additional_field"));

// Additional unknown field (additional_message")
let json_data = r#"
{
"messages": [
{
"content": "this is a nice sentence",
"role": "user",
"name": "string",
"additional_msg: "test"
}
],
"model": "my_model",
"n": 1,
"temperature": 1,
"top_p": 1,
"user": "user-1234",
"detectors": {
"input": {}
}
}
"#;
let result: Result<ChatCompletionsRequest, _> = serde_json::from_str(json_data);
assert!(result.is_err());
let error = result.unwrap_err().to_string();
assert!(error.to_string().contains("unknown field `additional_msg"));

// Additional unknown field (typo for input field in detectors)
let json_data = r#"
{
"messages": [
{
"content": "this is a nice sentence",
"role": "user",
"name": "string"
}
],
"model": "my_model",
"n": 1,
"temperature": 1,
"top_p": 1,
"user": "user-1234",
"detectors": {
"inputs": {}
}
}
"#;
let result: Result<ChatCompletionsRequest, _> = serde_json::from_str(json_data);
assert!(result.is_err());
let error = result.unwrap_err().to_string();
assert!(error.to_string().contains("unknown field `inputs"));
}
}