Skip to content

Commit

Permalink
simplify code and add some more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Mar 6, 2025
1 parent 37d1dfb commit 94cb400
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
24 changes: 9 additions & 15 deletions crates/polars-plan/src/dsl/function_expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,26 +156,20 @@ impl From<ArrayFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
pub(super) fn length(s: &Column) -> PolarsResult<Column> {
let array = s.array()?;
let width = array.width();
let width = IdxSize::try_from(width)
.map_err(|_| polars_err!(bigidx, ctx = "array length", size = width))?;

let width_val = match IDX_DTYPE {
DataType::UInt32 => AnyValue::UInt32(width as u32),
DataType::UInt64 => AnyValue::UInt64(width as u64),
_ => unreachable!("IDX_DTYPE should be UInt32 or UInt64"),
};

let mut c = Column::new_scalar(
array.name().clone(),
Scalar::new(IDX_DTYPE, width_val),
array.len(),
);

let mut c = Column::new_scalar(array.name().clone(), width.into(), array.len());
if let Some(validity) = array.rechunk_validity() {
let mut series = c.into_materialized_series().clone();

// SAFETY: We keep datatypes intact and call compute_len afterwards.
let chunks = unsafe { series.chunks_mut() };
let arr = &mut chunks[0];
*arr = arr.with_validity(Some(validity));
series.compute_len();
assert_eq!(chunks.len(), 1);

chunks[0] = chunks[0].with_validity(Some(validity));

series.compute_len();
c = series.into_column();
}

Expand Down
22 changes: 19 additions & 3 deletions py-polars/tests/unit/operations/namespaces/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,31 @@ def test_arr_sum(


def test_array_lengths() -> None:
s1 = pl.Series("a", [[1, 2, 3]], dtype=pl.Array(pl.Int64, 3))
s2 = pl.Series("b", [[4, 5]], dtype=pl.Array(pl.Int64, 2))
df = pl.DataFrame([s1, s2])
df = pl.DataFrame(
[
pl.Series("a", [[1, 2, 3]], dtype=pl.Array(pl.Int64, 3)),
pl.Series("b", [[4, 5]], dtype=pl.Array(pl.Int64, 2)),
]
)
out = df.select(pl.col("a").arr.len(), pl.col("b").arr.len())
expected_df = pl.DataFrame(
{"a": [3], "b": [2]}, schema={"a": pl.UInt32, "b": pl.UInt32}
)
assert_frame_equal(out, expected_df)

assert pl.Series("a", [[], []], pl.Array(pl.Null, 0)).arr.len().to_list() == [0, 0]
assert pl.Series("a", [None, []], pl.Array(pl.Null, 0)).arr.len().to_list() == [
None,
0,
]
assert pl.Series("a", [None], pl.Array(pl.Null, 0)).arr.len().to_list() == [None]

assert pl.Series("a", [], pl.Array(pl.Null, 0)).arr.len().to_list() == []
assert pl.Series("a", [], pl.Array(pl.Null, 1)).arr.len().to_list() == []
assert pl.Series(
"a", [[1, 2, 3], None, [7, 8, 9]], pl.Array(pl.Int32, 3)
).arr.len().to_list() == [3, None, 3]


def test_arr_unique() -> None:
df = pl.DataFrame(
Expand Down

0 comments on commit 94cb400

Please sign in to comment.