From d886984031f41873dfaa98aa27541b076053b4d4 Mon Sep 17 00:00:00 2001 From: declark1 <44146800+declark1@users.noreply.github.com> Date: Tue, 4 Mar 2025 11:21:11 -0800 Subject: [PATCH] Add test_detection_batch_stream Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- .../detection_batcher/max_processed_index.rs | 142 +++++++++++++++++- 1 file changed, 137 insertions(+), 5 deletions(-) diff --git a/src/orchestrator/types/detection_batcher/max_processed_index.rs b/src/orchestrator/types/detection_batcher/max_processed_index.rs index 823b39dc..813e7e6c 100644 --- a/src/orchestrator/types/detection_batcher/max_processed_index.rs +++ b/src/orchestrator/types/detection_batcher/max_processed_index.rs @@ -72,11 +72,20 @@ impl DetectionBatcher for MaxProcessedIndexBatcher { #[cfg(test)] mod test { + use std::task::Poll; + + use futures::StreamExt; + use tokio::sync::mpsc; + use tokio_stream::wrappers::ReceiverStream; + use super::*; - use crate::orchestrator::types::Detection; + use crate::orchestrator::{ + types::{Detection, DetectionBatchStream}, + Error, + }; #[test] - fn test_single_chunk_multiple_detectors() { + fn test_batcher_single_chunk_multiple_detectors() { let input_id = 0; let chunk = Chunk { input_start_index: 0, @@ -145,7 +154,7 @@ mod test { } #[test] - fn test_out_of_order_chunks() { + fn test_batcher_out_of_order_chunks() { let input_id = 0; let chunks = [ Chunk { @@ -194,7 +203,7 @@ mod test { Detections::default(), // no detections ); - // We have detections for chunk-2, but not chunk-1 + // We have all detections for chunk-2, but not chunk-1 // pop_batch() should return None assert!(batcher.pop_batch().is_none()); @@ -214,7 +223,7 @@ mod test { .into(), ); - // We have detections for chunk-1 and chunk-2 + // We have all detections for chunk-1 and chunk-2 // pop_batch() should return chunk-1 with 1 pii detection let batch = batcher.pop_batch(); assert!(batch @@ -228,4 +237,127 @@ mod test { // batcher state should be empty as all batches have been returned assert!(batcher.state.is_empty()); } + + #[tokio::test] + async fn test_detection_batch_stream() -> Result<(), Error> { + let input_id = 0; + let chunks = [ + Chunk { + input_start_index: 0, + input_end_index: 10, + start: 0, + end: 56, + text: " a powerful tool for the development \ + of complex systems." + .into(), + }, + Chunk { + input_start_index: 11, + input_end_index: 26, + start: 56, + end: 135, + text: " It has been used in many fields, such as \ + computer vision and image processing." + .into(), + }, + ]; + + // Create detection channels and streams + let (pii_detections_tx, pii_detections_rx) = + mpsc::channel::>(4); + let pii_detections_stream = ReceiverStream::new(pii_detections_rx).boxed(); + let (hap_detections_tx, hap_detections_rx) = + mpsc::channel::>(4); + let hap_detections_stream = ReceiverStream::new(hap_detections_rx).boxed(); + + // Create a batcher that will process batches for 2 detectors + let n = 2; + let batcher = MaxProcessedIndexBatcher::new(n); + + // Create detection batch stream + let streams = vec![pii_detections_stream, hap_detections_stream]; + let mut detection_batch_stream = DetectionBatchStream::new(batcher, streams); + + // Send chunk-2 detections for pii detector + let _ = pii_detections_tx + .send(Ok(( + input_id, + "pii".into(), + chunks[1].clone(), + Detections::default(), // no detections + ))) + .await; + + // Send chunk-1 detections for hap detector + let _ = hap_detections_tx + .send(Ok(( + input_id, + "hap".into(), + chunks[0].clone(), + Detections::default(), // no detections + ))) + .await; + + // Send chunk-2 detections for hap detector + let _ = hap_detections_tx + .send(Ok(( + input_id, + "hap".into(), + chunks[1].clone(), + Detections::default(), // no detections + ))) + .await; + + // We have all detections for chunk-2, but not chunk-1 + // detection_batch_stream.next() future should not be ready + assert!(matches!( + futures::poll!(detection_batch_stream.next()), + Poll::Pending + )); + + // Send chunk-1 detections for pii detector + let _ = pii_detections_tx + .send(Ok(( + input_id, + "pii".into(), + chunks[0].clone(), + vec![Detection { + start: Some(10), + end: Some(20), + detector_id: Some("pii".into()), + detection_type: "pii".into(), + score: 0.4, + ..Default::default() + }] + .into(), + ))) + .await; + + // We have all detections for chunk-1 and chunk-2 + // detection_batch_stream.next() should be ready and return chunk-1 with 1 pii detection + let batch = detection_batch_stream.next().await; + assert!(batch.is_some_and(|result| result + .is_ok_and(|(chunk, detections)| chunk == chunks[0] && detections.len() == 1))); + + // detection_batch_stream.next() should be ready and return chunk-2 with no detections + let batch = detection_batch_stream.next().await; + assert!(batch.is_some_and(|result| result + .is_ok_and(|(chunk, detections)| chunk == chunks[1] && detections.is_empty()))); + + // detection_batch_stream.next() future should not be ready + // as detection senders have not been closed + assert!(matches!( + futures::poll!(detection_batch_stream.next()), + Poll::Pending + )); + + // Drop detection senders + drop(pii_detections_tx); + drop(hap_detections_tx); + + // detection_batch_stream.next() should return None + assert!(detection_batch_stream.next().await.is_none()); + + Ok(()) + } }