Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add IPC source node for new streaming engine #19454

Merged
merged 15 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions crates/polars-arrow/src/io/ipc/read/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,14 @@ pub fn read_dictionary<R: Read + Seek>(
Ok(())
}

pub fn prepare_projection(
schema: &ArrowSchema,
mut projection: Vec<usize>,
) -> (Vec<usize>, PlHashMap<usize, usize>, ArrowSchema) {
#[derive(Clone)]
pub struct ProjectionInfo {
pub columns: Vec<usize>,
pub map: PlHashMap<usize, usize>,
pub schema: ArrowSchema,
}

pub fn prepare_projection(schema: &ArrowSchema, mut projection: Vec<usize>) -> ProjectionInfo {
let schema = projection
.iter()
.map(|x| {
Expand Down Expand Up @@ -355,7 +359,11 @@ pub fn prepare_projection(
}
}

(projection, map, schema)
ProjectionInfo {
columns: projection,
map,
schema,
}
}

pub fn apply_projection(
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/io/ipc/read/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ fn get_message_from_block_offset<'a, R: Read + Seek>(
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))
}

fn get_message_from_block<'a, R: Read + Seek>(
pub(super) fn get_message_from_block<'a, R: Read + Seek>(
reader: &mut R,
block: &arrow_format::ipc::Block,
message_scratch: &'a mut Vec<u8>,
Expand Down
1 change: 1 addition & 0 deletions crates/polars-arrow/src/io/ipc/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod schema;
mod stream;

pub(crate) use common::first_dict_field;
pub use common::{prepare_projection, ProjectionInfo};
pub use error::OutOfSpecKind;
pub use file::{
deserialize_footer, get_row_count, read_batch, read_file_dictionaries, read_file_metadata,
Expand Down
90 changes: 80 additions & 10 deletions crates/polars-arrow/src/io/ipc/read/reader.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::io::{Read, Seek};

use polars_error::PolarsResult;
use polars_utils::aliases::PlHashMap;

use super::common::*;
use super::file::{get_message_from_block, get_record_batch};
use super::{read_batch, read_file_dictionaries, Dictionaries, FileMetadata};
use crate::array::Array;
use crate::datatypes::ArrowSchema;
Expand All @@ -16,7 +16,7 @@ pub struct FileReader<R: Read + Seek> {
// the dictionaries are going to be read
dictionaries: Option<Dictionaries>,
current_block: usize,
projection: Option<(Vec<usize>, PlHashMap<usize, usize>, ArrowSchema)>,
projection: Option<ProjectionInfo>,
remaining: usize,
data_scratch: Vec<u8>,
message_scratch: Vec<u8>,
Expand All @@ -32,10 +32,29 @@ impl<R: Read + Seek> FileReader<R> {
projection: Option<Vec<usize>>,
limit: Option<usize>,
) -> Self {
let projection = projection.map(|projection| {
let (p, h, schema) = prepare_projection(&metadata.schema, projection);
(p, h, schema)
});
let projection =
projection.map(|projection| prepare_projection(&metadata.schema, projection));
Self {
reader,
metadata,
dictionaries: Default::default(),
projection,
remaining: limit.unwrap_or(usize::MAX),
current_block: 0,
data_scratch: Default::default(),
message_scratch: Default::default(),
}
}

/// Creates a new [`FileReader`]. Use `projection` to only take certain columns.
/// # Panic
/// Panics iff the projection is not in increasing order (e.g. `[1, 0]` nor `[0, 1, 1]` are valid)
pub fn new_with_projection_info(
reader: R,
metadata: FileMetadata,
projection: Option<ProjectionInfo>,
limit: Option<usize>,
) -> Self {
Self {
reader,
metadata,
Expand All @@ -52,7 +71,7 @@ impl<R: Read + Seek> FileReader<R> {
pub fn schema(&self) -> &ArrowSchema {
self.projection
.as_ref()
.map(|x| &x.2)
.map(|x| &x.schema)
.unwrap_or(&self.metadata.schema)
}

Expand All @@ -66,9 +85,23 @@ impl<R: Read + Seek> FileReader<R> {
self.reader
}

pub fn set_current_block(&mut self, idx: usize) {
self.current_block = idx;
}

pub fn get_current_block(&self) -> usize {
self.current_block
}

/// Get the inner memory scratches so they can be reused in a new writer.
/// This can be utilized to save memory allocations for performance reasons.
pub fn take_projection_info(&mut self) -> Option<ProjectionInfo> {
std::mem::take(&mut self.projection)
}

/// Get the inner memory scratches so they can be reused in a new writer.
/// This can be utilized to save memory allocations for performance reasons.
pub fn get_scratches(&mut self) -> (Vec<u8>, Vec<u8>) {
pub fn take_scratches(&mut self) -> (Vec<u8>, Vec<u8>) {
(
std::mem::take(&mut self.data_scratch),
std::mem::take(&mut self.message_scratch),
Expand All @@ -91,6 +124,43 @@ impl<R: Read + Seek> FileReader<R> {
};
Ok(())
}

/// Skip over blocks until we have seen at most `offset` rows, returning how many rows we are
/// still too see.
///
/// This will never go over the `offset`. Meaning that if the `offset < current_block.len()`,
/// the block will not be skipped.
pub fn skip_blocks_till_limit(&mut self, offset: u64) -> PolarsResult<u64> {
let mut remaining_offset = offset;

for (i, block) in self.metadata.blocks.iter().enumerate() {
let message =
get_message_from_block(&mut self.reader, block, &mut self.message_scratch)?;
let record_batch = get_record_batch(message)?;

let length = record_batch.length()?;
let length = length as u64;

if length > remaining_offset {
self.current_block = i;
return Ok(remaining_offset);
}

remaining_offset -= length;
}

self.current_block = self.metadata.blocks.len();
Ok(remaining_offset)
}

pub fn next_record_batch(
&mut self,
) -> Option<PolarsResult<arrow_format::ipc::RecordBatchRef<'_>>> {
let block = self.metadata.blocks.get(self.current_block)?;
self.current_block += 1;
let message = get_message_from_block(&mut self.reader, block, &mut self.message_scratch);
Some(message.and_then(|m| get_record_batch(m)))
}
}

impl<R: Read + Seek> Iterator for FileReader<R> {
Expand All @@ -114,15 +184,15 @@ impl<R: Read + Seek> Iterator for FileReader<R> {
&mut self.reader,
self.dictionaries.as_ref().unwrap(),
&self.metadata,
self.projection.as_ref().map(|x| x.0.as_ref()),
self.projection.as_ref().map(|x| x.columns.as_ref()),
Some(self.remaining),
block,
&mut self.message_scratch,
&mut self.data_scratch,
);
self.remaining -= chunk.as_ref().map(|x| x.len()).unwrap_or_default();

let chunk = if let Some((_, map, _)) = &self.projection {
let chunk = if let Some(ProjectionInfo { map, .. }) = &self.projection {
// re-order according to projection
chunk.map(|chunk| apply_projection(chunk, map))
} else {
Expand Down
17 changes: 7 additions & 10 deletions crates/polars-arrow/src/io/ipc/read/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::io::Read;

use arrow_format::ipc::planus::ReadAsRoot;
use polars_error::{polars_bail, polars_err, PolarsError, PolarsResult};
use polars_utils::aliases::PlHashMap;

use super::super::CONTINUATION_MARKER;
use super::common::*;
Expand Down Expand Up @@ -93,7 +92,7 @@ fn read_next<R: Read>(
dictionaries: &mut Dictionaries,
message_buffer: &mut Vec<u8>,
data_buffer: &mut Vec<u8>,
projection: &Option<(Vec<usize>, PlHashMap<usize, usize>, ArrowSchema)>,
projection: &Option<ProjectionInfo>,
scratch: &mut Vec<u8>,
) -> PolarsResult<Option<StreamState>> {
// determine metadata length
Expand Down Expand Up @@ -169,7 +168,7 @@ fn read_next<R: Read>(
batch,
&metadata.schema,
&metadata.ipc_schema,
projection.as_ref().map(|x| x.0.as_ref()),
projection.as_ref().map(|x| x.columns.as_ref()),
None,
dictionaries,
metadata.version,
Expand All @@ -179,7 +178,7 @@ fn read_next<R: Read>(
scratch,
);

if let Some((_, map, _)) = projection {
if let Some(ProjectionInfo { map, .. }) = projection {
// re-order according to projection
chunk
.map(|chunk| apply_projection(chunk, map))
Expand Down Expand Up @@ -238,7 +237,7 @@ pub struct StreamReader<R: Read> {
finished: bool,
data_buffer: Vec<u8>,
message_buffer: Vec<u8>,
projection: Option<(Vec<usize>, PlHashMap<usize, usize>, ArrowSchema)>,
projection: Option<ProjectionInfo>,
scratch: Vec<u8>,
}

Expand All @@ -249,10 +248,8 @@ impl<R: Read> StreamReader<R> {
/// encounter a schema.
/// To check if the reader is done, use `is_finished(self)`
pub fn new(reader: R, metadata: StreamMetadata, projection: Option<Vec<usize>>) -> Self {
let projection = projection.map(|projection| {
let (p, h, schema) = prepare_projection(&metadata.schema, projection);
(p, h, schema)
});
let projection =
projection.map(|projection| prepare_projection(&metadata.schema, projection));

Self {
reader,
Expand All @@ -275,7 +272,7 @@ impl<R: Read> StreamReader<R> {
pub fn schema(&self) -> &ArrowSchema {
self.projection
.as_ref()
.map(|x| &x.2)
.map(|x| &x.schema)
.unwrap_or(&self.metadata.schema)
}

Expand Down
22 changes: 16 additions & 6 deletions crates/polars-arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::array::{Array, ArrayRef};
/// the same length, [`RecordBatchT::len`].
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RecordBatchT<A: AsRef<dyn Array>> {
length: usize,
height: usize,
arrays: Vec<A>,
}

Expand All @@ -29,14 +29,14 @@ impl<A: AsRef<dyn Array>> RecordBatchT<A> {
///
/// # Error
///
/// I.f.f. the length does not match the length of any of the arrays
pub fn try_new(length: usize, arrays: Vec<A>) -> PolarsResult<Self> {
/// I.f.f. the height does not match the length of any of the arrays
pub fn try_new(height: usize, arrays: Vec<A>) -> PolarsResult<Self> {
polars_ensure!(
arrays.iter().all(|arr| arr.as_ref().len() == length),
arrays.iter().all(|arr| arr.as_ref().len() == height),
ComputeError: "RecordBatch requires all its arrays to have an equal number of rows",
);

Ok(Self { length, arrays })
Ok(Self { height, arrays })
}

/// returns the [`Array`]s in [`RecordBatchT`]
Expand All @@ -51,7 +51,17 @@ impl<A: AsRef<dyn Array>> RecordBatchT<A> {

/// returns the number of rows of every array
pub fn len(&self) -> usize {
self.length
self.height
}

/// returns the number of rows of every array
pub fn height(&self) -> usize {
self.height
}

/// returns the number of arrays
pub fn width(&self) -> usize {
self.arrays.len()
}

/// returns whether the columns have any rows
Expand Down
26 changes: 26 additions & 0 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::borrow::Cow;
use std::{mem, ops};

use polars_row::ArrayRef;
use polars_utils::itertools::Itertools;
use rayon::prelude::*;

Expand Down Expand Up @@ -3334,6 +3335,31 @@ impl DataFrame {
pub(crate) fn infer_height(cols: &[Column]) -> usize {
cols.first().map_or(0, Column::len)
}

pub fn append_record_batch(&mut self, rb: RecordBatchT<ArrayRef>) -> PolarsResult<()> {
polars_ensure!(
rb.arrays().len() == self.width(),
InvalidOperation: "attempt to extend dataframe of width {} with record batch of width {}",
self.width(),
rb.arrays().len(),
);

if rb.height() == 0 {
return Ok(());
}

// SAFETY:
// - we don't adjust the names of the columns
// - each column gets appended the same number of rows, which is an invariant of
// record_batch.
let columns = unsafe { self.get_columns_mut() };
for (col, arr) in columns.iter_mut().zip(rb.into_arrays()) {
let arr_series = Series::from_arrow_chunks(PlSmallStr::EMPTY, vec![arr])?.into_column();
col.append(&arr_series)?;
}

Ok(())
}
}

pub struct RecordBatchIter<'a> {
Expand Down
28 changes: 28 additions & 0 deletions crates/polars-core/src/frame/upstream_traits.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::ops::{Index, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};

use arrow::record_batch::RecordBatchT;

use crate::prelude::*;

impl FromIterator<Series> for DataFrame {
Expand All @@ -22,6 +24,32 @@ impl FromIterator<Column> for DataFrame {
}
}

impl TryExtend<RecordBatchT<Box<dyn Array>>> for DataFrame {
fn try_extend<I: IntoIterator<Item = RecordBatchT<Box<dyn Array>>>>(
&mut self,
iter: I,
) -> PolarsResult<()> {
for record_batch in iter {
self.append_record_batch(record_batch)?;
}

Ok(())
}
}

impl TryExtend<PolarsResult<RecordBatchT<Box<dyn Array>>>> for DataFrame {
fn try_extend<I: IntoIterator<Item = PolarsResult<RecordBatchT<Box<dyn Array>>>>>(
&mut self,
iter: I,
) -> PolarsResult<()> {
for record_batch in iter {
self.append_record_batch(record_batch?)?;
}

Ok(())
}
}

impl Index<usize> for DataFrame {
type Output = Column;

Expand Down
Loading
Loading