Skip to content

Commit

Permalink
moved output detections for chat completions and sort detections logi…
Browse files Browse the repository at this point in the history
…c into own function

Signed-off-by: resoluteCoder <resolutecoder@gmail.com>
  • Loading branch information
resoluteCoder committed Feb 5, 2025
1 parent b2d0b2c commit 3df2093
Showing 1 changed file with 127 additions and 120 deletions.
247 changes: 127 additions & 120 deletions src/orchestrator/chat_completions_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
*/
use std::{
arch::x86_64::_MM_FROUND_CUR_DIRECTION,
collections::HashMap,
hash::RandomState,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
Expand Down Expand Up @@ -91,9 +93,8 @@ impl From<&Box<ChatCompletion>> for Vec<ChatMessageInternal> {
value
.choices
.iter()
.enumerate()
.map(|(index, choice)| ChatMessageInternal {
message_index: index,
.map(|choice| ChatMessageInternal {
message_index: choice.index,
role: choice.message.role.clone(),
content: Some(Content::Text(
choice.message.content.clone().unwrap_or_default(),
Expand Down Expand Up @@ -142,27 +143,8 @@ impl Orchestrator {

debug!(?input_detections);

if let Some(mut input_detections) = input_detections {
// Sort input detections by message_index
// input_detections.sort_by_key(|value| value.message_index);
input_detections.sort_by_key(|value| value.index);

let detections = input_detections
.into_iter()
.map(|mut detection| {
let last_idx = detection.results.len();
// sort detection by starting span, if span is not present then move to the end of the message
detection.results.sort_by_key(|r| match r {
GuardrailDetection::ContentAnalysisResponse(value) => value.start,
_ => last_idx,
});
// detection
InputDetectionResult {
message_index: detection.index,
results: detection.results,
}
})
.collect::<Vec<_>>();
if let Some(input_detections) = input_detections {
let detections = sort_detections(input_detections);

Ok(ChatCompletionsResponse::Unary(Box::new(ChatCompletion {
id: Uuid::new_v4().simple().to_string(),
Expand All @@ -173,7 +155,13 @@ impl Orchestrator {
.unwrap()
.as_secs() as i64,
detections: Some(ChatDetections {
input: detections,
input: detections
.into_iter()
.map(|detection_result| InputDetectionResult {
message_index: detection_result.index,
results: detection_result.results,
})
.collect(),
output: vec![],
}),
warnings: vec![OrchestratorWarning::new(
Expand All @@ -200,102 +188,18 @@ impl Orchestrator {
error,
})?;

if let ChatCompletionsResponse::Unary(ref chat_completion) = chat_completions {
let choices = Vec::<ChatMessageInternal>::from(chat_completion);

let output_detections = match detectors.output {
Some(detectors) if !detectors.is_empty() => {
let tasks = choices.into_iter().map(|choice| {
tokio::spawn({
let ctx = ctx.clone();
let detectors = detectors.clone();
let headers = task.headers.clone();
async move {
let result = message_detection(
&ctx,
&detectors,
vec![choice],
&headers,
)
.await;

if let Ok(Some(detection_results)) = result {
return detection_results;
}

vec![]
}
})
});

let detections = try_join_all(tasks).await;

match detections {
Ok(d) => Some(
d.iter()
.flatten()
.cloned()
.collect::<Vec<DetectionResult>>(),
),
Err(_) => None,
}
}
_ => None,
};

debug!(?output_detections);

match output_detections {
Some(mut output_detections) if !output_detections.is_empty() => {
output_detections.sort_by_key(|value| value.index);

let detections = output_detections
.into_iter()
.map(|mut detection| {
let last_idx = detection.results.len();
// sort detection by starting span, if span is not present then move to the end of the message
detection.results.sort_by_key(|r| match r {
GuardrailDetection::ContentAnalysisResponse(value) => {
value.start
}
_ => last_idx,
});
detection
})
.collect::<Vec<_>>();

return Ok(ChatCompletionsResponse::Unary(Box::new(ChatCompletion {
id: Uuid::new_v4().simple().to_string(),
object: chat_completion.object.clone(),
created: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64,
model: model_id.to_string(),
choices: chat_completion.choices.clone(),
usage: chat_completion.usage.clone(),
system_fingerprint: chat_completion.system_fingerprint.clone(),
service_tier: chat_completion.service_tier.clone(),
detections: Some(ChatDetections {
input: vec![],
output: detections
.into_iter()
.map(|detection_result| OutputDetectionResult {
choice_index: detection_result.index,
results: detection_result.results,
})
.collect(),
}),
warnings: vec![OrchestratorWarning::new(
DetectionWarningReason::UnsuitableOutput,
UNSUITABLE_OUTPUT_MESSAGE,
)],
})));
}
_ => {}
}
match handle_output_detections(
&chat_completions,
detectors.output,
ctx,
&task.headers,
model_id,
)
.await
{
Some(chat_completion_detections) => Ok(chat_completion_detections),
None => Ok(chat_completions),
}
Ok(chat_completions)
}
});

Expand Down Expand Up @@ -511,6 +415,109 @@ async fn detector_chunk_task(
Ok(chunks)
}

fn sort_detections(mut detections: Vec<DetectionResult>) -> Vec<DetectionResult> {
// Sort input detections by message_index
detections.sort_by_key(|value| value.index);

detections
.into_iter()
.map(|mut detection| {
let last_idx = detection.results.len();
// sort detection by starting span, if span is not present then move to the end of the message
detection.results.sort_by_key(|r| match r {
GuardrailDetection::ContentAnalysisResponse(value) => value.start,
_ => last_idx,
});
detection
})
.collect::<Vec<_>>()
}

async fn handle_output_detections(
chat_completions: &ChatCompletionsResponse,
detector_output: Option<HashMap<String, DetectorParams>>,
ctx: Arc<Context>,
headers: &HeaderMap,
model_id: String,
) -> Option<ChatCompletionsResponse> {
if let ChatCompletionsResponse::Unary(ref chat_completion) = chat_completions {
let choices = Vec::<ChatMessageInternal>::from(chat_completion);

let output_detections = match detector_output {
Some(detectors) if !detectors.is_empty() => {
let tasks = choices.into_iter().map(|choice| {
tokio::spawn({
let ctx = ctx.clone();
let detectors = detectors.clone();
let headers = headers.clone();
async move {
let result =
message_detection(&ctx, &detectors, vec![choice], &headers).await;

if let Ok(Some(detection_results)) = result {
return detection_results;
}

vec![]
}
})
});

let detections = try_join_all(tasks).await;

match detections {
Ok(d) => Some(
d.iter()
.flatten()
.cloned()
.collect::<Vec<DetectionResult>>(),
),
Err(_) => None,
}
}
_ => None,
};

debug!(?output_detections);

match output_detections {
Some(output_detections) if !output_detections.is_empty() => {
let detections = sort_detections(output_detections);

return Some(ChatCompletionsResponse::Unary(Box::new(ChatCompletion {
id: Uuid::new_v4().simple().to_string(),
object: chat_completion.object.clone(),
created: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64,
model: model_id.to_string(),
choices: chat_completion.choices.clone(),
usage: chat_completion.usage.clone(),
system_fingerprint: chat_completion.system_fingerprint.clone(),
service_tier: chat_completion.service_tier.clone(),
detections: Some(ChatDetections {
input: vec![],
output: detections
.into_iter()
.map(|detection_result| OutputDetectionResult {
choice_index: detection_result.index,
results: detection_result.results,
})
.collect(),
}),
warnings: vec![OrchestratorWarning::new(
DetectionWarningReason::UnsuitableOutput,
UNSUITABLE_OUTPUT_MESSAGE,
)],
})));
}
_ => {}
}
}
None
}

#[cfg(test)]
mod tests {
use std::any::{Any, TypeId};
Expand Down

0 comments on commit 3df2093

Please sign in to comment.