From a4a9c0b1a96d273761c74d6f76516ac31a02ef34 Mon Sep 17 00:00:00 2001 From: khalidmammadov Date: Wed, 16 Oct 2024 20:56:30 +0100 Subject: [PATCH] Return head --- py-polars/polars/_utils/various.py | 7 +++++++ py-polars/polars/lazyframe/frame.py | 15 +++------------ py-polars/tests/unit/dataframe/test_df.py | 12 ------------ 3 files changed, 10 insertions(+), 24 deletions(-) diff --git a/py-polars/polars/_utils/various.py b/py-polars/polars/_utils/various.py index 4acad1df5237..04d5e0211a00 100644 --- a/py-polars/polars/_utils/various.py +++ b/py-polars/polars/_utils/various.py @@ -629,3 +629,10 @@ def re_escape(s: str) -> str: # escapes _only_ those metachars with meaning to the rust regex crate re_rust_metachars = r"\\?()|\[\]{}^$#&~.+*-" return re.sub(f"([{re_rust_metachars}])", r"\\\1", s) + + +def try_head(seq: Sequence[Any] | Any, default: Any) -> Any: + try: + return seq[0] + except TypeError: + return default diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 2d5dd67d01c5..54524685ae9c 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -45,6 +45,7 @@ issue_warning, normalize_filepath, parse_percentiles, + try_head, ) from polars._utils.wrap import wrap_df, wrap_expr from polars.datatypes import ( @@ -1370,18 +1371,8 @@ def sort( """ # Fast path for sorting by a single existing column if isinstance(by, str) and not more_by: - if (isinstance(descending, list) and len(descending) != 1) or ( - isinstance(nulls_last, list) and len(nulls_last) != 1 - ): - msg = ( - "size of `descending` or `nulls_last` " - "must be 1 when defined as list" - ) - raise ValueError(msg) - if isinstance(descending, list): - descending = descending[0] - if isinstance(nulls_last, list): - nulls_last = nulls_last[0] + descending = try_head(descending, descending) + nulls_last = try_head(nulls_last, nulls_last) return self._from_pyldf( self._ldf.sort( by, descending, nulls_last, maintain_order, multithreaded diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 3c0e5bc6894c..4c5953d6c249 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -362,18 +362,6 @@ def test_sort_multi_output_exprs_01() -> None: ): df.sort("dts", "strs", nulls_last=[True, False, True]) - with pytest.raises( - ValueError, - match="size of `descending` or `nulls_last` must be 1 when defined as list", - ): - df.sort(by="dts", descending=[True, False]) - - with pytest.raises( - ValueError, - match="size of `descending` or `nulls_last` must be 1 when defined as list", - ): - df.sort(by="dts", nulls_last=[True, False]) - # No columns selected - return original input. assert_frame_equal(df, df.sort(pl.col("^xxx$")))