diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 4a660a5d..3bea007e 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -386,10 +386,21 @@ pub struct JsonSchemaObject { pub required: Option>, } +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + #[default] + User, + Developer, + Assistant, + System, + Tool, +} + #[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct Message { - /// The role of the messages author. - pub role: String, + /// The role of the author of this message. + pub role: Role, /// The contents of the message. #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, @@ -552,7 +563,7 @@ pub struct ChatCompletionChoice { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionMessage { /// The role of the author of this message. - pub role: String, + pub role: Role, /// The contents of the message. pub content: Option, /// The tool calls generated by the model, such as function calls. @@ -635,7 +646,7 @@ pub struct ChatCompletionChunkChoice { pub struct ChatCompletionDelta { /// The role of the author of this message. #[serde(skip_serializing_if = "Option::is_none")] - pub role: Option, + pub role: Option, /// The contents of the message. #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, diff --git a/src/orchestrator/chat_completions_detection.rs b/src/orchestrator/chat_completions_detection.rs index 1b40c9eb..13519326 100644 --- a/src/orchestrator/chat_completions_detection.rs +++ b/src/orchestrator/chat_completions_detection.rs @@ -32,7 +32,7 @@ use crate::{ detector::{ChatDetectionRequest, ContentAnalysisRequest}, openai::{ ChatCompletion, ChatCompletionChoice, ChatCompletionsRequest, ChatCompletionsResponse, - ChatDetections, Content, InputDetectionResult, OpenAiClient, OrchestratorWarning, + ChatDetections, Content, InputDetectionResult, OpenAiClient, OrchestratorWarning, Role, }, }, config::DetectorType, @@ -51,7 +51,7 @@ pub struct ChatMessageInternal { /// Index of the message pub message_index: usize, /// The role of the messages author. - pub role: String, + pub role: Role, /// The contents of the message. #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, @@ -425,7 +425,7 @@ mod tests { let messages = vec![ChatMessageInternal { message_index: 0, content: Some(Content::Text("hello".to_string())), - role: "assistant".to_string(), + role: Role::Assistant, ..Default::default() }]; let processed_messages = preprocess_chat_messages(&ctx, &detectors, messages).unwrap(); @@ -458,7 +458,7 @@ mod tests { message_index: 0, content: Some(Content::Text("hello".to_string())), // Invalid role will return error used for testing - role: "foo".to_string(), + role: Role::Tool, ..Default::default() }]; diff --git a/src/orchestrator/detector_processing/content.rs b/src/orchestrator/detector_processing/content.rs index 96cca597..74f7bc9a 100644 --- a/src/orchestrator/detector_processing/content.rs +++ b/src/orchestrator/detector_processing/content.rs @@ -15,7 +15,8 @@ */ use crate::{ - clients::openai::Content, models::ValidationError, + clients::openai::{Content, Role}, + models::ValidationError, orchestrator::chat_completions_detection::ChatMessageInternal, }; @@ -38,7 +39,7 @@ pub fn filter_chat_messages( )); } // 2. Role is user | assistant | system - if !matches!(message.role.as_str(), "user" | "assistant" | "system") { + if !matches!(message.role, Role::User | Role::Assistant | Role::System) { return Err(ValidationError::Invalid( "Last message role must be user, assistant, or system".into(), )); @@ -62,7 +63,7 @@ mod tests { let message = vec![ChatMessageInternal { message_index: 0, content: Some(Content::Text("hello".to_string())), - role: "assistant".to_string(), + role: Role::Assistant, ..Default::default() }]; @@ -79,13 +80,13 @@ mod tests { ChatMessageInternal { message_index: 0, content: Some(Content::Text("hello".to_string())), - role: "assistant".to_string(), + role: Role::Assistant, ..Default::default() }, ChatMessageInternal { message_index: 1, content: Some(Content::Text("bot".to_string())), - role: "assistant".to_string(), + role: Role::Assistant, ..Default::default() }, ]; @@ -102,7 +103,7 @@ mod tests { let message = vec![ChatMessageInternal { message_index: 0, content: Some(Content::Text("hello".to_string())), - role: "invalid_role".to_string(), + role: Role::Tool, ..Default::default() }];