From 6ca9dd23a91496f93237ed6d75d5c3c329d48521 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Fri, 13 Dec 2024 17:32:46 +0100 Subject: [PATCH] fix arrow construction for empty tables --- .../src/io/avro/read/deserialize.rs | 3 + crates/polars-arrow/src/io/ipc/read/common.rs | 23 ++++-- crates/polars-arrow/src/mmap/mod.rs | 8 +- crates/polars-arrow/src/record_batch.rs | 29 ++++++- .../logical/categorical/revmap.rs | 4 +- crates/polars-core/src/datatypes/dtype.rs | 12 ++- crates/polars-core/src/datatypes/field.rs | 4 +- crates/polars-core/src/frame/mod.rs | 45 +++++++---- crates/polars-core/src/series/from.rs | 14 ++-- crates/polars-io/src/ipc/mmap.rs | 5 +- .../src/dataframe/construction.rs | 8 +- crates/polars-python/src/dataframe/export.rs | 12 ++- .../polars-python/src/interop/arrow/to_py.rs | 33 +++++--- .../src/interop/arrow/to_rust.rs | 79 +++++++++++++------ crates/polars-python/src/series/export.rs | 6 +- crates/polars-utils/src/pl_str.rs | 2 +- crates/polars/tests/it/arrow/io/ipc/mod.rs | 6 +- crates/polars/tests/it/io/avro/read.rs | 27 +++++-- crates/polars/tests/it/io/avro/write.rs | 29 ++++++- .../polars/tests/it/io/parquet/arrow/mod.rs | 20 +++-- .../polars/tests/it/io/parquet/arrow/write.rs | 6 +- .../polars/tests/it/io/parquet/read/file.rs | 9 ++- .../tests/it/io/parquet/read/row_group.rs | 17 +++- .../polars/tests/it/io/parquet/roundtrip.rs | 7 +- .../polars/_utils/construction/dataframe.py | 21 +---- py-polars/tests/unit/interop/test_interop.py | 30 +++++++ py-polars/tests/unit/io/test_parquet.py | 23 ++++++ 27 files changed, 353 insertions(+), 129 deletions(-) diff --git a/crates/polars-arrow/src/io/avro/read/deserialize.rs b/crates/polars-arrow/src/io/avro/read/deserialize.rs index f2f8af90c167..2929657aca6f 100644 --- a/crates/polars-arrow/src/io/avro/read/deserialize.rs +++ b/crates/polars-arrow/src/io/avro/read/deserialize.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use avro_schema::file::Block; use avro_schema::schema::{Enum, Field as AvroField, Record, Schema as AvroSchema}; use polars_error::{polars_bail, polars_err, PolarsResult}; @@ -508,6 +510,7 @@ pub fn deserialize( RecordBatchT::try_new( rows, + Arc::new(fields.iter_values().cloned().collect()), arrays .iter_mut() .zip(projection.iter()) diff --git a/crates/polars-arrow/src/io/ipc/read/common.rs b/crates/polars-arrow/src/io/ipc/read/common.rs index 0a1297bf1184..a81e40227ca4 100644 --- a/crates/polars-arrow/src/io/ipc/read/common.rs +++ b/crates/polars-arrow/src/io/ipc/read/common.rs @@ -1,5 +1,6 @@ use std::collections::VecDeque; use std::io::{Read, Seek}; +use std::sync::Arc; use polars_error::{polars_bail, polars_err, PolarsResult}; use polars_utils::aliases::PlHashMap; @@ -197,7 +198,11 @@ pub fn read_record_batch( .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; let length = limit.map(|limit| limit.min(length)).unwrap_or(length); - RecordBatchT::try_new(length, columns) + let mut schema: ArrowSchema = fields.iter_values().cloned().collect(); + if let Some(projection) = projection { + schema = schema.try_project_indices(projection).unwrap(); + } + RecordBatchT::try_new(length, Arc::new(schema), columns) } fn find_first_dict_field_d<'a>( @@ -373,13 +378,21 @@ pub fn apply_projection( let length = chunk.len(); // re-order according to projection - let arrays = chunk.into_arrays(); + let (schema, arrays) = chunk.into_schema_and_arrays(); + let mut new_schema = schema.as_ref().clone(); let mut new_arrays = arrays.clone(); - map.iter() - .for_each(|(old, new)| new_arrays[*new] = arrays[*old].clone()); + map.iter().for_each(|(old, new)| { + let (old_name, old_field) = schema.get_at_index(*old).unwrap(); + let (new_name, new_field) = new_schema.get_at_index_mut(*new).unwrap(); + + *new_name = old_name.clone(); + *new_field = old_field.clone(); + + new_arrays[*new] = arrays[*old].clone(); + }); - RecordBatchT::new(length, new_arrays) + RecordBatchT::new(length, Arc::new(new_schema), new_arrays) } #[cfg(test)] diff --git a/crates/polars-arrow/src/mmap/mod.rs b/crates/polars-arrow/src/mmap/mod.rs index 6ad0ca776d7d..ed026eb845b5 100644 --- a/crates/polars-arrow/src/mmap/mod.rs +++ b/crates/polars-arrow/src/mmap/mod.rs @@ -111,7 +111,13 @@ pub(crate) unsafe fn mmap_record>( ) }) .collect::>() - .and_then(|arr| RecordBatchT::try_new(length, arr)) + .and_then(|arr| { + RecordBatchT::try_new( + length, + Arc::new(fields.iter_values().cloned().collect()), + arr, + ) + }) } /// Memory maps an record batch from an IPC file into a [`RecordBatchT`]. diff --git a/crates/polars-arrow/src/record_batch.rs b/crates/polars-arrow/src/record_batch.rs index 2b0b8112ea9e..a884a7335b01 100644 --- a/crates/polars-arrow/src/record_batch.rs +++ b/crates/polars-arrow/src/record_batch.rs @@ -4,12 +4,14 @@ use polars_error::{polars_ensure, PolarsResult}; use crate::array::{Array, ArrayRef}; +use crate::datatypes::{ArrowSchema, ArrowSchemaRef}; /// A vector of trait objects of [`Array`] where every item has /// the same length, [`RecordBatchT::len`]. #[derive(Debug, Clone, PartialEq, Eq)] pub struct RecordBatchT> { height: usize, + schema: ArrowSchemaRef, arrays: Vec, } @@ -21,8 +23,8 @@ impl> RecordBatchT { /// # Panics /// /// I.f.f. the length does not match the length of any of the arrays - pub fn new(length: usize, arrays: Vec) -> Self { - Self::try_new(length, arrays).unwrap() + pub fn new(length: usize, schema: ArrowSchemaRef, arrays: Vec) -> Self { + Self::try_new(length, schema, arrays).unwrap() } /// Creates a new [`RecordBatchT`]. @@ -30,13 +32,21 @@ impl> RecordBatchT { /// # Error /// /// I.f.f. the height does not match the length of any of the arrays - pub fn try_new(height: usize, arrays: Vec) -> PolarsResult { + pub fn try_new(height: usize, schema: ArrowSchemaRef, arrays: Vec) -> PolarsResult { + polars_ensure!( + schema.len() == arrays.len(), + ComputeError: "RecordBatch requires an equal number of fields and arrays", + ); polars_ensure!( arrays.iter().all(|arr| arr.as_ref().len() == height), ComputeError: "RecordBatch requires all its arrays to have an equal number of rows", ); - Ok(Self { height, arrays }) + Ok(Self { + height, + schema, + arrays, + }) } /// returns the [`Array`]s in [`RecordBatchT`] @@ -44,6 +54,11 @@ impl> RecordBatchT { &self.arrays } + /// returns the [`ArrowSchema`]s in [`RecordBatchT`] + pub fn schema(&self) -> &ArrowSchema { + &self.schema + } + /// returns the [`Array`]s in [`RecordBatchT`] pub fn columns(&self) -> &[A] { &self.arrays @@ -74,6 +89,12 @@ impl> RecordBatchT { pub fn into_arrays(self) -> Vec { self.arrays } + + /// Consumes [`RecordBatchT`] into its underlying schema and arrays. + /// The arrays are guaranteed to have the same length + pub fn into_schema_and_arrays(self) -> (ArrowSchemaRef, Vec) { + (self.schema, self.arrays) + } } impl> From> for Vec { diff --git a/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs b/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs index 9631848104c7..fd4d60cd709d 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs @@ -5,15 +5,17 @@ use arrow::array::*; use polars_utils::aliases::PlRandomState; #[cfg(any(feature = "serde-lazy", feature = "serde"))] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; use crate::datatypes::PlHashMap; use crate::using_string_cache; -#[derive(Debug, Copy, Clone, PartialEq, Default)] +#[derive(Debug, Copy, Clone, PartialEq, Default, IntoStaticStr)] #[cfg_attr( any(feature = "serde-lazy", feature = "serde"), derive(Serialize, Deserialize) )] +#[strum(serialize_all = "snake_case")] pub enum CategoricalOrdering { #[default] Physical, diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 748f853c73d1..f28e252ae038 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -21,9 +21,13 @@ pub trait MetaDataExt: IntoMetadata { metadata.get(DTYPE_ENUM_VALUES).is_some() } - fn is_categorical(&self) -> bool { + fn categorical(&self) -> Option { let metadata = self.into_metadata_ref(); - metadata.get(DTYPE_CATEGORICAL).is_some() + match metadata.get(DTYPE_CATEGORICAL)?.as_str() { + "lexical" => Some(CategoricalOrdering::Lexical), + // Default is Physical + _ => Some(CategoricalOrdering::Physical), + } } fn maintain_type(&self) -> bool { @@ -603,9 +607,9 @@ impl DataType { )])) }, #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) => Some(BTreeMap::from([( + DataType::Categorical(_, ordering) => Some(BTreeMap::from([( PlSmallStr::from_static(DTYPE_CATEGORICAL), - PlSmallStr::EMPTY, + PlSmallStr::from_static(ordering.into()), )])), DataType::BinaryOffset => Some(BTreeMap::from([( PlSmallStr::from_static(PL_KEY), diff --git a/crates/polars-core/src/datatypes/field.rs b/crates/polars-core/src/datatypes/field.rs index 69060e258c27..f3a0e916926e 100644 --- a/crates/polars-core/src/datatypes/field.rs +++ b/crates/polars-core/src/datatypes/field.rs @@ -186,7 +186,9 @@ impl DataType { encoded = remainder; } DataType::Enum(Some(Arc::new(RevMapping::build_local(cats.into()))), Default::default()) - } else if md.map(|md| md.is_categorical()).unwrap_or(false) || matches!(value_type.as_ref(), ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View) { + } else if let Some(ordering) = md.and_then(|md| md.categorical()) { + DataType::Categorical(None, ordering) + } else if matches!(value_type.as_ref(), ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View) { DataType::Categorical(None, Default::default()) } else { Self::from_arrow(value_type, bin_to_view, None) diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 16a31d1db00b..1b03a6c17911 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -2,6 +2,7 @@ use std::sync::OnceLock; use std::{mem, ops}; +use arrow::datatypes::ArrowSchemaRef; use polars_row::ArrayRef; use polars_schema::schema::debug_ensure_matching_schema_names; use polars_utils::itertools::Itertools; @@ -587,7 +588,7 @@ impl DataFrame { ) -> RecordBatchT> { let height = self.height(); - let arrays = self + let (schema, arrays) = self .columns .into_iter() .map(|col| { @@ -596,11 +597,14 @@ impl DataFrame { if series.n_chunks() > 1 { series = series.rechunk(); } - series.to_arrow(0, compat_level) + ( + series.field().to_arrow(compat_level), + series.to_arrow(0, compat_level), + ) }) .collect(); - RecordBatchT::new(height, arrays) + RecordBatchT::new(height, Arc::new(schema), arrays) } /// Returns true if the chunks of the columns do not align and re-chunking should be done @@ -2766,6 +2770,12 @@ impl DataFrame { RecordBatchIter { columns: &self.columns, + schema: Arc::new( + self.columns + .iter() + .map(|c| c.field().to_arrow(compat_level)) + .collect(), + ), idx: 0, n_chunks: self.first_col_n_chunks(), compat_level, @@ -2784,7 +2794,13 @@ impl DataFrame { /// as well. pub fn iter_chunks_physical(&self) -> PhysRecordBatchIter<'_> { PhysRecordBatchIter { - iters: self + schema: Arc::new( + self.get_columns() + .iter() + .map(|c| c.field().to_arrow(CompatLevel::newest())) + .collect(), + ), + arr_iters: self .materialized_column_iter() .map(|s| s.chunks().iter()) .collect(), @@ -3255,6 +3271,7 @@ impl DataFrame { pub struct RecordBatchIter<'a> { columns: &'a Vec, + schema: ArrowSchemaRef, idx: usize, n_chunks: usize, compat_level: CompatLevel, @@ -3287,8 +3304,7 @@ impl Iterator for RecordBatchIter<'_> { self.idx += 1; let length = batch_cols.first().map_or(0, |arr| arr.len()); - - Some(RecordBatch::new(length, batch_cols)) + Some(RecordBatch::new(length, self.schema.clone(), batch_cols)) } fn size_hint(&self) -> (usize, Option) { @@ -3298,25 +3314,26 @@ impl Iterator for RecordBatchIter<'_> { } pub struct PhysRecordBatchIter<'a> { - iters: Vec>, + schema: ArrowSchemaRef, + arr_iters: Vec>, } impl Iterator for PhysRecordBatchIter<'_> { type Item = RecordBatch; fn next(&mut self) -> Option { - self.iters + let arrs = self + .arr_iters .iter_mut() .map(|phys_iter| phys_iter.next().cloned()) - .collect::>>() - .map(|arrs| { - let length = arrs.first().map_or(0, |arr| arr.len()); - RecordBatch::new(length, arrs) - }) + .collect::>>()?; + + let length = arrs.first().map_or(0, |arr| arr.len()); + Some(RecordBatch::new(length, self.schema.clone(), arrs)) } fn size_hint(&self) -> (usize, Option) { - if let Some(iter) = self.iters.first() { + if let Some(iter) = self.arr_iters.first() { iter.size_hint() } else { (0, None) diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs index 54f864e360b6..8a406a8b6fef 100644 --- a/crates/polars-core/src/series/from.rs +++ b/crates/polars-core/src/series/from.rs @@ -374,6 +374,7 @@ impl Series { keys, values, ); + let mut ordering = CategoricalOrdering::default(); if let Some(metadata) = md { if metadata.is_enum() { // SAFETY: @@ -382,19 +383,16 @@ impl Series { UInt32Chunked::with_chunk(name, keys), Arc::new(RevMapping::build_local(values)), true, - Default::default(), + CategoricalOrdering::Physical, // Enum always uses physical ordering ) .into_series()); + } else if let Some(o) = metadata.categorical() { + ordering = o; } } - // SAFETY: - // the invariants of an Arrow Dictionary guarantee the keys are in bounds - return Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( - UInt32Chunked::with_chunk(name, keys), - Arc::new(RevMapping::build_local(values)), - false, - Default::default(), + return Ok(CategoricalChunked::from_keys_and_values( + name, &keys, &values, ordering, ) .into_series()); } diff --git a/crates/polars-io/src/ipc/mmap.rs b/crates/polars-io/src/ipc/mmap.rs index b8749c5a7392..e6cfa4043676 100644 --- a/crates/polars-io/src/ipc/mmap.rs +++ b/crates/polars-io/src/ipc/mmap.rs @@ -97,9 +97,10 @@ impl ArrowReader for MMapChunkIter<'_> { None => chunk, Some(proj) => { let length = chunk.len(); - let cols = chunk.into_arrays(); + let (schema, cols) = chunk.into_schema_and_arrays(); + let schema = schema.try_project_indices(proj).unwrap(); let arrays = proj.iter().map(|i| cols[*i].clone()).collect(); - RecordBatch::new(length, arrays) + RecordBatch::new(length, Arc::new(schema), arrays) }, }; Ok(Some(chunk)) diff --git a/crates/polars-python/src/dataframe/construction.rs b/crates/polars-python/src/dataframe/construction.rs index df051cab1741..15bce087d86c 100644 --- a/crates/polars-python/src/dataframe/construction.rs +++ b/crates/polars-python/src/dataframe/construction.rs @@ -52,8 +52,12 @@ impl PyDataFrame { } #[staticmethod] - pub fn from_arrow_record_batches(py: Python, rb: Vec>) -> PyResult { - let df = interop::arrow::to_rust::to_rust_df(py, &rb)?; + pub fn from_arrow_record_batches( + py: Python, + rb: Vec>, + schema: Bound, + ) -> PyResult { + let df = interop::arrow::to_rust::to_rust_df(py, &rb, schema)?; Ok(Self::from(df)) } } diff --git a/crates/polars-python/src/dataframe/export.rs b/crates/polars-python/src/dataframe/export.rs index 5edb1ed5ed0a..7caa5a763fb9 100644 --- a/crates/polars-python/src/dataframe/export.rs +++ b/crates/polars-python/src/dataframe/export.rs @@ -76,12 +76,11 @@ impl PyDataFrame { pub fn to_arrow(&mut self, py: Python, compat_level: PyCompatLevel) -> PyResult> { py.allow_threads(|| self.df.align_chunks_par()); let pyarrow = py.import("pyarrow")?; - let names = self.df.get_column_names_str(); let rbs = self .df .iter_chunks(compat_level.0, true) - .map(|rb| interop::arrow::to_py::to_py_rb(&rb, &names, &pyarrow)) + .map(|rb| interop::arrow::to_py::to_py_rb(&rb, py, &pyarrow)) .collect::>()?; Ok(rbs) } @@ -96,7 +95,6 @@ impl PyDataFrame { py.allow_threads(|| self.df.as_single_chunk_par()); Python::with_gil(|py| { let pyarrow = py.import("pyarrow")?; - let names = self.df.get_column_names_str(); let cat_columns = self .df .get_columns() @@ -115,9 +113,9 @@ impl PyDataFrame { .iter_chunks(CompatLevel::oldest(), true) .map(|rb| { let length = rb.len(); - let mut rb = rb.into_arrays(); + let (schema, mut arrays) = rb.into_schema_and_arrays(); for i in &cat_columns { - let arr = rb.get_mut(*i).unwrap(); + let arr = arrays.get_mut(*i).unwrap(); let out = polars_core::export::cast::cast( &**arr, &ArrowDataType::Dictionary( @@ -130,9 +128,9 @@ impl PyDataFrame { .unwrap(); *arr = out; } - let rb = RecordBatch::new(length, rb); + let rb = RecordBatch::new(length, schema, arrays); - interop::arrow::to_py::to_py_rb(&rb, &names, &pyarrow) + interop::arrow::to_py::to_py_rb(&rb, py, &pyarrow) }) .collect::>()?; Ok(rbs) diff --git a/crates/polars-python/src/interop/arrow/to_py.rs b/crates/polars-python/src/interop/arrow/to_py.rs index abbd763ed36e..89db60b315f5 100644 --- a/crates/polars-python/src/interop/arrow/to_py.rs +++ b/crates/polars-python/src/interop/arrow/to_py.rs @@ -14,12 +14,12 @@ use pyo3::prelude::*; use pyo3::types::PyCapsule; /// Arrow array to Python. -pub(crate) fn to_py_array(array: ArrayRef, pyarrow: &Bound) -> PyResult { - let schema = Box::new(ffi::export_field_to_c(&ArrowField::new( - PlSmallStr::EMPTY, - array.dtype().clone(), - true, - ))); +pub(crate) fn to_py_array( + array: ArrayRef, + field: &ArrowField, + pyarrow: &Bound, +) -> PyResult { + let schema = Box::new(ffi::export_field_to_c(&field)); let array = Box::new(ffi::export_array_to_c(array)); let schema_ptr: *const ffi::ArrowSchema = &*schema; @@ -36,19 +36,30 @@ pub(crate) fn to_py_array(array: ArrayRef, pyarrow: &Bound) -> PyResul /// RecordBatch to Python. pub(crate) fn to_py_rb( rb: &RecordBatch, - names: &[&str], + py: Python, pyarrow: &Bound, ) -> PyResult { - let mut arrays = Vec::with_capacity(rb.len()); + let mut arrays = Vec::with_capacity(rb.width()); - for array in rb.columns() { - let array_object = to_py_array(array.clone(), pyarrow)?; + for (array, field) in rb.columns().iter().zip(rb.schema().iter_values()) { + let array_object = to_py_array(array.clone(), field, pyarrow)?; arrays.push(array_object); } + let schema = Box::new(ffi::export_field_to_c(&ArrowField { + name: PlSmallStr::EMPTY, + dtype: ArrowDataType::Struct(rb.schema().iter_values().cloned().collect()), + is_nullable: false, + metadata: None, + })); + let schema_ptr: *const ffi::ArrowSchema = &*schema; + + let schema = pyarrow + .getattr("Schema")? + .call_method1("_import_from_c", (schema_ptr as Py_uintptr_t,))?; let record = pyarrow .getattr("RecordBatch")? - .call_method1("from_arrays", (arrays, names.to_vec()))?; + .call_method1("from_arrays", (arrays, py.None(), schema))?; Ok(record.unbind()) } diff --git a/crates/polars-python/src/interop/arrow/to_rust.rs b/crates/polars-python/src/interop/arrow/to_rust.rs index ee741c4279cc..81f005a26ebc 100644 --- a/crates/polars-python/src/interop/arrow/to_rust.rs +++ b/crates/polars-python/src/interop/arrow/to_rust.rs @@ -9,14 +9,18 @@ use pyo3::types::PyList; use crate::error::PyPolarsErr; -pub fn field_to_rust(obj: Bound<'_, PyAny>) -> PyResult { - let schema = Box::new(ffi::ArrowSchema::empty()); - let schema_ptr = &*schema as *const ffi::ArrowSchema; +pub fn field_to_rust_arrow(obj: Bound<'_, PyAny>) -> PyResult { + let mut schema = Box::new(ffi::ArrowSchema::empty()); + let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema; // make the conversion through PyArrow's private API obj.call_method1("_export_to_c", (schema_ptr as Py_uintptr_t,))?; let field = unsafe { ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)? }; - Ok((&field).into()) + Ok(field.clone()) +} + +pub fn field_to_rust(obj: Bound<'_, PyAny>) -> PyResult { + field_to_rust_arrow(obj).map(|f| (&f).into()) } // PyList which you get by calling `list(schema)` @@ -26,11 +30,11 @@ pub fn pyarrow_schema_to_rust(obj: &Bound<'_, PyList>) -> PyResult { pub fn array_to_rust(obj: &Bound) -> PyResult { // prepare a pointer to receive the Array struct - let array = Box::new(ffi::ArrowArray::empty()); - let schema = Box::new(ffi::ArrowSchema::empty()); + let mut array = Box::new(ffi::ArrowArray::empty()); + let mut schema = Box::new(ffi::ArrowSchema::empty()); - let array_ptr = &*array as *const ffi::ArrowArray; - let schema_ptr = &*schema as *const ffi::ArrowSchema; + let array_ptr = array.as_mut() as *mut ffi::ArrowArray; + let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema; // make the conversion through PyArrow's private API // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds @@ -46,24 +50,31 @@ pub fn array_to_rust(obj: &Bound) -> PyResult { } } -pub fn to_rust_df(py: Python, rb: &[Bound]) -> PyResult { - let schema = rb - .first() - .ok_or_else(|| PyPolarsErr::Other("empty table".into()))? - .getattr("schema")?; - let names = schema - .getattr("names")? - .extract::>()? - .into_iter() - .map(PlSmallStr::from_string) - .collect::>(); +pub fn to_rust_df(py: Python, rb: &[Bound], schema: Bound) -> PyResult { + let ArrowDataType::Struct(fields) = field_to_rust_arrow(schema)?.dtype else { + return Err(PyPolarsErr::Other("invalid top-level schema".into()).into()); + }; + let schema = ArrowSchema::from_iter(fields); + + if rb.is_empty() { + let columns = schema + .iter_values() + .map(|field| { + let field = Field::from(field); + Series::new_empty(field.name, &field.dtype).into_column() + }) + .collect::>(); + + // no need to check as a record batch has the same guarantees + return Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) }); + } let dfs = rb .iter() .map(|rb| { let mut run_parallel = false; - let columns = (0..names.len()) + let columns = (0..schema.len()) .map(|i| { let array = rb.call_method1("column", (i,))?; let arr = array_to_rust(&array)?; @@ -85,9 +96,17 @@ pub fn to_rust_df(py: Python, rb: &[Bound]) -> PyResult { .into_par_iter() .enumerate() .map(|(i, arr)| { - let s = Series::try_from((names[i].clone(), arr)) - .map_err(PyPolarsErr::from)? - .into_column(); + let (_, field) = schema.get_at_index(i).unwrap(); + let s = unsafe { + Series::_try_from_arrow_unchecked_with_md( + field.name.clone(), + vec![arr], + field.dtype(), + field.metadata.as_deref(), + ) + } + .map_err(PyPolarsErr::from)? + .into_column(); Ok(s) }) .collect::>>() @@ -98,9 +117,17 @@ pub fn to_rust_df(py: Python, rb: &[Bound]) -> PyResult { .into_iter() .enumerate() .map(|(i, arr)| { - let s = Series::try_from((names[i].clone(), arr)) - .map_err(PyPolarsErr::from)? - .into_column(); + let (_, field) = schema.get_at_index(i).unwrap(); + let s = unsafe { + Series::_try_from_arrow_unchecked_with_md( + field.name.clone(), + vec![arr], + field.dtype(), + field.metadata.as_deref(), + ) + } + .map_err(PyPolarsErr::from)? + .into_column(); Ok(s) }) .collect::>>() diff --git a/crates/polars-python/src/series/export.rs b/crates/polars-python/src/series/export.rs index f9c76e226ee3..c2fecd07088e 100644 --- a/crates/polars-python/src/series/export.rs +++ b/crates/polars-python/src/series/export.rs @@ -149,7 +149,11 @@ impl PySeries { self.rechunk(py, true); let pyarrow = py.import("pyarrow")?; - interop::arrow::to_py::to_py_array(self.series.to_arrow(0, compat_level.0), &pyarrow) + interop::arrow::to_py::to_py_array( + self.series.to_arrow(0, compat_level.0), + &self.series.field().to_arrow(compat_level.0), + &pyarrow, + ) } #[allow(unused_variables)] diff --git a/crates/polars-utils/src/pl_str.rs b/crates/polars-utils/src/pl_str.rs index abd689853786..64d1bde78e80 100644 --- a/crates/polars-utils/src/pl_str.rs +++ b/crates/polars-utils/src/pl_str.rs @@ -3,7 +3,7 @@ macro_rules! format_pl_smallstr { ($($arg:tt)*) => {{ use std::fmt::Write; - let mut string = PlSmallStr::EMPTY; + let mut string = $crate::pl_str::PlSmallStr::EMPTY; write!(string, $($arg)*).unwrap(); string }} diff --git a/crates/polars/tests/it/arrow/io/ipc/mod.rs b/crates/polars/tests/it/arrow/io/ipc/mod.rs index 8004f2fc8eea..8e9dc0c8db01 100644 --- a/crates/polars/tests/it/arrow/io/ipc/mod.rs +++ b/crates/polars/tests/it/arrow/io/ipc/mod.rs @@ -62,7 +62,7 @@ fn prep_schema(array: &dyn Array) -> ArrowSchemaRef { fn write_boolean() -> PolarsResult<()> { let array = BooleanArray::from([Some(true), Some(false), None, Some(true)]).boxed(); let schema = prep_schema(array.as_ref()); - let columns = RecordBatchT::try_new(4, vec![array])?; + let columns = RecordBatchT::try_new(4, schema.clone(), vec![array])?; round_trip(columns, schema, None, Some(Compression::ZSTD)) } @@ -72,7 +72,7 @@ fn write_sliced_utf8() -> PolarsResult<()> { .sliced(1, 1) .boxed(); let schema = prep_schema(array.as_ref()); - let columns = RecordBatchT::try_new(array.len(), vec![array])?; + let columns = RecordBatchT::try_new(array.len(), schema.clone(), vec![array])?; round_trip(columns, schema, None, Some(Compression::ZSTD)) } @@ -80,6 +80,6 @@ fn write_sliced_utf8() -> PolarsResult<()> { fn write_binview() -> PolarsResult<()> { let array = Utf8ViewArray::from_slice([Some("foo"), Some("bar"), None, Some("hamlet")]).boxed(); let schema = prep_schema(array.as_ref()); - let columns = RecordBatchT::try_new(array.len(), vec![array])?; + let columns = RecordBatchT::try_new(array.len(), schema.clone(), vec![array])?; round_trip(columns, schema, None, Some(Compression::ZSTD)) } diff --git a/crates/polars/tests/it/io/avro/read.rs b/crates/polars/tests/it/io/avro/read.rs index 57adac991d1a..6e7abe95a92b 100644 --- a/crates/polars/tests/it/io/avro/read.rs +++ b/crates/polars/tests/it/io/avro/read.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use apache_avro::types::{Record, Value}; use apache_avro::{Codec, Days, Duration, Millis, Months, Schema as AvroSchema, Writer}; use arrow::array::*; @@ -6,6 +8,7 @@ use arrow::io::avro::avro_schema::read::read_metadata; use arrow::io::avro::read; use arrow::record_batch::RecordBatchT; use polars_error::PolarsResult; +use polars_utils::format_pl_smallstr; pub(super) fn schema() -> (AvroSchema, ArrowSchema) { let raw_schema = r#" @@ -124,7 +127,13 @@ pub(super) fn data() -> RecordBatchT> { .boxed(), ]; - RecordBatchT::try_new(2, columns).unwrap() + let schema = columns + .iter() + .enumerate() + .map(|(i, col)| Field::new(format_pl_smallstr!("c{i}"), col.dtype().clone(), true)) + .collect(); + + RecordBatchT::try_new(2, Arc::new(schema), columns).unwrap() } pub(super) fn write_avro(codec: Codec) -> Result, apache_avro::Error> { @@ -259,14 +268,20 @@ fn test_projected() -> PolarsResult<()> { projection[i] = true; let length = expected.first().map_or(0, |arr| arr.len()); - let expected = expected + let (expected_schema_2, expected_arrays) = expected.clone().into_schema_and_arrays(); + let expected_schema_2 = expected_schema_2 + .as_ref() .clone() - .into_arrays() .into_iter() .zip(projection.iter()) .filter_map(|x| if *x.1 { Some(x.0) } else { None }) .collect(); - let expected = RecordBatchT::new(length, expected); + let expected_arrays = expected_arrays + .into_iter() + .zip(projection.iter()) + .filter_map(|x| if *x.1 { Some(x.0) } else { None }) + .collect(); + let expected = RecordBatchT::new(length, Arc::new(expected_schema_2), expected_arrays); let expected_schema = expected_schema .clone() @@ -328,9 +343,11 @@ pub(super) fn data_list() -> RecordBatchT> { array.try_extend(data).unwrap(); let length = array.len(); + let field = Field::new("c1".into(), array.dtype().clone(), true); + let schema = ArrowSchema::from_iter([field]); let columns = vec![array.into_box()]; - RecordBatchT::try_new(length, columns).unwrap() + RecordBatchT::try_new(length, Arc::new(schema), columns).unwrap() } pub(super) fn write_list(codec: Codec) -> Result, apache_avro::Error> { diff --git a/crates/polars/tests/it/io/avro/write.rs b/crates/polars/tests/it/io/avro/write.rs index 48633f39ce94..6c8ca171704e 100644 --- a/crates/polars/tests/it/io/avro/write.rs +++ b/crates/polars/tests/it/io/avro/write.rs @@ -1,4 +1,5 @@ use std::io::Cursor; +use std::sync::Arc; use arrow::array::*; use arrow::datatypes::*; @@ -11,6 +12,7 @@ use polars::io::avro::{AvroReader, AvroWriter}; use polars::io::{SerReader, SerWriter}; use polars::prelude::df; use polars_error::PolarsResult; +use polars_utils::format_pl_smallstr; use super::read::read_avro; @@ -102,7 +104,13 @@ pub(super) fn data() -> RecordBatchT> { )), ]; - RecordBatchT::new(2, columns) + let schema = columns + .iter() + .enumerate() + .map(|(i, col)| Field::new(format_pl_smallstr!("c{i}"), col.dtype().clone(), true)) + .collect(); + + RecordBatchT::new(2, Arc::new(schema), columns) } pub(super) fn serialize_to_block>( @@ -197,7 +205,13 @@ fn large_format_data() -> RecordBatchT> { Box::new(BinaryArray::::from_slice([b"foo", b"bar"])), Box::new(BinaryArray::::from([Some(b"foo"), None])), ]; - RecordBatchT::new(2, columns) + let schema = columns + .iter() + .enumerate() + .map(|(i, col)| Field::new(format_pl_smallstr!("c{i}"), col.dtype().clone(), true)) + .collect(); + + RecordBatchT::new(2, Arc::new(schema), columns) } fn large_format_expected_schema() -> ArrowSchema { @@ -216,7 +230,13 @@ fn large_format_expected_data() -> RecordBatchT> { Box::new(BinaryArray::::from_slice([b"foo", b"bar"])), Box::new(BinaryArray::::from([Some(b"foo"), None])), ]; - RecordBatchT::new(2, columns) + let schema = columns + .iter() + .enumerate() + .map(|(i, col)| Field::new(format_pl_smallstr!("c{i}"), col.dtype().clone(), true)) + .collect(); + + RecordBatchT::new(2, Arc::new(schema), columns) } #[test] @@ -264,9 +284,12 @@ fn struct_data() -> RecordBatchT> { Field::new("item1".into(), ArrowDataType::Int32, false), Field::new("item2".into(), ArrowDataType::Int32, true), ]); + let struct_field = Field::new("a".into(), struct_dt.clone(), false); + let schema = ArrowSchema::from_iter([("a".into(), struct_field)]); RecordBatchT::new( 2, + Arc::new(schema), vec![ Box::new(StructArray::new( struct_dt.clone(), diff --git a/crates/polars/tests/it/io/parquet/arrow/mod.rs b/crates/polars/tests/it/io/parquet/arrow/mod.rs index cec228971ad8..933cbc7018a3 100644 --- a/crates/polars/tests/it/io/parquet/arrow/mod.rs +++ b/crates/polars/tests/it/io/parquet/arrow/mod.rs @@ -2,6 +2,7 @@ mod read; mod write; use std::io::{Cursor, Read, Seek}; +use std::sync::Arc; use arrow::array::*; use arrow::bitmap::Bitmap; @@ -1382,7 +1383,7 @@ fn assert_roundtrip( .into_iter() .map(|x| x.sliced(0, limit)) .collect::>(); - RecordBatchT::new(length, expected) + RecordBatchT::new(length, Arc::new(schema.clone()), expected) } else { chunk }; @@ -1435,7 +1436,7 @@ fn assert_array_roundtrip( ) -> PolarsResult<()> { let schema = ArrowSchema::from_iter([Field::new("a1".into(), array.dtype().clone(), is_nullable)]); - let chunk = RecordBatchT::try_new(array.len(), vec![array])?; + let chunk = RecordBatchT::try_new(array.len(), Arc::new(schema.clone()), vec![array])?; assert_roundtrip(schema, chunk, limit) } @@ -1537,9 +1538,18 @@ fn limit_list() -> PolarsResult<()> { #[test] fn filter_chunk() -> PolarsResult<()> { - let chunk1 = RecordBatchT::new(2, vec![PrimitiveArray::from_slice([1i16, 3]).boxed()]); - let chunk2 = RecordBatchT::new(2, vec![PrimitiveArray::from_slice([2i16, 4]).boxed()]); - let schema = ArrowSchema::from_iter([Field::new("c1".into(), ArrowDataType::Int16, true)]); + let field = Field::new("c1".into(), ArrowDataType::Int16, true); + let schema = ArrowSchema::from_iter([field]); + let chunk1 = RecordBatchT::new( + 2, + Arc::new(schema.clone()), + vec![PrimitiveArray::from_slice([1i16, 3]).boxed()], + ); + let chunk2 = RecordBatchT::new( + 2, + Arc::new(schema.clone()), + vec![PrimitiveArray::from_slice([2i16, 4]).boxed()], + ); let r = integration_write(&schema, &[chunk1.clone(), chunk2.clone()])?; diff --git a/crates/polars/tests/it/io/parquet/arrow/write.rs b/crates/polars/tests/it/io/parquet/arrow/write.rs index e6c969f1d054..d767bc2cbdbb 100644 --- a/crates/polars/tests/it/io/parquet/arrow/write.rs +++ b/crates/polars/tests/it/io/parquet/arrow/write.rs @@ -50,7 +50,11 @@ fn round_trip_opt_stats( data_page_size: None, }; - let iter = vec![RecordBatchT::try_new(array.len(), vec![array.clone()])]; + let iter = vec![RecordBatchT::try_new( + array.len(), + Arc::new(schema.clone()), + vec![array.clone()], + )]; let row_groups = RowGroupIterator::try_new(iter.into_iter(), &schema, options, vec![encodings])?; diff --git a/crates/polars/tests/it/io/parquet/read/file.rs b/crates/polars/tests/it/io/parquet/read/file.rs index d2be2c5402d9..0005015387f2 100644 --- a/crates/polars/tests/it/io/parquet/read/file.rs +++ b/crates/polars/tests/it/io/parquet/read/file.rs @@ -1,4 +1,5 @@ use std::io::{Read, Seek}; +use std::sync::Arc; use arrow::array::Array; use arrow::datatypes::ArrowSchema; @@ -142,6 +143,7 @@ impl RowGroupReader { let num_rows = row_group.num_rows(); + let column_schema = self.schema.iter_values().cloned().collect(); let column_chunks = read_columns_many( &mut self.reader, &row_group, @@ -149,7 +151,12 @@ impl RowGroupReader { Some(Filter::new_limited(self.remaining_rows)), )?; - let result = RowGroupDeserializer::new(column_chunks, num_rows, Some(self.remaining_rows)); + let result = RowGroupDeserializer::new( + Arc::new(column_schema), + column_chunks, + num_rows, + Some(self.remaining_rows), + ); self.remaining_rows = self.remaining_rows.saturating_sub(num_rows); Ok(Some(result)) } diff --git a/crates/polars/tests/it/io/parquet/read/row_group.rs b/crates/polars/tests/it/io/parquet/read/row_group.rs index 8008c594ed19..d5e2ab7728ff 100644 --- a/crates/polars/tests/it/io/parquet/read/row_group.rs +++ b/crates/polars/tests/it/io/parquet/read/row_group.rs @@ -1,7 +1,7 @@ use std::io::{Read, Seek}; use arrow::array::Array; -use arrow::datatypes::Field; +use arrow::datatypes::{ArrowSchemaRef, Field}; use arrow::record_batch::RecordBatchT; use polars::prelude::ArrowSchema; use polars_error::PolarsResult; @@ -22,6 +22,7 @@ use polars_utils::mmap::MemReader; pub struct RowGroupDeserializer { num_rows: usize, remaining_rows: usize, + column_schema: ArrowSchemaRef, column_chunks: Vec>, } @@ -31,10 +32,16 @@ impl RowGroupDeserializer { /// # Panic /// This function panics iff any of the `column_chunks` /// do not return an array with an equal length. - pub fn new(column_chunks: Vec>, num_rows: usize, limit: Option) -> Self { + pub fn new( + column_schema: ArrowSchemaRef, + column_chunks: Vec>, + num_rows: usize, + limit: Option, + ) -> Self { Self { num_rows, remaining_rows: limit.unwrap_or(usize::MAX).min(num_rows), + column_schema, column_chunks, } } @@ -53,7 +60,11 @@ impl Iterator for RowGroupDeserializer { return None; } let length = self.column_chunks.first().map_or(0, |chunk| chunk.len()); - let chunk = RecordBatchT::try_new(length, std::mem::take(&mut self.column_chunks)); + let chunk = RecordBatchT::try_new( + length, + self.column_schema.clone(), + std::mem::take(&mut self.column_chunks), + ); self.remaining_rows = self.remaining_rows.saturating_sub( chunk .as_ref() diff --git a/crates/polars/tests/it/io/parquet/roundtrip.rs b/crates/polars/tests/it/io/parquet/roundtrip.rs index d20551432ec0..62a8ba412cab 100644 --- a/crates/polars/tests/it/io/parquet/roundtrip.rs +++ b/crates/polars/tests/it/io/parquet/roundtrip.rs @@ -1,4 +1,5 @@ use std::io::Cursor; +use std::sync::Arc; use arrow::array::{ArrayRef, Utf8ViewArray}; use arrow::datatypes::{ArrowSchema, Field}; @@ -28,7 +29,11 @@ fn round_trip( data_page_size: None, }; - let iter = vec![RecordBatchT::try_new(array.len(), vec![array.clone()])]; + let iter = vec![RecordBatchT::try_new( + array.len(), + Arc::new(schema.clone()), + vec![array.clone()], + )]; let row_groups = RowGroupIterator::try_new(iter.into_iter(), &schema, options, vec![encodings])?; diff --git a/py-polars/polars/_utils/construction/dataframe.py b/py-polars/polars/_utils/construction/dataframe.py index f1db5fa70859..4b35a65b6203 100644 --- a/py-polars/polars/_utils/construction/dataframe.py +++ b/py-polars/polars/_utils/construction/dataframe.py @@ -1176,9 +1176,6 @@ def arrow_to_pydf( raise ValueError(msg) from e data_dict = {} - # dictionaries cannot be built in different batches (categorical does not allow - # that) so we rechunk them and create them separately. - dictionary_cols = {} # struct columns don't work properly if they contain multiple chunks. struct_cols = {} names = [] @@ -1188,10 +1185,7 @@ def arrow_to_pydf( names.append(name) column = plc.coerce_arrow(column) - if pa.types.is_dictionary(column.type): - ps = plc.arrow_to_pyseries(name, column, rechunk=rechunk) - dictionary_cols[i] = wrap_s(ps) - elif ( + if ( isinstance(column.type, pa.StructType) and hasattr(column, "num_chunks") and column.num_chunks > 1 @@ -1205,12 +1199,7 @@ def arrow_to_pydf( tbl = pa.table(data_dict) # path for table without rows that keeps datatype - if tbl.shape[0] == 0: - pydf = pl.DataFrame( - [pl.Series(name, c) for (name, c) in zip(tbl.column_names, tbl.columns)] - )._df - else: - pydf = PyDataFrame.from_arrow_record_batches(tbl.to_batches()) + pydf = PyDataFrame.from_arrow_record_batches(tbl.to_batches(), tbl.schema) else: pydf = pl.DataFrame([])._df if rechunk: @@ -1222,12 +1211,6 @@ def broadcastable_s(s: Series, name: str) -> Expr: return F.lit(s).alias(name) reset_order = False - if len(dictionary_cols) > 0: - df = wrap_df(pydf) - df = df.with_columns( - [broadcastable_s(s, s.name) for s in dictionary_cols.values()] - ) - reset_order = True if len(struct_cols) > 0: df = wrap_df(pydf) diff --git a/py-polars/tests/unit/interop/test_interop.py b/py-polars/tests/unit/interop/test_interop.py index 7ab6c196c807..d7c3fbd6d9d4 100644 --- a/py-polars/tests/unit/interop/test_interop.py +++ b/py-polars/tests/unit/interop/test_interop.py @@ -800,3 +800,33 @@ def test_misaligned_nested_arrow_19097() -> None: a = a.replace(2, None) # then we add a validity mask with offset=0 a = a.reshape((2, 1)) # then we make it nested assert_series_equal(pl.Series("a", a.to_arrow()), a) + + +def test_arrow_roundtrip_lex_cat_20288() -> None: + tb = ( + pl.Series("a", ["A", "B"], pl.Categorical(ordering="lexical")) + .to_frame() + .to_arrow() + ) + df = pl.from_arrow(tb) + assert isinstance(df, pl.DataFrame) + dt = df.schema["a"] + assert isinstance(dt, pl.Categorical) + assert dt.ordering == "lexical" + + +def test_from_arrow_string_cache_20271(): + with pl.StringCache(): + s = pl.Series("a", ["A", "B", "C"], pl.Categorical) + df = pl.from_arrow( + pa.table({"b": pa.DictionaryArray.from_arrays([0, 1], ["D", "E"])}) + ) + assert isinstance(df, pl.DataFrame) + + assert_series_equal( + s.to_physical(), pl.Series("a", [0, 1, 2]), check_dtypes=False + ) + assert_series_equal(df.to_series(), pl.Series("b", ["D", "E"], pl.Categorical)) + assert_series_equal( + df.to_series().to_physical(), pl.Series("b", [3, 4]), check_dtypes=False + ) diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 0fc7992dd7ea..f3efd8ce6ac6 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -2682,3 +2682,26 @@ def test_parquet_cast_to_cat() -> None: pl.Series("col1", ["A", "A", None, "B"], pl.Categorical), pl.read_parquet(f).to_series(), ) + + +def test_parquet_roundtrip_lex_cat_20288() -> None: + f = io.BytesIO() + df = pl.Series("a", ["A", "B"], pl.Categorical(ordering="lexical")).to_frame() + df.write_parquet(f) + f.seek(0) + dt = pl.scan_parquet(f).collect_schema()["a"] + assert isinstance(dt, pl.Categorical) + assert dt.ordering == "lexical" + + +def test_from_parquet_string_cache_20271(): + with pl.StringCache(): + s = pl.Series("a", ["A", "B", "C"], pl.Categorical) + df = pl.from_arrow( + pa.table({"b": pa.DictionaryArray.from_arrays([0, 1], ["D", "E"])}) + ) + print(s) + print(s.to_physical()) + print() + print(df.to_series()) + print(df.to_series().to_physical())