diff --git a/crates/polars-stream/src/async_primitives/distributor_channel.rs b/crates/polars-stream/src/async_primitives/distributor_channel.rs index 21af7b53d7d1..fa8685f88925 100644 --- a/crates/polars-stream/src/async_primitives/distributor_channel.rs +++ b/crates/polars-stream/src/async_primitives/distributor_channel.rs @@ -17,6 +17,11 @@ use super::task_parker::TaskParker; /// The FIFO order is only guaranteed per receiver. That is, each receiver is /// guaranteed to see a subset of the data sent by the sender in the order the /// sender sent it in, but not necessarily contiguously. +/// +/// When one or more receivers are closed no attempt is made to avoid filling +/// those receivers' buffers. The values in the buffer of a closed receiver are +/// lost forever, they're not redistributed among the others, and simply +/// dropped when the channel is dropped. pub fn distributor_channel( num_receivers: usize, bufsize: usize, @@ -108,6 +113,8 @@ unsafe impl Send for Sender {} unsafe impl Send for Receiver {} impl Sender { + /// Note: This intentionally takes `&mut` to ensure it is only accessed in a single-threaded + /// manner. pub async fn send(&mut self, mut value: T) -> Result<(), T> { let num_receivers = self.inner.receivers.len(); loop { @@ -128,7 +135,7 @@ impl Sender { } } - match self.try_send(hungriest_idx, value) { + match unsafe { self.try_send(hungriest_idx, value) } { Ok(()) => return Ok(()), Err(SendError::Full(v)) => value = v, Err(SendError::Closed(v)) => value = v, @@ -141,7 +148,7 @@ impl Sender { let mut idx = ((self.rng.gen::() as u64 * num_receivers as u64) >> 32) as usize; let mut all_closed = true; for _ in 0..num_receivers { - match self.try_send(idx, value) { + match unsafe { self.try_send(idx, value) } { Ok(()) => return Ok(()), Err(SendError::Full(v)) => { all_closed = false; @@ -164,6 +171,8 @@ impl Sender { } } + // Returns the upper bound on the length of the queue of the given receiver. + // It is an upper bound because racy reads can reduce it in the meantime. fn upper_bound_len(&self, recv_idx: usize) -> usize { let read_head = self.inner.receivers[recv_idx] .read_head @@ -172,7 +181,9 @@ impl Sender { write_head.wrapping_sub(read_head) } - fn try_send(&self, recv_idx: usize, value: T) -> Result<(), SendError> { + /// # Safety + /// May only be called from one thread at a time. + unsafe fn try_send(&self, recv_idx: usize, value: T) -> Result<(), SendError> { let read_head = self.inner.receivers[recv_idx] .read_head .load(Ordering::SeqCst); @@ -203,7 +214,7 @@ impl Receiver { pub async fn recv(&mut self) -> Result { loop { // Fast-path. - match self.try_recv() { + match unsafe { self.try_recv() } { Ok(v) => return Ok(v), Err(RecvError::Closed) => return Err(()), Err(RecvError::Empty) => {}, @@ -211,7 +222,7 @@ impl Receiver { // Try again, threatening to park if there's still nothing. let park = self.inner.receivers[self.index].parker.park(); - match self.try_recv() { + match unsafe { self.try_recv() } { Ok(v) => return Ok(v), Err(RecvError::Closed) => return Err(()), Err(RecvError::Empty) => {}, @@ -220,27 +231,36 @@ impl Receiver { } } - fn try_recv(&self) -> Result { - let read_head = self.inner.receivers[self.index] - .read_head - .load(Ordering::Relaxed); - let write_head = self.inner.write_heads[self.index].load(Ordering::SeqCst); - if read_head != write_head { - let idx = self.inner.reduce_index(read_head); - let read; - unsafe { - let ptr = self.inner.receivers[self.index].data[idx].get(); - read = ptr.read().assume_init(); - self.inner.receivers[self.index] - .read_head - .store(read_head.wrapping_add(1), Ordering::SeqCst); + /// # Safety + /// May only be called from one thread at a time. + unsafe fn try_recv(&self) -> Result { + loop { + let read_head = self.inner.receivers[self.index] + .read_head + .load(Ordering::Relaxed); + let write_head = self.inner.write_heads[self.index].load(Ordering::SeqCst); + if read_head != write_head { + let idx = self.inner.reduce_index(read_head); + let read; + unsafe { + let ptr = self.inner.receivers[self.index].data[idx].get(); + read = ptr.read().assume_init(); + self.inner.receivers[self.index] + .read_head + .store(read_head.wrapping_add(1), Ordering::SeqCst); + } + self.inner.send_parker.unpark(); + return Ok(read); + } else if self.inner.send_closed.load(Ordering::SeqCst) { + // Check write head again, sender could've sent something right + // before closing. We can do this relaxed because we'll read it + // again in the next iteration with SeqCst if it's a new value. + if write_head == self.inner.write_heads[self.index].load(Ordering::Relaxed) { + return Err(RecvError::Closed); + } + } else { + return Err(RecvError::Empty); } - self.inner.send_parker.unpark(); - Ok(read) - } else if self.inner.send_closed.load(Ordering::SeqCst) { - Err(RecvError::Closed) - } else { - Err(RecvError::Empty) } } }