Skip to content

Commit

Permalink
Return head
Browse files Browse the repository at this point in the history
  • Loading branch information
khalidmammadov committed Oct 16, 2024
1 parent d960e63 commit a4a9c0b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 24 deletions.
7 changes: 7 additions & 0 deletions py-polars/polars/_utils/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,3 +629,10 @@ def re_escape(s: str) -> str:
# escapes _only_ those metachars with meaning to the rust regex crate
re_rust_metachars = r"\\?()|\[\]{}^$#&~.+*-"
return re.sub(f"([{re_rust_metachars}])", r"\\\1", s)


def try_head(seq: Sequence[Any] | Any, default: Any) -> Any:
try:
return seq[0]
except TypeError:
return default
15 changes: 3 additions & 12 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
issue_warning,
normalize_filepath,
parse_percentiles,
try_head,
)
from polars._utils.wrap import wrap_df, wrap_expr
from polars.datatypes import (
Expand Down Expand Up @@ -1370,18 +1371,8 @@ def sort(
"""
# Fast path for sorting by a single existing column
if isinstance(by, str) and not more_by:
if (isinstance(descending, list) and len(descending) != 1) or (
isinstance(nulls_last, list) and len(nulls_last) != 1
):
msg = (
"size of `descending` or `nulls_last` "
"must be 1 when defined as list"
)
raise ValueError(msg)
if isinstance(descending, list):
descending = descending[0]
if isinstance(nulls_last, list):
nulls_last = nulls_last[0]
descending = try_head(descending, descending)
nulls_last = try_head(nulls_last, nulls_last)
return self._from_pyldf(
self._ldf.sort(
by, descending, nulls_last, maintain_order, multithreaded
Expand Down
12 changes: 0 additions & 12 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,18 +362,6 @@ def test_sort_multi_output_exprs_01() -> None:
):
df.sort("dts", "strs", nulls_last=[True, False, True])

with pytest.raises(
ValueError,
match="size of `descending` or `nulls_last` must be 1 when defined as list",
):
df.sort(by="dts", descending=[True, False])

with pytest.raises(
ValueError,
match="size of `descending` or `nulls_last` must be 1 when defined as list",
):
df.sort(by="dts", nulls_last=[True, False])

# No columns selected - return original input.
assert_frame_equal(df, df.sort(pl.col("^xxx$")))

Expand Down

0 comments on commit a4a9c0b

Please sign in to comment.