Skip to content

Commit

Permalink
Add Role enum to openai module (#287)
Browse files Browse the repository at this point in the history
* updated role from string to enum for chat messages

Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com>

* corrected role in chat completion delta  to optional

Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com>

---------

Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com>
  • Loading branch information
swith004 authored Jan 29, 2025
1 parent 408badf commit 175bf51
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 14 deletions.
19 changes: 15 additions & 4 deletions src/clients/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,21 @@ pub struct JsonSchemaObject {
pub required: Option<Vec<String>>,
}

#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
#[default]
User,
Developer,
Assistant,
System,
Tool,
}

#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct Message {
/// The role of the messages author.
pub role: String,
/// The role of the author of this message.
pub role: Role,
/// The contents of the message.
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Content>,
Expand Down Expand Up @@ -552,7 +563,7 @@ pub struct ChatCompletionChoice {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionMessage {
/// The role of the author of this message.
pub role: String,
pub role: Role,
/// The contents of the message.
pub content: Option<String>,
/// The tool calls generated by the model, such as function calls.
Expand Down Expand Up @@ -635,7 +646,7 @@ pub struct ChatCompletionChunkChoice {
pub struct ChatCompletionDelta {
/// The role of the author of this message.
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
pub role: Option<Role>,
/// The contents of the message.
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
Expand Down
8 changes: 4 additions & 4 deletions src/orchestrator/chat_completions_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::{
detector::{ChatDetectionRequest, ContentAnalysisRequest},
openai::{
ChatCompletion, ChatCompletionChoice, ChatCompletionsRequest, ChatCompletionsResponse,
ChatDetections, Content, InputDetectionResult, OpenAiClient, OrchestratorWarning,
ChatDetections, Content, InputDetectionResult, OpenAiClient, OrchestratorWarning, Role,
},
},
config::DetectorType,
Expand All @@ -51,7 +51,7 @@ pub struct ChatMessageInternal {
/// Index of the message
pub message_index: usize,
/// The role of the messages author.
pub role: String,
pub role: Role,
/// The contents of the message.
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Content>,
Expand Down Expand Up @@ -425,7 +425,7 @@ mod tests {
let messages = vec![ChatMessageInternal {
message_index: 0,
content: Some(Content::Text("hello".to_string())),
role: "assistant".to_string(),
role: Role::Assistant,
..Default::default()
}];
let processed_messages = preprocess_chat_messages(&ctx, &detectors, messages).unwrap();
Expand Down Expand Up @@ -458,7 +458,7 @@ mod tests {
message_index: 0,
content: Some(Content::Text("hello".to_string())),
// Invalid role will return error used for testing
role: "foo".to_string(),
role: Role::Tool,
..Default::default()
}];

Expand Down
13 changes: 7 additions & 6 deletions src/orchestrator/detector_processing/content.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
*/
use crate::{
clients::openai::Content, models::ValidationError,
clients::openai::{Content, Role},
models::ValidationError,
orchestrator::chat_completions_detection::ChatMessageInternal,
};

Expand All @@ -38,7 +39,7 @@ pub fn filter_chat_messages(
));
}
// 2. Role is user | assistant | system
if !matches!(message.role.as_str(), "user" | "assistant" | "system") {
if !matches!(message.role, Role::User | Role::Assistant | Role::System) {
return Err(ValidationError::Invalid(
"Last message role must be user, assistant, or system".into(),
));
Expand All @@ -62,7 +63,7 @@ mod tests {
let message = vec![ChatMessageInternal {
message_index: 0,
content: Some(Content::Text("hello".to_string())),
role: "assistant".to_string(),
role: Role::Assistant,
..Default::default()
}];

Expand All @@ -79,13 +80,13 @@ mod tests {
ChatMessageInternal {
message_index: 0,
content: Some(Content::Text("hello".to_string())),
role: "assistant".to_string(),
role: Role::Assistant,
..Default::default()
},
ChatMessageInternal {
message_index: 1,
content: Some(Content::Text("bot".to_string())),
role: "assistant".to_string(),
role: Role::Assistant,
..Default::default()
},
];
Expand All @@ -102,7 +103,7 @@ mod tests {
let message = vec![ChatMessageInternal {
message_index: 0,
content: Some(Content::Text("hello".to_string())),
role: "invalid_role".to_string(),
role: Role::Tool,
..Default::default()
}];

Expand Down

0 comments on commit 175bf51

Please sign in to comment.