diff --git a/py-polars/tests/unit/test_scalar.py b/py-polars/tests/unit/test_scalar.py index 0fdb74cb8843..b00d2a66e043 100644 --- a/py-polars/tests/unit/test_scalar.py +++ b/py-polars/tests/unit/test_scalar.py @@ -84,3 +84,21 @@ def test_scalar_identification_function_expr_in_binary() -> None: def test_scalar_rechunk_20627() -> None: df = pl.concat(2 * [pl.Series([1])]).filter(pl.Series([False, True])).to_frame() assert df.rechunk().to_series().n_chunks() == 1 + + +def test_split_scalar_21581() -> None: + df = pl.DataFrame({"a": [1.0, 2.0, 3.0]}) + df = df.with_columns( + [ + pl.col("a").shift(-1).alias("next_a"), + pl.lit(True).alias("lit"), + ] + ) + + assert df.filter(df["next_a"] != 99.0).with_columns( + [pl.lit(False).alias("lit")] + ).to_dict(as_series=False) == { + "a": [1.0, 2.0], + "next_a": [2.0, 3.0], + "lit": [False, False], + }