Skip to content

Commit

Permalink
fix: fix pivot when multiple columns are passed. Output is now aligne…
Browse files Browse the repository at this point in the history
…d with what tidyverse / pandas.pivot_table would do (#14048)

Co-authored-by: ritchie <ritchie46@gmail.com>
  • Loading branch information
MarcoGorelli and ritchie46 authored Jan 30, 2024
1 parent 9e87515 commit 1683ea1
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 140 deletions.
6 changes: 6 additions & 0 deletions crates/polars-core/src/chunked_array/logical/struct_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use arrow::legacy::trusted_len::TrustedLenPush;
use arrow::offset::OffsetsBuffer;
use smartstring::alias::String as SmartString;

use self::sort::arg_sort_multiple::_get_rows_encoded_ca;
use super::*;
use crate::datatypes::*;
use crate::utils::index_to_chunked_index;
Expand Down Expand Up @@ -411,6 +412,11 @@ impl StructChunked {
}
self.cast_impl(dtype, true)
}

pub fn rows_encode(&self) -> PolarsResult<BinaryOffsetChunked> {
let descending = vec![false; self.fields.len()];
_get_rows_encoded_ca(self.name(), &self.fields, &descending, false)
}
}

impl LogicalType for StructChunked {
Expand Down
224 changes: 131 additions & 93 deletions crates/polars-ops/src/frame/pivot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,117 +187,155 @@ fn pivot_impl(
// used as separator/delimiter in generated column names.
separator: Option<&str>,
) -> PolarsResult<DataFrame> {
let sep = separator.unwrap_or("_");
polars_ensure!(!index.is_empty(), ComputeError: "index cannot be zero length");
polars_ensure!(!columns.is_empty(), ComputeError: "columns cannot be zero length");
if !stable {
println!("unstable pivot not yet supported, using stable pivot");
};
if columns.len() > 1 {
let schema = Arc::new(pivot_df.schema());
let binding = pivot_df.select_with_schema(columns, &schema)?;
let fields = binding.get_columns();
let column = format!("{{\"{}\"}}", columns.join("\",\""));
if schema.contains(column.as_str()) {
polars_bail!(ComputeError: "cannot use column name {column} that \
already exists in the DataFrame. Please rename it prior to calling `pivot`.")
}
let columns_struct = StructChunked::new(&column, fields).unwrap().into_series();
let mut binding = pivot_df.clone();
let pivot_df = unsafe { binding.with_column_unchecked(columns_struct) };
pivot_impl_single_column(
pivot_df,
&column,
values,
index,
agg_fn,
sort_columns,
separator,
)
} else {
pivot_impl_single_column(
pivot_df,
unsafe { columns.get_unchecked(0) },
values,
index,
agg_fn,
sort_columns,
separator,
)
}
}

fn pivot_impl_single_column(
pivot_df: &DataFrame,
column: &str,
values: &[String],
index: &[String],
agg_fn: Option<PivotAgg>,
sort_columns: bool,
separator: Option<&str>,
) -> PolarsResult<DataFrame> {
let sep = separator.unwrap_or("_");
let mut final_cols = vec![];

let mut count = 0;
let out: PolarsResult<()> = POOL.install(|| {
for column_column_name in columns {
let mut group_by = index.to_vec();
group_by.push(column_column_name.clone());
let mut group_by = index.to_vec();
group_by.push(column.to_string());

let groups = pivot_df.group_by_stable(group_by)?.take_groups();
let groups = pivot_df.group_by_stable(group_by)?.take_groups();

// these are the row locations
if !stable {
println!("unstable pivot not yet supported, using stable pivot");
};

let (col, row) = POOL.join(
|| positioning::compute_col_idx(pivot_df, column_column_name, &groups),
|| positioning::compute_row_idx(pivot_df, index, &groups, count),
);
let (col_locations, column_agg) = col?;
let (row_locations, n_rows, mut row_index) = row?;
let (col, row) = POOL.join(
|| positioning::compute_col_idx(pivot_df, column, &groups),
|| positioning::compute_row_idx(pivot_df, index, &groups, count),
);
let (col_locations, column_agg) = col?;
let (row_locations, n_rows, mut row_index) = row?;

for value_col_name in values {
let value_col = pivot_df.column(value_col_name)?;
for value_col_name in values {
let value_col = pivot_df.column(value_col_name)?;

use PivotAgg::*;
let value_agg = unsafe {
match &agg_fn {
None => match value_col.len() > groups.len() {
true => polars_bail!(ComputeError: "found multiple elements in the same group, please specify an aggregation function"),
false => value_col.agg_first(&groups),
}
Some(agg_fn) => match agg_fn {
Sum => value_col.agg_sum(&groups),
Min => value_col.agg_min(&groups),
Max => value_col.agg_max(&groups),
Last => value_col.agg_last(&groups),
First => value_col.agg_first(&groups),
Mean => value_col.agg_mean(&groups),
Median => value_col.agg_median(&groups),
Count => groups.group_count().into_series(),
Expr(ref expr) => {
let name = expr.root_name()?;
let mut value_col = value_col.clone();
value_col.rename(name);
let tmp_df = DataFrame::new_no_checks(vec![value_col]);
let mut aggregated = expr.evaluate(&tmp_df, &groups)?;
aggregated.rename(value_col_name);
aggregated
}
},
use PivotAgg::*;
let value_agg = unsafe {
match &agg_fn {
None => match value_col.len() > groups.len() {
true => polars_bail!(ComputeError: "found multiple elements in the same group, please specify an aggregation function"),
false => value_col.agg_first(&groups),
}
};

let headers = column_agg.unique_stable()?.cast(&DataType::String)?;
let mut headers = headers.str().unwrap().clone();
if values.len() > 1 {
headers = headers.apply_values(|v| Cow::from(format!("{value_col_name}{sep}{column_column_name}{sep}{v}")))
Some(agg_fn) => match agg_fn {
Sum => value_col.agg_sum(&groups),
Min => value_col.agg_min(&groups),
Max => value_col.agg_max(&groups),
Last => value_col.agg_last(&groups),
First => value_col.agg_first(&groups),
Mean => value_col.agg_mean(&groups),
Median => value_col.agg_median(&groups),
Count => groups.group_count().into_series(),
Expr(ref expr) => {
let name = expr.root_name()?;
let mut value_col = value_col.clone();
value_col.rename(name);
let tmp_df = DataFrame::new_no_checks(vec![value_col]);
let mut aggregated = expr.evaluate(&tmp_df, &groups)?;
aggregated.rename(value_col_name);
aggregated
}
},
}
};

let n_cols = headers.len();
let value_agg_phys = value_agg.to_physical_repr();
let logical_type = value_agg.dtype();
let headers = column_agg.unique_stable()?.cast(&DataType::String)?;
let mut headers = headers.str().unwrap().clone();
if values.len() > 1 {
headers = headers.apply_values(|v| Cow::from(format!("{value_col_name}{sep}{v}")))
}

debug_assert_eq!(row_locations.len(), col_locations.len());
debug_assert_eq!(value_agg_phys.len(), row_locations.len());
let n_cols = headers.len();
let value_agg_phys = value_agg.to_physical_repr();
let logical_type = value_agg.dtype();

let mut cols = if value_agg_phys.dtype().is_numeric() {
macro_rules! dispatch {
($ca:expr) => {{
positioning::position_aggregates_numeric(
n_rows,
n_cols,
&row_locations,
&col_locations,
$ca,
logical_type,
&headers,
)
}};
}
downcast_as_macro_arg_physical!(value_agg_phys, dispatch)
} else {
positioning::position_aggregates(
n_rows,
n_cols,
&row_locations,
&col_locations,
&value_agg_phys,
logical_type,
&headers,
)
};
debug_assert_eq!(row_locations.len(), col_locations.len());
debug_assert_eq!(value_agg_phys.len(), row_locations.len());

if sort_columns {
cols.sort_unstable_by(|a, b| a.name().partial_cmp(b.name()).unwrap());
let mut cols = if value_agg_phys.dtype().is_numeric() {
macro_rules! dispatch {
($ca:expr) => {{
positioning::position_aggregates_numeric(
n_rows,
n_cols,
&row_locations,
&col_locations,
$ca,
logical_type,
&headers,
)
}};
}
downcast_as_macro_arg_physical!(value_agg_phys, dispatch)
} else {
positioning::position_aggregates(
n_rows,
n_cols,
&row_locations,
&col_locations,
&value_agg_phys,
logical_type,
&headers,
)
};

let cols = if count == 0 {
let mut final_cols = row_index.take().unwrap();
final_cols.extend(cols);
final_cols
} else {
cols
};
count += 1;
final_cols.extend_from_slice(&cols);
if sort_columns {
cols.sort_unstable_by(|a, b| a.name().partial_cmp(b.name()).unwrap());
}

let cols = if count == 0 {
let mut final_cols = row_index.take().unwrap();
final_cols.extend(cols);
final_cols
} else {
cols
};
count += 1;
final_cols.extend_from_slice(&cols);
}
Ok(())
});
Expand Down
Loading

0 comments on commit 1683ea1

Please sign in to comment.