Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow for rolling_*_by to use index count as window #19071

Merged
merged 1 commit into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,19 @@ where
by.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?,
&None,
),
DataType::Int64 => (
by.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?,
&None,
),
DataType::Int32 | DataType::UInt64 | DataType::UInt32 => (
by.cast(&DataType::Int64)?
.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?,
&None,
),
dt => polars_bail!(InvalidOperation:
"in `rolling_*_by` operation, `by` argument of dtype `{}` is not supported (expected `{}`)",
dt,
"date/datetime"),
"Date/Datetime/Int64/Int32/UInt64/UInt32"),
};
let ca = ca.rechunk();
let by = by.rechunk();
Expand Down
8 changes: 8 additions & 0 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6210,6 +6210,7 @@ def rolling_min_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)

By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -6331,6 +6332,7 @@ def rolling_max_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)

By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -6478,6 +6480,7 @@ def rolling_mean_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)

By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -6630,6 +6633,7 @@ def rolling_sum_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)

By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -6780,6 +6784,7 @@ def rolling_std_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)

By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -6936,6 +6941,7 @@ def rolling_var_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)

By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -7091,6 +7097,7 @@ def rolling_median_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)

By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down Expand Up @@ -7220,6 +7227,7 @@ def rolling_quantile_by(
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
- 1i (1 index count)

By "calendar day", we mean the corresponding time on the next day
(which may not be 24 hours, due to daylight savings). Similarly for
Expand Down
14 changes: 12 additions & 2 deletions py-polars/tests/unit/operations/rolling/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,10 @@ def test_rolling_crossing_dst(


def test_rolling_by_invalid() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).sort("a")
msg = r"in `rolling_\*_by` operation, `by` argument of dtype `i64` is not supported"
df = pl.DataFrame(
{"a": [1, 2, 3], "b": [4, 5, 6]}, schema_overrides={"a": pl.Int16}
).sort("a")
msg = "unsupported data type: i16 for `window_size`, expected UInt64, UInt32, Int64, Int32, Datetime, Date, Duration, or Time"
with pytest.raises(InvalidOperationError, match=msg):
df.select(pl.col("b").rolling_min_by("a", "2i"))
df = pl.DataFrame({"a": [1, 2, 3], "b": [date(2020, 1, 1)] * 3}).sort("b")
Expand Down Expand Up @@ -818,6 +820,14 @@ def test_rolling_by_date() -> None:
assert_frame_equal(result, expected)


@pytest.mark.parametrize("dtype", [pl.Int64, pl.Int32, pl.UInt64, pl.UInt32])
def test_rolling_by_integer(dtype: PolarsDataType) -> None:
df = pl.DataFrame({"val": [1, 2, 3]}).with_row_index()
Comment on lines +824 to +825
Copy link
Contributor

@HansBambel HansBambel Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I missing something or is the parameterized dtype never used?

Shouldn't it be something like this:

df = pl.DataFrame({"val": [1, 2, 3]}, schema={"val": dtype}).with_row_index()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @HansBambel , the intention was probably

df = pl.DataFrame({"val": [1, 2, 3]}).with_row_index().with_columns(pl.col('index').cast(dtype))

fancy making a pr to address this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli I'll do that on the weekend at the latest!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Managed to do it now already: #19555

Local tests were failing because connectorx was missing even after running make requirements-all. Not sure why

result = df.with_columns(roll=pl.col("val").rolling_sum_by("index", "2i"))
expected = df.with_columns(roll=pl.Series([1, 3, 5]))
assert_frame_equal(result, expected)


def test_rolling_nanoseconds_11003() -> None:
df = pl.DataFrame(
{
Expand Down
Loading