Skip to content

Commit

Permalink
fix from_arrow to always use the arrowschema
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Dec 13, 2024
1 parent 6ca9dd2 commit e1f46a4
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 75 deletions.
10 changes: 9 additions & 1 deletion crates/polars-arrow/src/array/binview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,13 +481,21 @@ impl<T: ViewType + ?Sized> BinaryViewArrayGeneric<T> {
let views = self.views.make_mut();
let completed_buffers = self.buffers.to_vec();
let validity = self.validity.map(|bitmap| bitmap.make_mut());

// We need to know the total_bytes_len if we are going to mutate it.
let mut total_bytes_len = self.total_bytes_len.load(Ordering::Relaxed);
if total_bytes_len == UNKNOWN_LEN {
total_bytes_len = views.iter().map(|view| view.length as u64).sum();
}
let total_bytes_len = total_bytes_len as usize;

MutableBinaryViewArray {
views,
completed_buffers,
in_progress_buffer: vec![],
validity,
phantom: Default::default(),
total_bytes_len: self.total_bytes_len.load(Ordering::Relaxed) as usize,
total_bytes_len,
total_buffer_len: self.total_buffer_len,
stolen_buffers: PlHashMap::new(),
}
Expand Down
9 changes: 8 additions & 1 deletion crates/polars-arrow/src/io/avro/read/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,9 +508,16 @@ pub fn deserialize(
}
}

let projected_schema = fields
.iter_values()
.zip(projection)
.filter_map(|(f, p)| (*p).then_some(f))
.cloned()
.collect();

RecordBatchT::try_new(
rows,
Arc::new(fields.iter_values().cloned().collect()),
Arc::new(projected_schema),
arrays
.iter_mut()
.zip(projection.iter())
Expand Down
29 changes: 24 additions & 5 deletions crates/polars-python/src/dataframe/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,26 +108,45 @@ impl PyDataFrame {
})
.map(|(i, _)| i)
.collect::<Vec<_>>();

let enum_and_categorical_dtype = ArrowDataType::Dictionary(
IntegerType::Int64,
Box::new(ArrowDataType::LargeUtf8),
false,
);

let mut replaced_schema = None;
let rbs = self
.df
.iter_chunks(CompatLevel::oldest(), true)
.map(|rb| {
let length = rb.len();
let (schema, mut arrays) = rb.into_schema_and_arrays();

// Pandas does not allow unsigned dictionary indices so we replace them.
replaced_schema =
(replaced_schema.is_none() && !cat_columns.is_empty()).then(|| {
let mut schema = schema.as_ref().clone();
for i in &cat_columns {
let (_, field) = schema.get_at_index_mut(*i).unwrap();
field.dtype = enum_and_categorical_dtype.clone();
}
Arc::new(schema)
});

for i in &cat_columns {
let arr = arrays.get_mut(*i).unwrap();
let out = polars_core::export::cast::cast(
&**arr,
&ArrowDataType::Dictionary(
IntegerType::Int64,
Box::new(ArrowDataType::LargeUtf8),
false,
),
&enum_and_categorical_dtype,
CastOptionsImpl::default(),
)
.unwrap();
*arr = out;
}
let schema = replaced_schema
.as_ref()
.map_or(schema, |replaced| replaced.clone());
let rb = RecordBatch::new(length, schema, arrays);

interop::arrow::to_py::to_py_rb(&rb, py, &pyarrow)
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-python/src/interop/arrow/to_py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub(crate) fn to_py_array(
field: &ArrowField,
pyarrow: &Bound<PyModule>,
) -> PyResult<PyObject> {
let schema = Box::new(ffi::export_field_to_c(&field));
let schema = Box::new(ffi::export_field_to_c(field));
let array = Box::new(ffi::export_array_to_c(array));

let schema_ptr: *const ffi::ArrowSchema = &*schema;
Expand Down
69 changes: 12 additions & 57 deletions py-polars/polars/_utils/construction/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,81 +1161,36 @@ def arrow_to_pydf(
rechunk: bool = True,
) -> PyDataFrame:
"""Construct a PyDataFrame from an Arrow Table or RecordBatch."""
original_schema = schema
data_column_names = data.schema.names
column_names, schema_overrides = _unpack_schema(
(schema or data_column_names), schema_overrides=schema_overrides
(schema or data.schema.names), schema_overrides=schema_overrides
)

try:
if column_names != data_column_names:
if isinstance(data, pa.RecordBatch):
data = pa.Table.from_batches([data])
if column_names != data.schema.names:
data = data.rename_columns(column_names)
except pa.lib.ArrowInvalid as e:
msg = "dimensions of columns arg must match data dimensions"
raise ValueError(msg) from e

data_dict = {}
# struct columns don't work properly if they contain multiple chunks.
struct_cols = {}
names = []
for i, column in enumerate(data):
# extract the name before casting
name = f"column_{i}" if column._name is None else column._name
names.append(name)

column = plc.coerce_arrow(column)
if (
isinstance(column.type, pa.StructType)
and hasattr(column, "num_chunks")
and column.num_chunks > 1
):
ps = plc.arrow_to_pyseries(name, column, rechunk=rechunk)
struct_cols[i] = wrap_s(ps)
else:
data_dict[name] = column
batches: list[pa.RecordBatch]
if isinstance(data, pa.RecordBatch):
batches = [data]
else:
batches = data.to_batches()

if len(data_dict) > 0:
tbl = pa.table(data_dict)
# supply the arrow schema so the metadata is intact
pydf = PyDataFrame.from_arrow_record_batches(batches, data.schema)

# path for table without rows that keeps datatype
pydf = PyDataFrame.from_arrow_record_batches(tbl.to_batches(), tbl.schema)
else:
pydf = pl.DataFrame([])._df
if rechunk:
pydf = pydf.rechunk()

def broadcastable_s(s: Series, name: str) -> Expr:
if s.len() == 1:
return F.lit(s).first().alias(name)
return F.lit(s).alias(name)

reset_order = False

if len(struct_cols) > 0:
df = wrap_df(pydf)
df = df.with_columns([broadcastable_s(s, s.name) for s in struct_cols.values()])
reset_order = True

if reset_order:
df = df[names]
pydf = df._df

if column_names != original_schema and (schema_overrides or original_schema):
if schema_overrides is not None:
pydf = _post_apply_columns(
pydf,
original_schema,
column_names,
schema_overrides=schema_overrides,
strict=strict,
)
elif schema_overrides:
for col, dtype in zip(pydf.columns(), pydf.dtypes()):
override_dtype = schema_overrides.get(col)
if override_dtype is not None and dtype != override_dtype:
pydf = _post_apply_columns(
pydf, original_schema, schema_overrides=schema_overrides
)
break

return pydf

Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def test_from_arrow(monkeypatch: Any) -> None:
assert df2.rows() == df.rows()[:3]

assert df0.schema == {"id": pl.String, "points": pl.Int64}
print(df1.schema)
assert df1.schema == {"x": pl.String, "y": pl.Int32}
assert df2.schema == {"x": pl.String, "y": pl.Int32}

Expand Down
10 changes: 0 additions & 10 deletions py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,16 +1286,6 @@ def test_nested_non_uniform_primitive() -> None:
test_round_trip(df)


def test_parquet_lexical_categorical() -> None:
# @TODO: This should be fixed
# This test shows that we don't handle saving the ordering properly in
# parquet files
df = pl.DataFrame({"a": [None]}, schema={"a": pl.Categorical(ordering="lexical")})

with pytest.raises(AssertionError):
test_round_trip(df)


def test_parquet_nested_struct_17933() -> None:
df = pl.DataFrame(
{"a": [{"x": {"u": None}, "y": True}]},
Expand Down

0 comments on commit e1f46a4

Please sign in to comment.