Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Support List/Array in search_sorted(), and fix edge case where length 1 Series of array couldn't use index_of() #21266

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
27 changes: 27 additions & 0 deletions crates/polars-ops/src/series/ops/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use polars_core::chunked_array::ops::search_sorted::{binary_search_ca, SearchSortedSide};
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_polars_type;
use row_encode::encode_rows_unordered;

pub fn search_sorted(
s: &Series,
Expand All @@ -15,6 +16,14 @@ pub fn search_sorted(
polars_bail!(InvalidOperation: "'search_sorted' is not supported on dtype: {}", s.dtype())
}

let (s, search_values) = if s.dtype().is_array() || s.dtype().is_list() {
let s = encode_rows_unordered(&[s.clone().into_column()])?;
let search_values = encode_rows_unordered(&[search_values.clone().into_column()])?;
(&s.into_series(), &search_values.into_series())
} else {
(s, search_values)
};

let s = s.to_physical_repr();
let phys_dtype = s.dtype();

Expand Down Expand Up @@ -66,6 +75,24 @@ pub fn search_sorted(

Ok(IdxCa::new_vec(s.name().clone(), idx))
},
DataType::BinaryOffset => {
let ca = s.binary_offset().unwrap();

let idx = match search_values.dtype() {
DataType::BinaryOffset => {
let search_values = search_values.binary_offset().unwrap();
binary_search_ca(ca, search_values.iter(), side, descending)
},
DataType::Binary => {
let search_values = search_values.binary().unwrap();
binary_search_ca(ca, search_values.iter(), side, descending)
},
_ => unreachable!(),
};

Ok(IdxCa::new_vec(s.name().clone(), idx))
},

dt if dt.is_primitive_numeric() => {
let search_values = search_values.to_physical_repr();

Expand Down
6 changes: 1 addition & 5 deletions crates/polars-plan/src/dsl/function_expr/index_of.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ pub(super) fn index_of(s: &mut [Column]) -> PolarsResult<Column> {
let result = match is_sorted_flag {
// If the Series is sorted, we can use an optimized binary search to
// find the value.
IsSorted::Ascending | IsSorted::Descending
if !needle.is_null() &&
// search_sorted() doesn't support decimals at the moment.
!series.dtype().is_decimal() =>
{
IsSorted::Ascending | IsSorted::Descending if !needle.is_null() => {
search_sorted(
series,
needle_s.as_materialized_series(),
Expand Down
4 changes: 1 addition & 3 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,7 @@ impl Expr {
collect_groups: ApplyOptions::GroupWise,
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
fmt_str: "search_sorted",
cast_options: Some(CastingRules::Supertype(
(SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING).into(),
)),
cast_options: Some(CastingRules::FirstArgLossless),
..Default::default()
},
}
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2382,7 +2382,7 @@ def search_sorted(
│ 0 ┆ 2 ┆ 4 │
└──────┴───────┴─────┘
"""
element = parse_into_expression(element, str_as_lit=True, list_as_series=True) # type: ignore[arg-type]
element = parse_into_expression(element, str_as_lit=True, list_as_series=False) # type: ignore[arg-type]
return self._from_pyexpr(self._pyexpr.search_sorted(element, side))

def sort_by(
Expand Down
62 changes: 61 additions & 1 deletion py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3456,6 +3456,14 @@ def search_sorted(
----------
element
Expression or scalar value.
If this value matches the dtype of values in self, the return result is an
integer.
If self's dtype is ``pl.List``/``pl.Array``, we assume an element that is
``list`` or NumPy array is a single value, and return an integer. For other
dtypes, a ``list``/NumPy array element is assumed to be searching for
multiple values, and the return result is a ``Series``.
If this is a ``Series`` or ``Expr``, the return result is a ``Series``.

side : {'any', 'left', 'right'}
If 'any', the index of the first suitable location found is given.
If 'left', the index of the leftmost suitable location found is given.
Expand Down Expand Up @@ -3495,8 +3503,60 @@ def search_sorted(
6
]
"""
# A list or ndarray passed to search_sorted() has two possible meanings:
#
# 1. Searching for multiple values, for example the Series is Int64 and
# we're searching for multiple integers.
# 2. Searching for a single value, when the Series dtype is List or
# Array.
#
# Depending which it is, we need to return either a Series (i.e.
# multiple results), or a single integer. We can mostly distinguish
# which case it is by casting to the dtype of self: if that succeeds,
# we assume it's case 2, if it fails, we assume case 1. There is still
# an ambiguous case, though:
#
# Series([...], dtype=pl.List(pl.List(pl.Int64()))).search_sorted([])
#
# Does this mean searching multiple values, which should return a
# pl.Series of length 0, or does it mean searching for a single empty
# list and it should return an integer? Who can say! Arguably this API
# design was a mistake once you allow searching pl.List series, and
# probably the solution is deprecate searching for multiple values with
# lists, and force people to use Series for that case.
#
# For now, we disallow searching for multiple values via lists when
# self's dtype is pl.List or pl.Array, so that we have a non-ambiguous
# API.
if isinstance(element, (list, np.ndarray)):
if isinstance(self.dtype, (List, Array)):
# Catch (most) disallowed multi-value-search cases by casting
# the needle to the haystack's dtype:
try:
F.select(F.lit(element, dtype=self.dtype))
except TypeError as err:
message = (
f"{element} does not match dtype {self.dtype}. "
"If you were trying to search for multiple values, "
"use a ``pl.Series`` instead of a list/ndarray."
)
raise TypeError(message) from err
else:
# We're definitely searching for multiple values. Wrap in
# Series so we don't have issues with casting:
element = pl.Series(element)

df = F.select(F.lit(self).search_sorted(element, side))
if isinstance(element, (list, Series, pl.Expr, np.ndarray)):
# These types unambiguously return a Series:
#
# * Series means we want to search for multiple values.
# * Expr because that always returns a Series, matching Expr.search_sorted().
if isinstance(element, (Series, pl.Expr)):
return df.to_series()

if isinstance(element, (list, np.ndarray)) and not isinstance(
self.dtype, (List, Array)
):
return df.to_series()
else:
return df.item()
Expand Down
5 changes: 5 additions & 0 deletions py-polars/tests/unit/operations/test_index_of.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ def test_other_types(
[
series.sort(descending=False),
series.sort(descending=True),
# Length 1 series are marked as sorted; this catches regression
# in issue #21100:
pl.Series(series.to_list()[:1], dtype=series.dtype),
]
)
for s in series_variants:
Expand Down Expand Up @@ -317,6 +320,8 @@ def test_enum(convert_to_literal: bool) -> None:
series.drop_nulls(),
series.sort(descending=False),
series.sort(descending=True),
# Length 1 series, to check for #21100:
pl.Series(["a"], dtype=pl.Enum(["c", "b", "a"])),
]:
for value in expected_values:
assert_index_of(s, value, convert_to_literal=convert_to_literal)
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/unit/operations/test_search_sorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,28 @@ def test_raise_literal_numeric_search_sorted_18096() -> None:

with pytest.raises(pl.exceptions.InvalidOperationError):
df.with_columns(idx=pl.col("foo").search_sorted("bar"))


def test_search_sorted_list() -> None:
series = pl.Series([[1], [2], [3]])
assert series.search_sorted([2]) == 1
assert series.search_sorted(pl.Series([[3], [2]])).to_list() == [2, 1]
assert series.search_sorted(pl.lit([3], dtype=pl.List(pl.Int64()))).to_list() == [2]
with pytest.raises(
TypeError, match="If you were trying to search for multiple values"
):
series.search_sorted([[1]]) # type: ignore[list-item]


def test_search_sorted_array() -> None:
dtype = pl.Array(pl.Int64(), 1)
series = pl.Series([[1], [2], [3]], dtype=dtype)
assert series.index_of([2]) == 1
assert series.search_sorted([2]) == 1
assert series.search_sorted(pl.Series([[3], [2]], dtype=dtype)).to_list() == [2, 1]
assert series.search_sorted(pl.lit([3])).to_list() == [2]
assert series.search_sorted(pl.lit([3], dtype=dtype)).to_list() == [2]
with pytest.raises(
TypeError, match="If you were trying to search for multiple values"
):
series.search_sorted([[1]]) # type: ignore[list-item]
Loading