Skip to content

Commit

Permalink
working with test
Browse files Browse the repository at this point in the history
  • Loading branch information
barak1412 committed Oct 14, 2024
1 parent 57b8f23 commit cd34234
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 5 deletions.
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/function_expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl ArrayFunction {
#[cfg(feature = "array_count")]
CountMatches => mapper.with_dtype(IDX_DTYPE),
Shift => mapper.with_same_dtype(),
Explode => mapper.map_to_list_and_array_inner_dtype(),
Explode => mapper.try_map_to_array_inner_dtype(),
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,16 @@ impl<'a> FieldsMapper<'a> {
Ok(first)
}

#[cfg(feature = "dtype-array")]
/// Map the dtype to the dtype of the array elements, with typo validation.
pub fn try_map_to_array_inner_dtype(&self) -> PolarsResult<Field> {
let dt = self.fields[0].dtype();
match dt {
DataType::Array(_, _) => self.map_to_list_and_array_inner_dtype(),
_ => polars_bail!(InvalidOperation: "expected Array type, got: {}", dt),
}
}

/// Map the dtypes to the "supertype" of a list of lists.
pub fn map_to_list_supertype(&self) -> PolarsResult<Field> {
self.try_map_dtypes(|dts| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ pub(super) fn optimize_functions(
expr_arena: &mut Arena<AExpr>,
) -> PolarsResult<Option<AExpr>> {
let out = match function {
// arr.explode() -> explode
#[cfg(feature = "dtype-array")]
// arr.explode() -> explode()
FunctionExpr::ArrayExpr(ArrayFunction::Explode) => {
let input_input = input[0].node();
Some(AExpr::Explode(input_input))
let input_node = input[0].node();
Some(AExpr::Explode(input_node))
},
// is_null().any() -> null_count() > 0
// is_not_null().any() -> null_count() < len()
Expand Down
13 changes: 12 additions & 1 deletion py-polars/tests/unit/operations/namespaces/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.exceptions import ComputeError, InvalidOperationError
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -449,3 +449,14 @@ def test_array_n_unique() -> None:
{"n_unique": [2, 1, 1, None]}, schema={"n_unique": pl.UInt32}
)
assert_frame_equal(out, expected)


def test_explode_19049() -> None:
df = pl.DataFrame({"a": [[1, 2, 3]]}, schema={"a": pl.Array(pl.Int64, 3)})
result_df = df.select(pl.col.a.arr.explode())
expected_df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.Int64})
assert_frame_equal(result_df, expected_df)

df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.Int64})
with pytest.raises(InvalidOperationError, match="expected Array type, got: i64"):
df.select(pl.col.a.arr.explode())

0 comments on commit cd34234

Please sign in to comment.