Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Mar 5, 2025
1 parent 241d34b commit 3cb7e64
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 68 deletions.
92 changes: 77 additions & 15 deletions crates/polars-io/src/cloud/adaptors.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
//! Interface with the object_store crate and define AsyncSeek, AsyncRead.
use std::future::Future;
use std::sync::Arc;
use std::task::Poll;

use futures::FutureExt;
use object_store::buffered::BufWriter;
use object_store::path::Path;
use object_store::ObjectStore;
Expand All @@ -11,6 +14,13 @@ use tokio::io::AsyncWriteExt;
use super::{object_path_from_str, CloudOptions};
use crate::pl_async::{get_runtime, get_upload_chunk_size};

/// CloudWriter's synchronous functions should be callable from async contexts,
/// so we ensure we are `block_in_place` using this util function.
fn block_in_place_on<T>(func: impl Future<Output = T>) -> T {
let rt = get_runtime();
tokio::task::block_in_place(|| rt.block_on(func))
}

enum WriterState {
Open(BufWriter),
/// Note: `Err` state is also used as the close state on success.
Expand All @@ -26,7 +36,7 @@ impl WriterState {
Self::Open(writer) => match func(writer) {
Ok(v) => Ok(v),
Err(e) => {
let _ = get_runtime().block_on_potential_spawn(writer.abort());
let _ = block_in_place_on(writer.abort());
*self = Self::Err(e);
self.try_with_writer(func)
},
Expand Down Expand Up @@ -87,49 +97,101 @@ impl CloudWriter {
)
}

pub fn close(&mut self) -> PolarsResult<()> {
pub async fn close(&mut self) -> PolarsResult<()> {
let WriterState::Open(writer) = &mut self.inner else {
panic!();
};

get_runtime()
.block_on_potential_spawn(async { writer.shutdown().await })
.map_err(to_compute_err)?;
writer.shutdown().await.map_err(to_compute_err)?;

self.inner = WriterState::Err(std::io::Error::new(
std::io::ErrorKind::Other,
"impl error: file was closed",
"already closed",
));

Ok(())
}

pub fn close_sync(&mut self) -> PolarsResult<()> {
block_in_place_on(self.close())
}
}

impl std::io::Write for CloudWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
// SAFETY:
// We extend the lifetime for the duration of this function. This is safe as well block the
// We extend the lifetime for the duration of this function. This is safe as we block the
// async runtime here
let buf = unsafe { std::mem::transmute::<&[u8], &'static [u8]>(buf) };

self.inner.try_with_writer(|writer| {
get_runtime()
.block_on_potential_spawn(async { writer.write_all(buf).await.map(|_t| buf.len()) })
block_in_place_on(async { writer.write_all(buf).await.map(|_t| buf.len()) })
})
}

fn flush(&mut self) -> std::io::Result<()> {
self.inner.try_with_writer(|writer| {
get_runtime().block_on_potential_spawn(async { writer.flush().await })
})
self.inner
.try_with_writer(|writer| block_in_place_on(async { writer.flush().await }))
}
}

impl tokio::io::AsyncWrite for CloudWriter {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match &mut self.inner {
WriterState::Open(writer) => match Box::pin(writer.write(buf)).poll_unpin(cx) {
Poll::Ready(Err(e)) => {
self.inner = WriterState::Err(e);
Self::poll_write(self, cx, buf)
},
v => v,
},
WriterState::Err(e) => Poll::Ready(Err(std::io::Error::new(e.kind(), e.to_string()))),
}
}

fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut self.inner {
WriterState::Open(writer) => match Box::pin(writer.flush()).poll_unpin(cx) {
Poll::Ready(Err(e)) => {
self.inner = WriterState::Err(e);
Self::poll_flush(self, cx)
},
v => v,
},
WriterState::Err(e) => Poll::Ready(Err(std::io::Error::new(e.kind(), e.to_string()))),
}
}

fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut self.inner {
WriterState::Open(writer) => match Box::pin(writer.shutdown()).poll_unpin(cx) {
Poll::Ready(Err(e)) => {
self.inner = WriterState::Err(e);
Self::poll_shutdown(self, cx)
},
v => v,
},
WriterState::Err(e) => Poll::Ready(Err(std::io::Error::new(e.kind(), e.to_string()))),
}
}
}

impl Drop for CloudWriter {
fn drop(&mut self) {
// TODO: Properly raise this error instead of panicking.
// TODO: Once we are properly calling `close()` from all contexts this can instead be a
// debug_assert that we are in an `Err(_)` state when dropping.
match self.inner {
WriterState::Open(_) => self.close().unwrap(),
WriterState::Open(_) => self.close_sync().unwrap(),
WriterState::Err(_) => {},
}
}
Expand Down Expand Up @@ -194,7 +256,7 @@ mod tests {
.finish(&mut df)
.expect("Could not write DataFrame as CSV to remote location");

cloud_writer.close().unwrap();
cloud_writer.close_sync().unwrap();

assert_eq!(
CsvReadOptions::default()
Expand Down
Loading

0 comments on commit 3cb7e64

Please sign in to comment.