From a32d2564aa0704b16298bb45c93ce6ad81d3ad00 Mon Sep 17 00:00:00 2001 From: barak1412 Date: Sun, 6 Oct 2024 18:44:29 +0300 Subject: [PATCH] working version with fixed test --- crates/polars-core/src/series/mod.rs | 4 ++-- crates/polars-expr/src/expressions/mod.rs | 6 +++++- py-polars/tests/unit/dataframe/test_df.py | 8 ++++++++ py-polars/tests/unit/functions/range/test_time_range.py | 6 +++--- 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 72cb3b67dc41..b32a7a348f09 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -528,13 +528,13 @@ impl Series { Ok(max) } - /// Explode a list Series. This expands every item to a new row.. + /// Explode a list/array Series. This expands every item to a new row. pub fn explode(&self) -> PolarsResult { match self.dtype() { DataType::List(_) => self.list().unwrap().explode(), #[cfg(feature = "dtype-array")] DataType::Array(_, _) => self.array().unwrap().explode(), - _ => Ok(self.clone()), + dt => polars_bail!(opq = explode, dt), } } diff --git a/crates/polars-expr/src/expressions/mod.rs b/crates/polars-expr/src/expressions/mod.rs index 8a74033953dc..f814b8a20a5a 100644 --- a/crates/polars-expr/src/expressions/mod.rs +++ b/crates/polars-expr/src/expressions/mod.rs @@ -467,7 +467,11 @@ impl<'a> AggregationContext<'a> { AggState::AggregatedScalar(s) => (s, groups), AggState::Literal(s) => (s, groups), AggState::AggregatedList(s) => { - let flattened = s.explode().unwrap(); + // flat only if we have a list + let flattened = match s.dtype() { + DataType::List(_) => s.explode().unwrap(), + _ => s.clone(), + }; let groups = groups.into_owned(); // unroll the possible flattened state // say we have groups with overlapping windows: diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 015b64ce6303..d9f167739c2f 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -621,6 +621,14 @@ def test_explode() -> None: assert out["nrs"].to_list() == [1, 2, 1, 3] +def test_illegal_explode_19049() -> None: + df = pl.DataFrame({"a": [0]}, schema={"a": pl.Int64}) + with pytest.raises( + InvalidOperationError, match="`explode` operation not supported for dtype `i64`" + ): + df.select(pl.col.a.explode()) + + @pytest.mark.parametrize( ("stack", "exp_shape", "exp_columns"), [ diff --git a/py-polars/tests/unit/functions/range/test_time_range.py b/py-polars/tests/unit/functions/range/test_time_range.py index 63bf5e0ec73f..2c2fdb5a169f 100644 --- a/py-polars/tests/unit/functions/range/test_time_range.py +++ b/py-polars/tests/unit/functions/range/test_time_range.py @@ -163,7 +163,7 @@ def test_time_range_lit_eager() -> None: ).alias("tm") ) if not eager: - tm = tm.select(pl.col("tm").explode()) + tm = tm.select(pl.col("tm")) assert tm["tm"].to_list() == [ time(6, 47, 13, 333000), time(12, 32, 23, 666000), @@ -178,7 +178,7 @@ def test_time_range_lit_eager() -> None: ).alias("tm") ) if not eager: - tm = tm.select(pl.col("tm").explode()) + tm = tm.select(pl.col("tm")) assert tm["tm"].to_list() == [ time(0, 0), time(5, 45, 10, 333000), @@ -194,7 +194,7 @@ def test_time_range_lit_eager() -> None: eager=eager, ).alias("tm") ) - tm = tm.select(pl.col("tm").explode()) + tm = tm.select(pl.col("tm")) assert tm["tm"].to_list() == [ time(23, 59, 59, 999980), time(23, 59, 59, 999990),