Skip to content

Commit

Permalink
refactor(rust): Fix race condition in DistributorChannel
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp committed Sep 30, 2024
1 parent ab5200d commit 6da7881
Showing 1 changed file with 45 additions and 25 deletions.
70 changes: 45 additions & 25 deletions crates/polars-stream/src/async_primitives/distributor_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ use super::task_parker::TaskParker;
/// The FIFO order is only guaranteed per receiver. That is, each receiver is
/// guaranteed to see a subset of the data sent by the sender in the order the
/// sender sent it in, but not necessarily contiguously.
///
/// When one or more receivers are closed no attempt is made to avoid filling
/// those receivers' buffers. The values in the buffer of a closed receiver are
/// lost forever, they're not redistributed among the others, and simply
/// dropped when the channel is dropped.
pub fn distributor_channel<T>(
num_receivers: usize,
bufsize: usize,
Expand Down Expand Up @@ -108,6 +113,8 @@ unsafe impl<T: Send> Send for Sender<T> {}
unsafe impl<T: Send> Send for Receiver<T> {}

impl<T: Send> Sender<T> {
/// Note: This intentionally takes `&mut` to ensure it is only accessed in a single-threaded
/// manner.
pub async fn send(&mut self, mut value: T) -> Result<(), T> {
let num_receivers = self.inner.receivers.len();
loop {
Expand All @@ -128,7 +135,7 @@ impl<T: Send> Sender<T> {
}
}

match self.try_send(hungriest_idx, value) {
match unsafe { self.try_send(hungriest_idx, value) } {
Ok(()) => return Ok(()),
Err(SendError::Full(v)) => value = v,
Err(SendError::Closed(v)) => value = v,
Expand All @@ -141,7 +148,7 @@ impl<T: Send> Sender<T> {
let mut idx = ((self.rng.gen::<u32>() as u64 * num_receivers as u64) >> 32) as usize;
let mut all_closed = true;
for _ in 0..num_receivers {
match self.try_send(idx, value) {
match unsafe { self.try_send(idx, value) } {
Ok(()) => return Ok(()),
Err(SendError::Full(v)) => {
all_closed = false;
Expand All @@ -164,6 +171,8 @@ impl<T: Send> Sender<T> {
}
}

// Returns the upper bound on the length of the queue of the given receiver.
// It is an upper bound because racy reads can reduce it in the meantime.
fn upper_bound_len(&self, recv_idx: usize) -> usize {
let read_head = self.inner.receivers[recv_idx]
.read_head
Expand All @@ -172,7 +181,9 @@ impl<T: Send> Sender<T> {
write_head.wrapping_sub(read_head)
}

fn try_send(&self, recv_idx: usize, value: T) -> Result<(), SendError<T>> {
/// # Safety
/// May only be called from one thread at a time.
unsafe fn try_send(&self, recv_idx: usize, value: T) -> Result<(), SendError<T>> {
let read_head = self.inner.receivers[recv_idx]
.read_head
.load(Ordering::SeqCst);
Expand Down Expand Up @@ -203,15 +214,15 @@ impl<T: Send> Receiver<T> {
pub async fn recv(&mut self) -> Result<T, ()> {
loop {
// Fast-path.
match self.try_recv() {
match unsafe { self.try_recv() } {
Ok(v) => return Ok(v),
Err(RecvError::Closed) => return Err(()),
Err(RecvError::Empty) => {},
}

// Try again, threatening to park if there's still nothing.
let park = self.inner.receivers[self.index].parker.park();
match self.try_recv() {
match unsafe { self.try_recv() } {
Ok(v) => return Ok(v),
Err(RecvError::Closed) => return Err(()),
Err(RecvError::Empty) => {},
Expand All @@ -220,27 +231,36 @@ impl<T: Send> Receiver<T> {
}
}

fn try_recv(&self) -> Result<T, RecvError> {
let read_head = self.inner.receivers[self.index]
.read_head
.load(Ordering::Relaxed);
let write_head = self.inner.write_heads[self.index].load(Ordering::SeqCst);
if read_head != write_head {
let idx = self.inner.reduce_index(read_head);
let read;
unsafe {
let ptr = self.inner.receivers[self.index].data[idx].get();
read = ptr.read().assume_init();
self.inner.receivers[self.index]
.read_head
.store(read_head.wrapping_add(1), Ordering::SeqCst);
/// # Safety
/// May only be called from one thread at a time.
unsafe fn try_recv(&self) -> Result<T, RecvError> {
loop {
let read_head = self.inner.receivers[self.index]
.read_head
.load(Ordering::Relaxed);
let write_head = self.inner.write_heads[self.index].load(Ordering::SeqCst);
if read_head != write_head {
let idx = self.inner.reduce_index(read_head);
let read;
unsafe {
let ptr = self.inner.receivers[self.index].data[idx].get();
read = ptr.read().assume_init();
self.inner.receivers[self.index]
.read_head
.store(read_head.wrapping_add(1), Ordering::SeqCst);
}
self.inner.send_parker.unpark();
return Ok(read);
} else if self.inner.send_closed.load(Ordering::SeqCst) {
// Check write head again, sender could've sent something right
// before closing. We can do this relaxed because we'll read it
// again in the next iteration with SeqCst if it's a new value.
if write_head == self.inner.write_heads[self.index].load(Ordering::Relaxed) {
return Err(RecvError::Closed);
}
} else {
return Err(RecvError::Empty);
}
self.inner.send_parker.unpark();
Ok(read)
} else if self.inner.send_closed.load(Ordering::SeqCst) {
Err(RecvError::Closed)
} else {
Err(RecvError::Empty)
}
}
}
Expand Down

0 comments on commit 6da7881

Please sign in to comment.