Skip to content

Commit

Permalink
fix: Incorrect lazy schema for Expr.list.diff (#21484)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley authored Feb 27, 2025
1 parent bf1b47f commit bee578b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
16 changes: 15 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,21 @@ impl ListFunction {
ArgMin => mapper.with_dtype(IDX_DTYPE),
ArgMax => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "diff")]
Diff { .. } => mapper.with_same_dtype(),
Diff { .. } => mapper.map_dtype(|dt| {
let inner_dt = match dt.inner_dtype().unwrap() {
#[cfg(feature = "dtype-datetime")]
DataType::Datetime(tu, _) => DataType::Duration(*tu),
#[cfg(feature = "dtype-date")]
DataType::Date => DataType::Duration(TimeUnit::Milliseconds),
#[cfg(feature = "dtype-time")]
DataType::Time => DataType::Duration(TimeUnit::Nanoseconds),
DataType::UInt64 | DataType::UInt32 => DataType::Int64,
DataType::UInt16 => DataType::Int32,
DataType::UInt8 => DataType::Int16,
inner_dt => inner_dt.clone(),
};
DataType::List(Box::new(inner_dt))
}),
Sort(_) => mapper.with_same_dtype(),
Reverse => mapper.with_same_dtype(),
Unique(_) => mapper.with_same_dtype(),
Expand Down
26 changes: 26 additions & 0 deletions py-polars/tests/unit/operations/namespaces/list/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,3 +1001,29 @@ def test_list_agg_all_null(
) -> None:
s = pl.Series([None, None], dtype=pl.List(inner_dtype))
assert getattr(s.list, agg)().dtype == expected_dtype


@pytest.mark.parametrize(
("inner_dtype", "expected_inner_dtype"),
[
(pl.Datetime("us"), pl.Duration("us")),
(pl.Date(), pl.Duration("ms")),
(pl.Time(), pl.Duration("ns")),
(pl.UInt64(), pl.Int64()),
(pl.UInt32(), pl.Int64()),
(pl.UInt8(), pl.Int16()),
(pl.Int8(), pl.Int8()),
(pl.Float32(), pl.Float32()),
],
)
def test_list_diff_schema(
inner_dtype: PolarsDataType, expected_inner_dtype: PolarsDataType
) -> None:
lf = (
pl.LazyFrame({"a": [[1, 2]]})
.cast(pl.List(inner_dtype))
.select(pl.col("a").list.diff(1))
)
expected = {"a": pl.List(expected_inner_dtype)}
assert lf.collect_schema() == expected
assert lf.collect().schema == expected

0 comments on commit bee578b

Please sign in to comment.