Skip to content

Commit

Permalink
Update DetectionBatchStream to wrap batch_rx directly, update comments
Browse files Browse the repository at this point in the history
Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com>
  • Loading branch information
declark1 committed Mar 4, 2025
1 parent 4e20565 commit c24215e
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions src/orchestrator/types/detection_batch_stream.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use futures::{stream, Stream, StreamExt};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;

use super::{BoxStream, Chunk, DetectionBatcher, DetectionStream, Detections, DetectorId, InputId};
use super::{Chunk, DetectionBatcher, DetectionStream, Detections, DetectorId, InputId};
use crate::orchestrator::Error;

/// Wraps detection streams and produces a stream
/// of batches using a [`DetectionBatcher`].
pub struct DetectionBatchStream<B: DetectionBatcher> {
inner: BoxStream<Result<B::Batch, Error>>,
batch_rx: mpsc::Receiver<Result<B::Batch, Error>>,
}

impl<B> DetectionBatchStream<B>
Expand All @@ -23,33 +22,30 @@ where
mpsc::channel::<Result<(InputId, DetectorId, Chunk, Detections), Error>>(32);

// Spawn batcher task
// This task receives new detections, pushes them to a batcher, and sends
// batches to the batch (output) stream as they become ready.
// This task receives new detections, pushes them to the batcher,
// and sends batches to the batch (output) channel as they become ready.
tokio::spawn(async move {
loop {
tokio::select! {
result = batcher_rx.recv() => {
match result {
Some(Ok((input_id, detector_id, chunk, detections))) => {
// Received new detections
// Push detections to batcher
batcher.push(input_id, detector_id, chunk, detections);

// Check if we have any batches ready
// Check if the next batch is ready
if let Some(batch) = batcher.pop_batch() {
// Send batch to batch channel
let _ = batch_tx.send(Ok(batch)).await;
}
},
Some(Err(error)) => {
// Received error
// Send error to batch channel
let _ = batch_tx.send(Err(error)).await;
break;
},
None => {
// Batcher stream closed
// Terminate task
// Batcher channel closed
break;
},
}
Expand All @@ -58,22 +54,19 @@ where
}
});

// Create a stream set (a single stream) from multiple detection streams
// Create a stream set (single stream) from multiple detection streams
let mut stream_set = stream::select_all(streams);

// Spawn detection consumer task
// This task consumes new detections and sends them to the batcher task.
tokio::spawn(async move {
while let Some(result) = stream_set.next().await {
// Received new detections
// Send to batcher task
// Send new detections to batcher task
let _ = batcher_tx.send(result).await;
}
});

Self {
inner: ReceiverStream::new(batch_rx).boxed(),
}
Self { batch_rx }
}
}

Expand All @@ -84,6 +77,6 @@ impl<T: DetectionBatcher> Stream for DetectionBatchStream<T> {
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
self.batch_rx.poll_recv(cx)
}
}

0 comments on commit c24215e

Please sign in to comment.