Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pre-refactor: Add orchestrator.types module with common types #326

Merged
merged 14 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions docs/architecture/adrs/010-detection-batcher.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# ADR 010: DetectionBatcher & DetectionBatchStream

This ADR documents the addition of two new abstractions to handle batching (fka "aggregation") of streaming detection results.

1. `DetectionBatcher`
A trait to implement pluggable batching logic for a `DetectionBatchStream`. It includes an associated `Batch` type, enabling implementations to return different types of batches.

2. `DetectionBatchStream`
A stream adapter that wraps multiple detection streams and produces a stream of batches using a `DetectionBatcher`.

## Motivation

To support initial streaming requirements outlined in ADR 002, we implemented the `Aggregator` and `Tracker` components.

1. `Aggregator` handles batching detections and building results. Internally, it is implemented as 3 actors:

- `AggregationActor`
Aggregates detections and sends them to the `ResultActor`

- `GenerationActor`
Consumes generations from the generation stream, accumulates them, and provides them on-demand to the `ResultActor` to build responses

- `ResultActor`
Builds results from detection batches and generations and sends them to result channel

2. `Tracker` wraps a BTreeMap and contains batching logic. It is used internally by the `AggregationActor`.

The primary issue with these components is that they were designed specifically for the *Streaming Classification With Generation* task and lack flexibility to be extended to additional streaming use cases that require batching detections, e.g.
- A use case may require different batching logic
- A use case may need to use different containers to implement it's batching logic
- A use case may need to return a different batch type
- A use case may need to build a different result type

Additionally, actors are not used in other areas of this codebase and it introduces concepts that may be unfamiliar to new contributors, further increasing the learning curve.

## Decisions

1. The `DetectionBatcher` trait replaces the `Tracker`, enabling flexible and pluggable batching logic tailored to different use cases.

2. The `DetectionBatchStream`, a stream adapter, replaces the `Aggregator`, enabling more flexiblity as it is generic over `DetectionBatcher`.

3. The task of building results is decoupled and delegated to the task handler as a post-batching task. Instead of using an actor to accumulate and own generation/chat completion message state, a task handler can use a shared vec instead, e.g. `Arc<RwLock<Vec<T>>>`, or other approach per use case requirements.

## Notes
1. The existing *Streaming Classification With Generation* batching logic has been re-implemented in `MaxProcessedIndexBatcher`, a `DetectionBatcher` implementation.

## Status

Pending
1 change: 1 addition & 0 deletions src/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub mod common;
pub mod detector_processing;
pub mod streaming;
pub mod streaming_content_detection;
pub mod types;
pub mod unary;

use std::{collections::HashMap, pin::Pin, sync::Arc};
Expand Down
16 changes: 16 additions & 0 deletions src/orchestrator/common.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,18 @@
/*
Copyright FMS Guardrails Orchestrator Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
pub mod utils;
pub use utils::*;
16 changes: 16 additions & 0 deletions src/orchestrator/common/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
/*
Copyright FMS Guardrails Orchestrator Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
use std::{collections::HashMap, sync::Arc};

use crate::{
Expand Down
47 changes: 47 additions & 0 deletions src/orchestrator/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
Copyright FMS Guardrails Orchestrator Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
use std::pin::Pin;

use futures::Stream;
use tokio::sync::mpsc;

pub mod chat_message;
pub use chat_message::*;
pub mod chunk;
pub mod detection;
pub use chunk::*;
pub use detection::*;
pub mod detection_batcher;
pub use detection_batcher::*;
pub mod detection_batch_stream;
pub use detection_batch_stream::*;

use super::Error;
use crate::{clients::openai::ChatCompletionChunk, models::ClassifiedGeneratedTextStreamResult};

pub type ChunkerId = String;
pub type DetectorId = String;
pub type InputId = u32;

pub type BoxStream<T> = Pin<Box<dyn Stream<Item = T> + Send>>;
pub type ChunkStream = BoxStream<Result<Chunk, Error>>;
pub type InputStream = BoxStream<Result<(usize, String), Error>>;
pub type InputSender = mpsc::Sender<Result<(usize, String), Error>>;
pub type InputReceiver = mpsc::Receiver<Result<(usize, String), Error>>;
pub type DetectionStream = BoxStream<Result<(InputId, DetectorId, Chunk, Detections), Error>>;
pub type GenerationStream = BoxStream<(usize, Result<ClassifiedGeneratedTextStreamResult, Error>)>;
pub type ChatCompletionStream = BoxStream<(usize, Result<Option<ChatCompletionChunk>, Error>)>;
72 changes: 72 additions & 0 deletions src/orchestrator/types/chat_message.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
Copyright FMS Guardrails Orchestrator Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

*/
use crate::clients::openai;

/// A chat message.
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ChatMessage<'a> {
/// Message index
/// Corresponds to choice index for chat completions.
pub index: u32,
/// The role of the author of this message.
pub role: Option<&'a openai::Role>,
/// The text contents of the message.
pub text: Option<&'a str>,
}

/// An iterator over chat messages.
pub trait ChatMessageIterator {
/// Returns an iterator of [`ChatMessage`]s.
fn messages(&self) -> impl Iterator<Item = ChatMessage>;
}

impl ChatMessageIterator for openai::ChatCompletionsRequest {
fn messages(&self) -> impl Iterator<Item = ChatMessage> {
self.messages.iter().enumerate().map(|(index, message)| {
let text = if let Some(openai::Content::Text(text)) = &message.content {
Some(text.as_str())
} else {
None
};
ChatMessage {
index: index as u32,
role: Some(&message.role),
text,
}
})
}
}

impl ChatMessageIterator for openai::ChatCompletion {
fn messages(&self) -> impl Iterator<Item = ChatMessage> {
self.choices.iter().map(|choice| ChatMessage {
index: choice.index,
role: Some(&choice.message.role),
text: choice.message.content.as_deref(),
})
}
}

impl ChatMessageIterator for openai::ChatCompletionChunk {
fn messages(&self) -> impl Iterator<Item = ChatMessage> {
self.choices.iter().map(|choice| ChatMessage {
index: choice.index,
role: choice.delta.role.as_ref(),
text: choice.delta.content.as_deref(),
})
}
}
165 changes: 165 additions & 0 deletions src/orchestrator/types/chunk.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
Copyright FMS Guardrails Orchestrator Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
use crate::pb::caikit_data_model::nlp as pb;

/// A chunk.
#[derive(Default, Debug, Clone)]
pub struct Chunk {
/// Index of message where chunk begins
pub input_start_index: usize,
/// Index of message where chunk ends
pub input_end_index: usize,
/// Index of char where chunk begins
pub start: usize,
/// Index of char where chunk ends
pub end: usize,
/// Text
pub text: String,
}

impl PartialOrd for Chunk {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl Ord for Chunk {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
(
self.input_start_index,
self.input_end_index,
self.start,
self.end,
)
.cmp(&(
other.input_start_index,
other.input_end_index,
other.start,
other.end,
))
}
}

impl PartialEq for Chunk {
fn eq(&self, other: &Self) -> bool {
(
self.input_start_index,
self.input_end_index,
self.start,
self.end,
) == (
other.input_start_index,
other.input_end_index,
other.start,
other.end,
)
}
}

impl Eq for Chunk {}

impl std::hash::Hash for Chunk {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.input_start_index.hash(state);
self.input_end_index.hash(state);
self.start.hash(state);
self.end.hash(state);
}
}

/// An array of chunks.
#[derive(Default, Debug, Clone)]
pub struct Chunks(Vec<Chunk>);

impl Chunks {
pub fn new() -> Self {
Self::default()
}
}

impl std::ops::Deref for Chunks {
type Target = Vec<Chunk>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl std::ops::DerefMut for Chunks {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl IntoIterator for Chunks {
type Item = Chunk;
type IntoIter = <Vec<Chunk> as IntoIterator>::IntoIter;

fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}

impl FromIterator<Chunk> for Chunks {
fn from_iter<T: IntoIterator<Item = Chunk>>(iter: T) -> Self {
let mut chunks = Chunks::new();
for value in iter {
chunks.push(value);
}
chunks
}
}

impl From<Vec<Chunk>> for Chunks {
fn from(value: Vec<Chunk>) -> Self {
Self(value)
}
}

// Conversions

impl From<pb::ChunkerTokenizationStreamResult> for Chunk {
fn from(value: pb::ChunkerTokenizationStreamResult) -> Self {
let text = value
.results
.into_iter()
.map(|token| token.text)
.collect::<String>();
Chunk {
input_start_index: value.input_start_index as usize,
input_end_index: value.input_end_index as usize,
start: value.start_index as usize,
end: value.processed_index as usize,
text,
}
}
}

impl From<pb::TokenizationResults> for Chunks {
fn from(value: pb::TokenizationResults) -> Self {
value
.results
.into_iter()
.map(|token| Chunk {
start: token.start as usize,
end: token.end as usize,
text: token.text,
..Default::default()
})
.collect()
}
}
Loading