diff --git a/crates/polars-stream/src/nodes/io_sinks/csv.rs b/crates/polars-stream/src/nodes/io_sinks/csv.rs index e52b0244b317..3176f4003aaf 100644 --- a/crates/polars-stream/src/nodes/io_sinks/csv.rs +++ b/crates/polars-stream/src/nodes/io_sinks/csv.rs @@ -10,13 +10,12 @@ use polars_io::SerWriter; use polars_plan::dsl::SinkOptions; use polars_utils::priority::Priority; -use super::{SinkInputPort, SinkNode, SinkRecvPort}; +use super::{SinkInputPort, SinkNode}; use crate::async_executor::spawn; -use crate::async_primitives::linearizer::Linearizer; -use crate::nodes::io_sinks::{tokio_sync_on_close, DEFAULT_SINK_LINEARIZER_BUFFER_SIZE}; -use crate::nodes::{JoinHandle, MorselSeq, TaskPriority}; +use crate::async_primitives::connector::Receiver; +use crate::nodes::io_sinks::{parallelize_receive_task, tokio_sync_on_close}; +use crate::nodes::{JoinHandle, PhaseOutcome, TaskPriority}; -type Linearized = Priority, Vec>; pub struct CsvSinkNode { path: PathBuf, schema: SchemaRef, @@ -47,39 +46,18 @@ impl SinkNode for CsvSinkNode { fn is_sink_input_parallel(&self) -> bool { true } - fn do_maintain_order(&self) -> bool { - self.sink_options.maintain_order - } fn spawn_sink( &mut self, num_pipelines: usize, - recv_ports_recv: SinkRecvPort, + recv_port_rx: Receiver<(PhaseOutcome, SinkInputPort)>, _state: &ExecutionState, join_handles: &mut Vec>>, ) { - let rxs = recv_ports_recv.parallel(join_handles); - self.spawn_sink_once( - num_pipelines, - SinkInputPort::Parallel(rxs), - _state, + let (pass_rxs, mut io_rx) = parallelize_receive_task( join_handles, - ); - } - - fn spawn_sink_once( - &mut self, - num_pipelines: usize, - recv_port: SinkInputPort, - _state: &ExecutionState, - join_handles: &mut Vec>>, - ) { - // .. -> Encode task - let rxs = recv_port.parallel(); - // Encode tasks -> IO task - let (mut lin_rx, lin_txs) = Linearizer::::new_with_maintain_order( + recv_port_rx, num_pipelines, - DEFAULT_SINK_LINEARIZER_BUFFER_SIZE, self.sink_options.maintain_order, ); @@ -89,7 +67,7 @@ impl SinkNode for CsvSinkNode { // Encode task. // // Task encodes the columns into their corresponding CSV encoding. - join_handles.extend(rxs.into_iter().zip(lin_txs).map(|(mut rx, mut lin_tx)| { + join_handles.extend(pass_rxs.into_iter().map(|mut pass_rx| { let schema = self.schema.clone(); let options = self.write_options.clone(); @@ -99,34 +77,36 @@ impl SinkNode for CsvSinkNode { let mut allocation_size = DEFAULT_ALLOCATION_SIZE; let options = options.clone(); - while let Ok(morsel) = rx.recv().await { - let (df, seq, _, consume_token) = morsel.into_inner(); - - let mut buffer = Vec::with_capacity(allocation_size); - let mut writer = CsvWriter::new(&mut buffer) - .include_bom(false) // Handled once in the IO task. - .include_header(false) // Handled once in the IO task. - .with_separator(options.serialize_options.separator) - .with_line_terminator(options.serialize_options.line_terminator.clone()) - .with_quote_char(options.serialize_options.quote_char) - .with_datetime_format(options.serialize_options.datetime_format.clone()) - .with_date_format(options.serialize_options.date_format.clone()) - .with_time_format(options.serialize_options.time_format.clone()) - .with_float_scientific(options.serialize_options.float_scientific) - .with_float_precision(options.serialize_options.float_precision) - .with_null_value(options.serialize_options.null.clone()) - .with_quote_style(options.serialize_options.quote_style) - .n_threads(1) // Disable rayon parallelism - .batched(&schema)?; - - writer.write_batch(&df)?; - - allocation_size = allocation_size.max(buffer.len()); - if lin_tx.insert(Priority(Reverse(seq), buffer)).await.is_err() { - return Ok(()); + while let Ok((mut rx, mut lin_tx)) = pass_rx.recv().await { + while let Ok(morsel) = rx.recv().await { + let (df, seq, _, consume_token) = morsel.into_inner(); + + let mut buffer = Vec::with_capacity(allocation_size); + let mut writer = CsvWriter::new(&mut buffer) + .include_bom(false) // Handled once in the IO task. + .include_header(false) // Handled once in the IO task. + .with_separator(options.serialize_options.separator) + .with_line_terminator(options.serialize_options.line_terminator.clone()) + .with_quote_char(options.serialize_options.quote_char) + .with_datetime_format(options.serialize_options.datetime_format.clone()) + .with_date_format(options.serialize_options.date_format.clone()) + .with_time_format(options.serialize_options.time_format.clone()) + .with_float_scientific(options.serialize_options.float_scientific) + .with_float_precision(options.serialize_options.float_precision) + .with_null_value(options.serialize_options.null.clone()) + .with_quote_style(options.serialize_options.quote_style) + .n_threads(1) // Disable rayon parallelism + .batched(&schema)?; + + writer.write_batch(&df)?; + + allocation_size = allocation_size.max(buffer.len()); + if lin_tx.insert(Priority(Reverse(seq), buffer)).await.is_err() { + return Ok(()); + } + drop(consume_token); // Keep the consume_token until here to increase the + // backpressure. } - drop(consume_token); // Keep the consume_token until here to increase the - // backpressure. } PolarsResult::Ok(()) @@ -165,8 +145,10 @@ impl SinkNode for CsvSinkNode { file = tokio::fs::File::from_std(std_file); } - while let Some(Priority(_, buffer)) = lin_rx.get().await { - file.write_all(&buffer).await?; + while let Ok(mut lin_rx) = io_rx.recv().await { + while let Some(Priority(_, buffer)) = lin_rx.get().await { + file.write_all(&buffer).await?; + } } tokio_sync_on_close(sink_options.sync_on_close, &mut file).await?; diff --git a/crates/polars-stream/src/nodes/io_sinks/ipc.rs b/crates/polars-stream/src/nodes/io_sinks/ipc.rs index c660f07a0868..bce0d4f5e783 100644 --- a/crates/polars-stream/src/nodes/io_sinks/ipc.rs +++ b/crates/polars-stream/src/nodes/io_sinks/ipc.rs @@ -18,16 +18,16 @@ use polars_plan::dsl::SinkOptions; use polars_utils::priority::Priority; use super::{ - buffer_and_distribute_columns_task, SinkInputPort, SinkNode, SinkRecvPort, + buffer_and_distribute_columns_task, SinkInputPort, SinkNode, DEFAULT_SINK_DISTRIBUTOR_BUFFER_SIZE, DEFAULT_SINK_LINEARIZER_BUFFER_SIZE, }; use crate::async_executor::spawn; -use crate::async_primitives::connector::connector; +use crate::async_primitives::connector::{connector, Receiver}; use crate::async_primitives::distributor_channel::distributor_channel; use crate::async_primitives::linearizer::Linearizer; use crate::morsel::get_ideal_morsel_size; use crate::nodes::io_sinks::sync_on_close; -use crate::nodes::{JoinHandle, TaskPriority}; +use crate::nodes::{JoinHandle, PhaseOutcome, TaskPriority}; pub struct IpcSinkNode { path: PathBuf, @@ -77,28 +77,10 @@ impl SinkNode for IpcSinkNode { fn spawn_sink( &mut self, num_pipelines: usize, - recv_ports_recv: SinkRecvPort, + recv_port_rx: Receiver<(PhaseOutcome, SinkInputPort)>, _state: &ExecutionState, join_handles: &mut Vec>>, ) { - let rx = recv_ports_recv.serial(join_handles); - self.spawn_sink_once( - num_pipelines, - SinkInputPort::Serial(rx), - _state, - join_handles, - ); - } - - fn spawn_sink_once( - &mut self, - num_pipelines: usize, - recv_port: SinkInputPort, - _state: &ExecutionState, - join_handles: &mut Vec>>, - ) { - // .. -> Buffer task - let buffer_rx = recv_port.serial(); // Buffer task -> Encode tasks let (dist_tx, dist_rxs) = distributor_channel(num_pipelines, DEFAULT_SINK_DISTRIBUTOR_BUFFER_SIZE); @@ -124,7 +106,7 @@ impl SinkNode for IpcSinkNode { // Buffer task. join_handles.push(buffer_and_distribute_columns_task( - buffer_rx, + recv_port_rx, dist_tx, chunk_size, self.input_schema.clone(), diff --git a/crates/polars-stream/src/nodes/io_sinks/json.rs b/crates/polars-stream/src/nodes/io_sinks/json.rs index 85c5653dc075..107884c8ff00 100644 --- a/crates/polars-stream/src/nodes/io_sinks/json.rs +++ b/crates/polars-stream/src/nodes/io_sinks/json.rs @@ -7,13 +7,12 @@ use polars_io::json::BatchedWriter; use polars_plan::dsl::SinkOptions; use polars_utils::priority::Priority; -use super::{SinkInputPort, SinkNode, SinkRecvPort}; +use super::{SinkInputPort, SinkNode}; use crate::async_executor::spawn; -use crate::async_primitives::linearizer::Linearizer; -use crate::nodes::io_sinks::{tokio_sync_on_close, DEFAULT_SINK_LINEARIZER_BUFFER_SIZE}; -use crate::nodes::{JoinHandle, MorselSeq, TaskPriority}; +use crate::async_primitives::connector::Receiver; +use crate::nodes::io_sinks::{parallelize_receive_task, tokio_sync_on_close}; +use crate::nodes::{JoinHandle, PhaseOutcome, TaskPriority}; -type Linearized = Priority, Vec>; pub struct NDJsonSinkNode { path: PathBuf, sink_options: SinkOptions, @@ -39,32 +38,14 @@ impl SinkNode for NDJsonSinkNode { fn spawn_sink( &mut self, num_pipelines: usize, - recv_ports_recv: SinkRecvPort, + recv_port_rx: Receiver<(PhaseOutcome, SinkInputPort)>, _state: &ExecutionState, join_handles: &mut Vec>>, ) { - let rxs = recv_ports_recv.parallel(join_handles); - self.spawn_sink_once( - num_pipelines, - SinkInputPort::Parallel(rxs), - _state, + let (pass_rxs, mut io_rx) = parallelize_receive_task( join_handles, - ); - } - - fn spawn_sink_once( - &mut self, - num_pipelines: usize, - recv_port: SinkInputPort, - _state: &ExecutionState, - join_handles: &mut Vec>>, - ) { - // .. -> Encode task - let rxs = recv_port.parallel(); - // Encode tasks -> IO task - let (mut lin_rx, lin_txs) = Linearizer::::new_with_maintain_order( + recv_port_rx, num_pipelines, - DEFAULT_SINK_LINEARIZER_BUFFER_SIZE, self.sink_options.maintain_order, ); @@ -74,26 +55,28 @@ impl SinkNode for NDJsonSinkNode { // Encode task. // // Task encodes the columns into their corresponding JSON encoding. - join_handles.extend(rxs.into_iter().zip(lin_txs).map(|(mut rx, mut lin_tx)| { + join_handles.extend(pass_rxs.into_iter().map(|mut pass_rx| { spawn(TaskPriority::High, async move { // Amortize the allocations over time. If we see that we need to do way larger // allocations, we adjust to that over time. let mut allocation_size = DEFAULT_ALLOCATION_SIZE; - while let Ok(morsel) = rx.recv().await { - let (df, seq, _, consume_token) = morsel.into_inner(); + while let Ok((mut rx, mut lin_tx)) = pass_rx.recv().await { + while let Ok(morsel) = rx.recv().await { + let (df, seq, _, consume_token) = morsel.into_inner(); - let mut buffer = Vec::with_capacity(allocation_size); - let mut writer = BatchedWriter::new(&mut buffer); + let mut buffer = Vec::with_capacity(allocation_size); + let mut writer = BatchedWriter::new(&mut buffer); - writer.write_batch(&df)?; + writer.write_batch(&df)?; - allocation_size = allocation_size.max(buffer.len()); - if lin_tx.insert(Priority(Reverse(seq), buffer)).await.is_err() { - return Ok(()); + allocation_size = allocation_size.max(buffer.len()); + if lin_tx.insert(Priority(Reverse(seq), buffer)).await.is_err() { + return Ok(()); + } + drop(consume_token); // Keep the consume_token until here to increase the + // backpressure. } - drop(consume_token); // Keep the consume_token until here to increase the - // backpressure. } PolarsResult::Ok(()) @@ -117,8 +100,10 @@ impl SinkNode for NDJsonSinkNode { .await .map_err(|err| polars_utils::_limit_path_len_io_err(path.as_path(), err))?; - while let Some(Priority(_, buffer)) = lin_rx.get().await { - file.write_all(&buffer).await?; + while let Ok(mut lin_rx) = io_rx.recv().await { + while let Some(Priority(_, buffer)) = lin_rx.get().await { + file.write_all(&buffer).await?; + } } tokio_sync_on_close(sink_options.sync_on_close, &mut file).await?; diff --git a/crates/polars-stream/src/nodes/io_sinks/mod.rs b/crates/polars-stream/src/nodes/io_sinks/mod.rs index ffa888268582..0ab1a93458c0 100644 --- a/crates/polars-stream/src/nodes/io_sinks/mod.rs +++ b/crates/polars-stream/src/nodes/io_sinks/mod.rs @@ -10,13 +10,13 @@ use polars_error::PolarsResult; use polars_expr::state::ExecutionState; use polars_plan::dsl::SyncOnCloseType; -use super::io_sources::PhaseOutcomeToken; use super::{ ComputeNode, JoinHandle, Morsel, PhaseOutcome, PortState, RecvPort, SendPort, TaskScope, }; use crate::async_executor::{spawn, AbortOnDropHandle}; use crate::async_primitives::connector::{connector, Receiver, Sender}; use crate::async_primitives::distributor_channel; +use crate::async_primitives::linearizer::{Inserter, Linearizer}; use crate::async_primitives::wait_group::WaitGroup; use crate::nodes::TaskPriority; @@ -39,11 +39,6 @@ pub enum SinkInputPort { Parallel(Vec>), } -pub struct SinkRecvPort { - num_pipelines: usize, - recv: Receiver<(PhaseOutcome, SinkInputPort)>, -} - impl SinkInputPort { pub fn serial(self) -> Receiver { match self { @@ -60,86 +55,10 @@ impl SinkInputPort { } } -impl SinkRecvPort { - pub fn parallel( - mut self, - join_handles: &mut Vec>>, - ) -> Vec> { - let (txs, rxs) = (0..self.num_pipelines) - .map(|_| connector()) - .collect::<(Vec<_>, Vec<_>)>(); - let (mut pass_txs, pass_rxs) = (0..self.num_pipelines) - .map(|_| connector()) - .collect::<(Vec<_>, Vec<_>)>(); - let mut outcomes = Vec::::with_capacity(self.num_pipelines); - let wg = WaitGroup::default(); - - join_handles.push(spawn(TaskPriority::High, async move { - while let Ok((outcome, port_rxs)) = self.recv.recv().await { - let port_rxs = port_rxs.parallel(); - for (pass_tx, port_rx) in pass_txs.iter_mut().zip(port_rxs) { - let (token, outcome) = PhaseOutcome::new_shared_wait(wg.token()); - if pass_tx.send((outcome, port_rx)).await.is_err() { - return Ok(()); - } - outcomes.push(token); - } - - wg.wait().await; - for outcome_token in &outcomes { - if outcome_token.did_finish() { - return Ok(()); - } - } - outcomes.clear(); - outcome.stopped(); - } - - Ok(()) - })); - join_handles.extend(pass_rxs.into_iter().zip(txs).map(|(mut pass_rx, mut tx)| { - spawn(TaskPriority::High, async move { - while let Ok((outcome, mut rx)) = pass_rx.recv().await { - while let Ok(morsel) = rx.recv().await { - if tx.send(morsel).await.is_err() { - return Ok(()); - } - } - outcome.stopped(); - } - Ok(()) - }) - })); - - rxs - } - - /// Serialize the input and allow for long lived lasts to listen to a constant channel. - pub fn serial( - mut self, - join_handles: &mut Vec>>, - ) -> Receiver { - let (mut tx, rx) = connector(); - join_handles.push(spawn(TaskPriority::High, async move { - while let Ok((outcome, port_rx)) = self.recv.recv().await { - let mut port_rx = port_rx.serial(); - while let Ok(morsel) = port_rx.recv().await { - if tx.send(morsel).await.is_err() { - return Ok(()); - } - } - outcome.stopped(); - } - Ok(()) - })); - rx - } -} - /// Spawn a task that linearizes and buffers morsels until a given a maximum chunk size is reached /// and then distributes the columns amongst worker tasks. fn buffer_and_distribute_columns_task( - mut rx: Receiver, + mut recv_port_rx: Receiver<(PhaseOutcome, SinkInputPort)>, mut dist_tx: distributor_channel::Sender<(usize, usize, Column)>, chunk_size: usize, schema: SchemaRef, @@ -148,26 +67,31 @@ fn buffer_and_distribute_columns_task( let mut seq = 0usize; let mut buffer = DataFrame::empty_with_schema(schema.as_ref()); - while let Ok(morsel) = rx.recv().await { - let (df, _, _, consume_token) = morsel.into_inner(); - // @NOTE: This also performs schema validation. - buffer.vstack_mut(&df)?; + while let Ok((outcome, rx)) = recv_port_rx.recv().await { + let mut rx = rx.serial(); + while let Ok(morsel) = rx.recv().await { + let (df, _, _, consume_token) = morsel.into_inner(); + // @NOTE: This also performs schema validation. + buffer.vstack_mut(&df)?; - while buffer.height() >= chunk_size { - let df; - (df, buffer) = buffer.split_at(buffer.height().min(chunk_size) as i64); + while buffer.height() >= chunk_size { + let df; + (df, buffer) = buffer.split_at(buffer.height().min(chunk_size) as i64); - for (i, column) in df.take_columns().into_iter().enumerate() { - if dist_tx.send((seq, i, column)).await.is_err() { - return Ok(()); + for (i, column) in df.take_columns().into_iter().enumerate() { + if dist_tx.send((seq, i, column)).await.is_err() { + return Ok(()); + } } + seq += 1; } - seq += 1; + drop(consume_token); // Increase the backpressure. Only free up a pipeline when the + // morsel has started encoding in its entirety. This still + // allows for parallelism of Morsels, but prevents large + // bunches of Morsels from stacking up here. } - drop(consume_token); // Increase the backpressure. Only free up a pipeline when the - // morsel has started encoding in its entirety. This still - // allows for parallelism of Morsels, but prevents large - // bunches of Morsels from stacking up here. + + outcome.stopped(); } // Flush the remaining rows. @@ -182,22 +106,61 @@ fn buffer_and_distribute_columns_task( }) } +#[allow(clippy::type_complexity)] +pub fn parallelize_receive_task( + join_handles: &mut Vec>>, + mut recv_port_rx: Receiver<(PhaseOutcome, SinkInputPort)>, + num_pipelines: usize, + maintain_order: bool, +) -> ( + Vec, Inserter)>>, + Receiver>, +) { + // Phase Handling Task -> Encode Tasks. + let (mut pass_txs, pass_rxs) = (0..num_pipelines) + .map(|_| connector()) + .collect::<(Vec<_>, Vec<_>)>(); + let (mut io_tx, io_rx) = connector(); + + join_handles.push(spawn(TaskPriority::High, async move { + while let Ok((outcome, port_rxs)) = recv_port_rx.recv().await { + let port_rxs = port_rxs.parallel(); + let (lin_rx, lin_txs) = Linearizer::::new_with_maintain_order( + num_pipelines, + DEFAULT_SINK_LINEARIZER_BUFFER_SIZE, + maintain_order, + ); + + for ((pass_tx, port_rx), lin_tx) in pass_txs.iter_mut().zip(port_rxs).zip(lin_txs) { + if pass_tx.send((port_rx, lin_tx)).await.is_err() { + return Ok(()); + } + } + if io_tx.send(lin_rx).await.is_err() { + return Ok(()); + } + + outcome.stopped(); + } + + Ok(()) + })); + + (pass_rxs, io_rx) +} + pub trait SinkNode { fn name(&self) -> &str; + fn is_sink_input_parallel(&self) -> bool; - fn do_maintain_order(&self) -> bool; + fn do_maintain_order(&self) -> bool { + true + } fn spawn_sink( &mut self, num_pipelines: usize, - recv_ports_recv: SinkRecvPort, - state: &ExecutionState, - join_handles: &mut Vec>>, - ); - fn spawn_sink_once( - &mut self, - num_pipelines: usize, - recv_ports_recv: SinkInputPort, + recv_ports_recv: Receiver<(PhaseOutcome, SinkInputPort)>, state: &ExecutionState, join_handles: &mut Vec>>, ); @@ -280,15 +243,8 @@ impl ComputeNode for SinkComputeNode { let (tx, rx) = connector(); let mut join_handles = Vec::new(); - self.sink.spawn_sink( - self.num_pipelines, - SinkRecvPort { - num_pipelines: self.num_pipelines, - recv: rx, - }, - state, - &mut join_handles, - ); + self.sink + .spawn_sink(self.num_pipelines, rx, state, &mut join_handles); // One of the tasks might throw an error. In which case, we need to cancel all // handles and find the error. let join_handles: FuturesUnordered<_> = diff --git a/crates/polars-stream/src/nodes/io_sinks/parquet.rs b/crates/polars-stream/src/nodes/io_sinks/parquet.rs index f7cd8aba3fd9..ad9c97818e85 100644 --- a/crates/polars-stream/src/nodes/io_sinks/parquet.rs +++ b/crates/polars-stream/src/nodes/io_sinks/parquet.rs @@ -20,15 +20,15 @@ use polars_plan::dsl::SinkOptions; use polars_utils::priority::Priority; use super::{ - buffer_and_distribute_columns_task, SinkInputPort, SinkNode, SinkRecvPort, + buffer_and_distribute_columns_task, SinkInputPort, SinkNode, DEFAULT_SINK_DISTRIBUTOR_BUFFER_SIZE, DEFAULT_SINK_LINEARIZER_BUFFER_SIZE, }; use crate::async_executor::spawn; -use crate::async_primitives::connector::connector; +use crate::async_primitives::connector::{connector, Receiver}; use crate::async_primitives::distributor_channel::distributor_channel; use crate::async_primitives::linearizer::Linearizer; use crate::nodes::io_sinks::sync_on_close; -use crate::nodes::{JoinHandle, TaskPriority}; +use crate::nodes::{JoinHandle, PhaseOutcome, TaskPriority}; pub struct ParquetSinkNode { path: PathBuf, @@ -85,28 +85,10 @@ impl SinkNode for ParquetSinkNode { fn spawn_sink( &mut self, num_pipelines: usize, - recv_ports_recv: SinkRecvPort, + recv_port_rx: Receiver<(PhaseOutcome, SinkInputPort)>, _state: &ExecutionState, join_handles: &mut Vec>>, ) { - let rx = recv_ports_recv.serial(join_handles); - self.spawn_sink_once( - num_pipelines, - SinkInputPort::Serial(rx), - _state, - join_handles, - ); - } - - fn spawn_sink_once( - &mut self, - num_pipelines: usize, - recv_port: SinkInputPort, - _state: &ExecutionState, - join_handles: &mut Vec>>, - ) { - // .. -> Buffer task - let buffer_rx = recv_port.serial(); // Buffer task -> Encode tasks let (dist_tx, dist_rxs) = distributor_channel(num_pipelines, DEFAULT_SINK_DISTRIBUTOR_BUFFER_SIZE); @@ -127,7 +109,7 @@ impl SinkNode for ParquetSinkNode { // Buffer task. join_handles.push(buffer_and_distribute_columns_task( - buffer_rx, + recv_port_rx, dist_tx, write_options .row_group_size diff --git a/crates/polars-stream/src/nodes/io_sinks/partition/max_size.rs b/crates/polars-stream/src/nodes/io_sinks/partition/max_size.rs index d743de929800..28b64de89761 100644 --- a/crates/polars-stream/src/nodes/io_sinks/partition/max_size.rs +++ b/crates/polars-stream/src/nodes/io_sinks/partition/max_size.rs @@ -14,12 +14,11 @@ use polars_utils::{format_pl_smallstr, IdxSize}; use super::CreateNewSinkFn; use crate::async_executor::{spawn, AbortOnDropHandle}; -use crate::async_primitives::connector::{self, connector}; +use crate::async_primitives::connector::{self, connector, Receiver}; use crate::async_primitives::distributor_channel::{self, distributor_channel}; -use crate::nodes::io_sinks::{ - SinkInputPort, SinkNode, SinkRecvPort, DEFAULT_SINK_DISTRIBUTOR_BUFFER_SIZE, -}; -use crate::nodes::{JoinHandle, Morsel, TaskPriority}; +use crate::async_primitives::wait_group::WaitGroup; +use crate::nodes::io_sinks::{SinkInputPort, SinkNode, DEFAULT_SINK_DISTRIBUTOR_BUFFER_SIZE}; +use crate::nodes::{JoinHandle, Morsel, PhaseOutcome, TaskPriority}; pub struct MaxSizePartitionSinkNode { input_schema: SchemaRef, @@ -84,28 +83,10 @@ impl SinkNode for MaxSizePartitionSinkNode { fn spawn_sink( &mut self, num_pipelines: usize, - recv_port: SinkRecvPort, - state: &ExecutionState, - join_handles: &mut Vec>>, - ) { - let rx = recv_port.serial(join_handles); - self.spawn_sink_once( - num_pipelines, - SinkInputPort::Serial(rx), - state, - join_handles, - ); - } - - fn spawn_sink_once( - &mut self, - num_pipelines: usize, - recv_port: SinkInputPort, + mut recv_port_recv: Receiver<(PhaseOutcome, SinkInputPort)>, _state: &ExecutionState, join_handles: &mut Vec>>, ) { - // .. -> Main Task - let mut recv_port = recv_port.serial(); // Main Task -> Distributor -> Parallel Input Sink let (mut dist_txs, dist_rxs) = (0..num_pipelines) .map(|_| connector()) @@ -138,115 +119,127 @@ impl SinkNode for MaxSizePartitionSinkNode { let mut part = 0; let mut current_sink_opt = None; - 'morsel_loop: while let Ok(mut morsel) = recv_port.recv().await { - while morsel.df().height() > 0 { - if retire_error.load(Ordering::Relaxed) { - return Ok(()); - } + while let Ok((outcome, recv_port)) = recv_port_recv.recv().await { + let mut recv_port = recv_port.serial(); + 'morsel_loop: while let Ok(mut morsel) = recv_port.recv().await { + while morsel.df().height() > 0 { + if retire_error.load(Ordering::Relaxed) { + return Ok(()); + } - let current_sink = match current_sink_opt.as_mut() { - Some(c) => c, - None => { - *args.get_mut(&part_str).unwrap() = format_pl_smallstr!("{part}"); - part += 1; + let current_sink = match current_sink_opt.as_mut() { + Some(c) => c, + None => { + *args.get_mut(&part_str).unwrap() = format_pl_smallstr!("{part}"); + part += 1; + + let path; + let mut node; + (path, node, args) = (create_new)(input_schema.clone(), args)?; + + if verbose { + eprintln!( + "[partition[max_size]]: Start on new file '{}'", + path.display() + ); + } - let path; - let mut node; - (path, node, args) = (create_new)(input_schema.clone(), args)?; + let (sink_input, sender) = if node.is_sink_input_parallel() { + let (tx, dist_rxs) = distributor_channel( + num_pipelines, + DEFAULT_SINK_DISTRIBUTOR_BUFFER_SIZE, + ); + let (txs, rxs) = (0..num_pipelines) + .map(|_| connector()) + .collect::<(Vec<_>, Vec<_>)>(); + + for (i, channels) in dist_rxs.into_iter().zip(txs).enumerate() { + if dist_txs[i].send(channels).await.is_err() { + return Ok(()); + } + } - if verbose { - eprintln!( - "[partition[max_size]]: Start on new file '{}'", - path.display() - ); - } + (SinkInputPort::Parallel(rxs), SinkSender::Distributor(tx)) + } else { + let (tx, rx) = connector(); + (SinkInputPort::Serial(rx), SinkSender::Connector(tx)) + }; - let (sink_input, sender) = if node.is_sink_input_parallel() { - let (tx, dist_rxs) = distributor_channel( + let mut join_handles = Vec::new(); + let state = ExecutionState::new(); + let (mut sink_input_tx, sink_input_rx) = connector(); + node.spawn_sink( num_pipelines, - DEFAULT_SINK_DISTRIBUTOR_BUFFER_SIZE, + sink_input_rx, + &state, + &mut join_handles, + ); + let join_handles = FuturesUnordered::from_iter( + join_handles.into_iter().map(AbortOnDropHandle::new), ); - let (txs, rxs) = (0..num_pipelines) - .map(|_| connector()) - .collect::<(Vec<_>, Vec<_>)>(); - for (i, channels) in dist_rxs.into_iter().zip(txs).enumerate() { - if dist_txs[i].send(channels).await.is_err() { - return Ok(()); - } + let (_, outcome) = + PhaseOutcome::new_shared_wait(WaitGroup::default().token()); + if sink_input_tx.send((outcome, sink_input)).await.is_err() { + return Ok(()); } + current_sink_opt.insert(CurrentSink { + sender, + num_remaining: max_size, + join_handles, + }) + }, + }; - (SinkInputPort::Parallel(rxs), SinkSender::Distributor(tx)) - } else { - let (tx, rx) = connector(); - (SinkInputPort::Serial(rx), SinkSender::Connector(tx)) - }; + // If we can send the whole morsel into sink, do that. + if morsel.df().height() < current_sink.num_remaining as usize { + current_sink.num_remaining -= morsel.df().height() as IdxSize; - let mut join_handles = Vec::new(); - let state = ExecutionState::new(); - node.spawn_sink_once( - num_pipelines, - sink_input, - &state, - &mut join_handles, - ); - let join_handles = FuturesUnordered::from_iter( - join_handles.into_iter().map(AbortOnDropHandle::new), - ); - current_sink_opt.insert(CurrentSink { - sender, - num_remaining: max_size, - join_handles, - }) - }, - }; - - // If we can send the whole morsel into sink, do that. - if morsel.df().height() < current_sink.num_remaining as usize { - current_sink.num_remaining -= morsel.df().height() as IdxSize; - - // This sends the consume token along so that we don't start buffering here - // too much. The sinks are very specific about how they handle consume - // tokens and we want to keep that behavior. - let result = match &mut current_sink.sender { - SinkSender::Connector(s) => s.send(morsel).await.ok(), - SinkSender::Distributor(s) => s.send(morsel).await.ok(), - }; + // This sends the consume token along so that we don't start buffering here + // too much. The sinks are very specific about how they handle consume + // tokens and we want to keep that behavior. + let result = match &mut current_sink.sender { + SinkSender::Connector(s) => s.send(morsel).await.ok(), + SinkSender::Distributor(s) => s.send(morsel).await.ok(), + }; - if result.is_none() { - break 'morsel_loop; + if result.is_none() { + break 'morsel_loop; + } + break; } - break; - } - // Else, we need to split up the morsel into what can be sent and what needs to - // be passed to the current sink and what needs to be passed to the next sink. - let (df, seq, source_token, consume_token) = morsel.into_inner(); + // Else, we need to split up the morsel into what can be sent and what needs to + // be passed to the current sink and what needs to be passed to the next sink. + let (df, seq, source_token, consume_token) = morsel.into_inner(); - let (final_sink_df, df) = df.split_at(current_sink.num_remaining as i64); - let final_sink_morsel = Morsel::new(final_sink_df, seq, source_token.clone()); + let (final_sink_df, df) = df.split_at(current_sink.num_remaining as i64); + let final_sink_morsel = + Morsel::new(final_sink_df, seq, source_token.clone()); - let result = match &mut current_sink.sender { - SinkSender::Connector(s) => s.send(final_sink_morsel).await.ok(), - SinkSender::Distributor(s) => s.send(final_sink_morsel).await.ok(), - }; + let result = match &mut current_sink.sender { + SinkSender::Connector(s) => s.send(final_sink_morsel).await.ok(), + SinkSender::Distributor(s) => s.send(final_sink_morsel).await.ok(), + }; - if result.is_none() { - return Ok(()); - } + if result.is_none() { + return Ok(()); + } - let join_handles = std::mem::take(&mut current_sink.join_handles); - drop(current_sink_opt.take()); - if retire_tx.send(join_handles).await.is_err() { - return Ok(()); - }; + let join_handles = std::mem::take(&mut current_sink.join_handles); + drop(current_sink_opt.take()); + if retire_tx.send(join_handles).await.is_err() { + return Ok(()); + }; - // We consciously keep the consume token for the last sub-morsel sent. - morsel = Morsel::new(df, seq, source_token); - if let Some(consume_token) = consume_token { - morsel.set_consume_token(consume_token); + // We consciously keep the consume token for the last sub-morsel sent. + morsel = Morsel::new(df, seq, source_token); + if let Some(consume_token) = consume_token { + morsel.set_consume_token(consume_token); + } } } + outcome.stopped(); } if let Some(mut current_sink) = current_sink_opt.take() { diff --git a/py-polars/tests/unit/streaming/test_streaming_io.py b/py-polars/tests/unit/streaming/test_streaming_io.py index 85858fc0bc9c..cf100e0a574a 100644 --- a/py-polars/tests/unit/streaming/test_streaming_io.py +++ b/py-polars/tests/unit/streaming/test_streaming_io.py @@ -270,6 +270,32 @@ def test_nyi_scan_in_memory(method: str) -> None: (getattr(pl, f"scan_{method}"))(f).collect(streaming=True) +@pytest.mark.parametrize( + "method", + ["parquet", "csv", "ipc", "ndjson"], +) +@pytest.mark.write_disk +def test_sink_phases(tmp_path: Path, method: str) -> None: + df = pl.DataFrame({ + 'a': [1, 2, 3, 4, 5, 6, 7], + 'b': ["some", "text", "over-here-is-very-long", "and", "some", "more", + "text"], + }) + + # Ordered Unions lead to many phase transitions. + ref_df = pl.concat([df] * 100) + lf = pl.concat([df.lazy()] * 100) + + (getattr(lf, f"sink_{method}"))(tmp_path / f"t.{method}") + df = (getattr(pl, f"scan_{method}"))(tmp_path / f"t.{method}").collect() + + assert_frame_equal(df, ref_df) + + (getattr(lf, f"sink_{method}"))(tmp_path / f"t.{method}", maintain_order=False) + height = (getattr(pl, f"scan_{method}"))(tmp_path / f"t.{method}").select(pl.len()).collect()[0, 0] + assert height == ref_df.height + + def test_empty_sink_parquet_join_14863(tmp_path: Path) -> None: file_path = tmp_path / "empty.parquet" lf = pl.LazyFrame(schema=["a", "b", "c"]).cast(pl.String)