Skip to content

Commit

Permalink
fix arrow construction for empty tables
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Dec 13, 2024
1 parent 28253f7 commit 6ca9dd2
Show file tree
Hide file tree
Showing 27 changed files with 353 additions and 129 deletions.
3 changes: 3 additions & 0 deletions crates/polars-arrow/src/io/avro/read/deserialize.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use avro_schema::file::Block;
use avro_schema::schema::{Enum, Field as AvroField, Record, Schema as AvroSchema};
use polars_error::{polars_bail, polars_err, PolarsResult};
Expand Down Expand Up @@ -508,6 +510,7 @@ pub fn deserialize(

RecordBatchT::try_new(
rows,
Arc::new(fields.iter_values().cloned().collect()),
arrays
.iter_mut()
.zip(projection.iter())
Expand Down
23 changes: 18 additions & 5 deletions crates/polars-arrow/src/io/ipc/read/common.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::VecDeque;
use std::io::{Read, Seek};
use std::sync::Arc;

use polars_error::{polars_bail, polars_err, PolarsResult};
use polars_utils::aliases::PlHashMap;
Expand Down Expand Up @@ -197,7 +198,11 @@ pub fn read_record_batch<R: Read + Seek>(
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
let length = limit.map(|limit| limit.min(length)).unwrap_or(length);

RecordBatchT::try_new(length, columns)
let mut schema: ArrowSchema = fields.iter_values().cloned().collect();
if let Some(projection) = projection {
schema = schema.try_project_indices(projection).unwrap();
}
RecordBatchT::try_new(length, Arc::new(schema), columns)
}

fn find_first_dict_field_d<'a>(
Expand Down Expand Up @@ -373,13 +378,21 @@ pub fn apply_projection(
let length = chunk.len();

// re-order according to projection
let arrays = chunk.into_arrays();
let (schema, arrays) = chunk.into_schema_and_arrays();
let mut new_schema = schema.as_ref().clone();
let mut new_arrays = arrays.clone();

map.iter()
.for_each(|(old, new)| new_arrays[*new] = arrays[*old].clone());
map.iter().for_each(|(old, new)| {
let (old_name, old_field) = schema.get_at_index(*old).unwrap();
let (new_name, new_field) = new_schema.get_at_index_mut(*new).unwrap();

*new_name = old_name.clone();
*new_field = old_field.clone();

new_arrays[*new] = arrays[*old].clone();
});

RecordBatchT::new(length, new_arrays)
RecordBatchT::new(length, Arc::new(new_schema), new_arrays)
}

#[cfg(test)]
Expand Down
8 changes: 7 additions & 1 deletion crates/polars-arrow/src/mmap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,13 @@ pub(crate) unsafe fn mmap_record<T: AsRef<[u8]>>(
)
})
.collect::<PolarsResult<_>>()
.and_then(|arr| RecordBatchT::try_new(length, arr))
.and_then(|arr| {
RecordBatchT::try_new(
length,
Arc::new(fields.iter_values().cloned().collect()),
arr,
)
})
}

/// Memory maps an record batch from an IPC file into a [`RecordBatchT`].
Expand Down
29 changes: 25 additions & 4 deletions crates/polars-arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
use polars_error::{polars_ensure, PolarsResult};

use crate::array::{Array, ArrayRef};
use crate::datatypes::{ArrowSchema, ArrowSchemaRef};

/// A vector of trait objects of [`Array`] where every item has
/// the same length, [`RecordBatchT::len`].
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RecordBatchT<A: AsRef<dyn Array>> {
height: usize,
schema: ArrowSchemaRef,
arrays: Vec<A>,
}

Expand All @@ -21,29 +23,42 @@ impl<A: AsRef<dyn Array>> RecordBatchT<A> {
/// # Panics
///
/// I.f.f. the length does not match the length of any of the arrays
pub fn new(length: usize, arrays: Vec<A>) -> Self {
Self::try_new(length, arrays).unwrap()
pub fn new(length: usize, schema: ArrowSchemaRef, arrays: Vec<A>) -> Self {
Self::try_new(length, schema, arrays).unwrap()
}

/// Creates a new [`RecordBatchT`].
///
/// # Error
///
/// I.f.f. the height does not match the length of any of the arrays
pub fn try_new(height: usize, arrays: Vec<A>) -> PolarsResult<Self> {
pub fn try_new(height: usize, schema: ArrowSchemaRef, arrays: Vec<A>) -> PolarsResult<Self> {
polars_ensure!(
schema.len() == arrays.len(),
ComputeError: "RecordBatch requires an equal number of fields and arrays",
);
polars_ensure!(
arrays.iter().all(|arr| arr.as_ref().len() == height),
ComputeError: "RecordBatch requires all its arrays to have an equal number of rows",
);

Ok(Self { height, arrays })
Ok(Self {
height,
schema,
arrays,
})
}

/// returns the [`Array`]s in [`RecordBatchT`]
pub fn arrays(&self) -> &[A] {
&self.arrays
}

/// returns the [`ArrowSchema`]s in [`RecordBatchT`]
pub fn schema(&self) -> &ArrowSchema {
&self.schema
}

/// returns the [`Array`]s in [`RecordBatchT`]
pub fn columns(&self) -> &[A] {
&self.arrays
Expand Down Expand Up @@ -74,6 +89,12 @@ impl<A: AsRef<dyn Array>> RecordBatchT<A> {
pub fn into_arrays(self) -> Vec<A> {
self.arrays
}

/// Consumes [`RecordBatchT`] into its underlying schema and arrays.
/// The arrays are guaranteed to have the same length
pub fn into_schema_and_arrays(self) -> (ArrowSchemaRef, Vec<A>) {
(self.schema, self.arrays)
}
}

impl<A: AsRef<dyn Array>> From<RecordBatchT<A>> for Vec<A> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ use arrow::array::*;
use polars_utils::aliases::PlRandomState;
#[cfg(any(feature = "serde-lazy", feature = "serde"))]
use serde::{Deserialize, Serialize};
use strum_macros::IntoStaticStr;

use crate::datatypes::PlHashMap;
use crate::using_string_cache;

#[derive(Debug, Copy, Clone, PartialEq, Default)]
#[derive(Debug, Copy, Clone, PartialEq, Default, IntoStaticStr)]
#[cfg_attr(
any(feature = "serde-lazy", feature = "serde"),
derive(Serialize, Deserialize)
)]
#[strum(serialize_all = "snake_case")]
pub enum CategoricalOrdering {
#[default]
Physical,
Expand Down
12 changes: 8 additions & 4 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ pub trait MetaDataExt: IntoMetadata {
metadata.get(DTYPE_ENUM_VALUES).is_some()
}

fn is_categorical(&self) -> bool {
fn categorical(&self) -> Option<CategoricalOrdering> {
let metadata = self.into_metadata_ref();
metadata.get(DTYPE_CATEGORICAL).is_some()
match metadata.get(DTYPE_CATEGORICAL)?.as_str() {
"lexical" => Some(CategoricalOrdering::Lexical),
// Default is Physical
_ => Some(CategoricalOrdering::Physical),
}
}

fn maintain_type(&self) -> bool {
Expand Down Expand Up @@ -603,9 +607,9 @@ impl DataType {
)]))
},
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_, _) => Some(BTreeMap::from([(
DataType::Categorical(_, ordering) => Some(BTreeMap::from([(
PlSmallStr::from_static(DTYPE_CATEGORICAL),
PlSmallStr::EMPTY,
PlSmallStr::from_static(ordering.into()),
)])),
DataType::BinaryOffset => Some(BTreeMap::from([(
PlSmallStr::from_static(PL_KEY),
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ impl DataType {
encoded = remainder;
}
DataType::Enum(Some(Arc::new(RevMapping::build_local(cats.into()))), Default::default())
} else if md.map(|md| md.is_categorical()).unwrap_or(false) || matches!(value_type.as_ref(), ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View) {
} else if let Some(ordering) = md.and_then(|md| md.categorical()) {
DataType::Categorical(None, ordering)
} else if matches!(value_type.as_ref(), ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View) {
DataType::Categorical(None, Default::default())
} else {
Self::from_arrow(value_type, bin_to_view, None)
Expand Down
45 changes: 31 additions & 14 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use std::sync::OnceLock;
use std::{mem, ops};

use arrow::datatypes::ArrowSchemaRef;
use polars_row::ArrayRef;
use polars_schema::schema::debug_ensure_matching_schema_names;
use polars_utils::itertools::Itertools;
Expand Down Expand Up @@ -587,7 +588,7 @@ impl DataFrame {
) -> RecordBatchT<Box<dyn Array>> {
let height = self.height();

let arrays = self
let (schema, arrays) = self
.columns
.into_iter()
.map(|col| {
Expand All @@ -596,11 +597,14 @@ impl DataFrame {
if series.n_chunks() > 1 {
series = series.rechunk();
}
series.to_arrow(0, compat_level)
(
series.field().to_arrow(compat_level),
series.to_arrow(0, compat_level),
)
})
.collect();

RecordBatchT::new(height, arrays)
RecordBatchT::new(height, Arc::new(schema), arrays)
}

/// Returns true if the chunks of the columns do not align and re-chunking should be done
Expand Down Expand Up @@ -2766,6 +2770,12 @@ impl DataFrame {

RecordBatchIter {
columns: &self.columns,
schema: Arc::new(
self.columns
.iter()
.map(|c| c.field().to_arrow(compat_level))
.collect(),
),
idx: 0,
n_chunks: self.first_col_n_chunks(),
compat_level,
Expand All @@ -2784,7 +2794,13 @@ impl DataFrame {
/// as well.
pub fn iter_chunks_physical(&self) -> PhysRecordBatchIter<'_> {
PhysRecordBatchIter {
iters: self
schema: Arc::new(
self.get_columns()
.iter()
.map(|c| c.field().to_arrow(CompatLevel::newest()))
.collect(),
),
arr_iters: self
.materialized_column_iter()
.map(|s| s.chunks().iter())
.collect(),
Expand Down Expand Up @@ -3255,6 +3271,7 @@ impl DataFrame {

pub struct RecordBatchIter<'a> {
columns: &'a Vec<Column>,
schema: ArrowSchemaRef,
idx: usize,
n_chunks: usize,
compat_level: CompatLevel,
Expand Down Expand Up @@ -3287,8 +3304,7 @@ impl Iterator for RecordBatchIter<'_> {
self.idx += 1;

let length = batch_cols.first().map_or(0, |arr| arr.len());

Some(RecordBatch::new(length, batch_cols))
Some(RecordBatch::new(length, self.schema.clone(), batch_cols))
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand All @@ -3298,25 +3314,26 @@ impl Iterator for RecordBatchIter<'_> {
}

pub struct PhysRecordBatchIter<'a> {
iters: Vec<std::slice::Iter<'a, ArrayRef>>,
schema: ArrowSchemaRef,
arr_iters: Vec<std::slice::Iter<'a, ArrayRef>>,
}

impl Iterator for PhysRecordBatchIter<'_> {
type Item = RecordBatch;

fn next(&mut self) -> Option<Self::Item> {
self.iters
let arrs = self
.arr_iters
.iter_mut()
.map(|phys_iter| phys_iter.next().cloned())
.collect::<Option<Vec<_>>>()
.map(|arrs| {
let length = arrs.first().map_or(0, |arr| arr.len());
RecordBatch::new(length, arrs)
})
.collect::<Option<Vec<_>>>()?;

let length = arrs.first().map_or(0, |arr| arr.len());
Some(RecordBatch::new(length, self.schema.clone(), arrs))
}

fn size_hint(&self) -> (usize, Option<usize>) {
if let Some(iter) = self.iters.first() {
if let Some(iter) = self.arr_iters.first() {
iter.size_hint()
} else {
(0, None)
Expand Down
14 changes: 6 additions & 8 deletions crates/polars-core/src/series/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ impl Series {
keys, values,
);

let mut ordering = CategoricalOrdering::default();
if let Some(metadata) = md {
if metadata.is_enum() {
// SAFETY:
Expand All @@ -382,19 +383,16 @@ impl Series {
UInt32Chunked::with_chunk(name, keys),
Arc::new(RevMapping::build_local(values)),
true,
Default::default(),
CategoricalOrdering::Physical, // Enum always uses physical ordering
)
.into_series());
} else if let Some(o) = metadata.categorical() {
ordering = o;
}
}

// SAFETY:
// the invariants of an Arrow Dictionary guarantee the keys are in bounds
return Ok(CategoricalChunked::from_cats_and_rev_map_unchecked(
UInt32Chunked::with_chunk(name, keys),
Arc::new(RevMapping::build_local(values)),
false,
Default::default(),
return Ok(CategoricalChunked::from_keys_and_values(
name, &keys, &values, ordering,
)
.into_series());
}
Expand Down
5 changes: 3 additions & 2 deletions crates/polars-io/src/ipc/mmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ impl ArrowReader for MMapChunkIter<'_> {
None => chunk,
Some(proj) => {
let length = chunk.len();
let cols = chunk.into_arrays();
let (schema, cols) = chunk.into_schema_and_arrays();
let schema = schema.try_project_indices(proj).unwrap();
let arrays = proj.iter().map(|i| cols[*i].clone()).collect();
RecordBatch::new(length, arrays)
RecordBatch::new(length, Arc::new(schema), arrays)
},
};
Ok(Some(chunk))
Expand Down
8 changes: 6 additions & 2 deletions crates/polars-python/src/dataframe/construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@ impl PyDataFrame {
}

#[staticmethod]
pub fn from_arrow_record_batches(py: Python, rb: Vec<Bound<PyAny>>) -> PyResult<Self> {
let df = interop::arrow::to_rust::to_rust_df(py, &rb)?;
pub fn from_arrow_record_batches(
py: Python,
rb: Vec<Bound<PyAny>>,
schema: Bound<PyAny>,
) -> PyResult<Self> {
let df = interop::arrow::to_rust::to_rust_df(py, &rb, schema)?;
Ok(Self::from(df))
}
}
Expand Down
Loading

0 comments on commit 6ca9dd2

Please sign in to comment.