Skip to content

Commit

Permalink
fix: implement Serialize and Deserialize for `RollingOptionsFixed…
Browse files Browse the repository at this point in the history
…Window` and `RollingOptionsDynamicWindow`
  • Loading branch information
3ok committed Sep 14, 2024
1 parent 962b576 commit 71eacbb
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 17 deletions.
37 changes: 35 additions & 2 deletions crates/polars-arrow/src/legacy/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use std::sync::Arc;
use num_traits::{Bounded, Float, NumCast, One, Zero};
use polars_utils::float::IsFloat;
use polars_utils::ord::{compare_fn_nan_max, compare_fn_nan_min};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use window::*;

use crate::array::{ArrayRef, PrimitiveArray};
Expand All @@ -25,6 +27,35 @@ type WindowSize = usize;
type Len = usize;
pub type DynArgs = Option<Arc<dyn Any + Sync + Send>>;

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum RollingFnParams {
Quantile(RollingQuantileParams),
Var(RollingVarParams),
}

impl RollingFnParams {
pub fn from_dyn_args(dyn_args: &DynArgs) -> Option<Self> {
dyn_args.as_ref().and_then(|params| {
params
.downcast_ref::<RollingQuantileParams>()
.map(|params| RollingFnParams::Quantile(*params))
.or_else(|| {
params
.downcast_ref::<RollingVarParams>()
.map(|params| RollingFnParams::Var(*params))
})
})
}

pub fn to_dyn_args(&self) -> DynArgs {
match self {
RollingFnParams::Quantile(params) => Some(Arc::new(*params)),
RollingFnParams::Var(params) => Some(Arc::new(*params)),
}
}
}

fn det_offsets(i: Idx, window_size: WindowSize, _len: Len) -> (usize, usize) {
(i.saturating_sub(window_size - 1), i + 1)
}
Expand Down Expand Up @@ -77,12 +108,14 @@ where
}

// Parameters allowed for rolling operations.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RollingVarParams {
pub ddof: u8,
}

#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RollingQuantileParams {
pub prob: f64,
pub interpol: QuantileInterpolOptions,
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-arrow/src/legacy/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ pub use crate::legacy::array::default_arrays::*;
pub use crate::legacy::array::*;
pub use crate::legacy::index::*;
pub use crate::legacy::kernels::rolling::no_nulls::QuantileInterpolOptions;
pub use crate::legacy::kernels::rolling::{DynArgs, RollingQuantileParams, RollingVarParams};
pub use crate::legacy::kernels::rolling::{
DynArgs, RollingFnParams, RollingQuantileParams, RollingVarParams,
};
pub use crate::legacy::kernels::{Ambiguous, NonExistent};

pub type LargeStringArray = Utf8Array<i64>;
Expand Down
60 changes: 54 additions & 6 deletions crates/polars-core/src/chunked_array/ops/rolling_window.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use arrow::legacy::prelude::DynArgs;
use arrow::legacy::prelude::{DynArgs, RollingFnParams};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use serde::{ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};

#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RollingOptionsFixedWindow {
/// The length of the window.
pub window_size: usize,
Expand All @@ -14,19 +13,68 @@ pub struct RollingOptionsFixedWindow {
pub weights: Option<Vec<f64>>,
/// Set the labels at the center of the window.
pub center: bool,
#[cfg_attr(feature = "serde", serde(skip))]
pub fn_params: DynArgs,
}

#[cfg(feature = "serde")]
impl Serialize for RollingOptionsFixedWindow {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let rolling_fn_params = RollingFnParams::from_dyn_args(&self.fn_params);
let mut state = serializer.serialize_struct("RollingOptionsFixedWindow", 5)?;

state.serialize_field("window_size", &self.window_size)?;
state.serialize_field("min_periods", &self.min_periods)?;
state.serialize_field("weights", &self.weights)?;
state.serialize_field("center", &self.center)?;
state.serialize_field("fn_params", &rolling_fn_params)?;

state.end()
}
}

#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for RollingOptionsFixedWindow {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize, Debug)]
struct Helper {
window_size: usize,
min_periods: usize,
weights: Option<Vec<f64>>,
center: bool,
#[serde(default)]
fn_params: Option<RollingFnParams>,
}

let helper = Helper::deserialize(deserializer)?;
let fn_params = helper
.fn_params
.as_ref()
.and_then(|param| param.to_dyn_args());
Ok(RollingOptionsFixedWindow {
window_size: helper.window_size,
min_periods: helper.min_periods,
weights: helper.weights,
center: helper.center,
fn_params,
})
}
}

#[cfg(feature = "rolling_window")]
impl PartialEq for RollingOptionsFixedWindow {
fn eq(&self, other: &Self) -> bool {
self.window_size == other.window_size
&& self.min_periods == other.min_periods
&& self.weights == other.weights
&& self.center == other.center
&& self.fn_params.is_none()
&& other.fn_params.is_none()
&& RollingFnParams::from_dyn_args(&self.fn_params)
== RollingFnParams::from_dyn_args(&other.fn_params)
}
}

Expand Down
56 changes: 51 additions & 5 deletions crates/polars-time/src/chunkedarray/rolling_window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ use arrow::legacy::kernels::rolling;
pub use dispatch::*;
use polars_core::prelude::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use serde::{ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};

use crate::prelude::*;

#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RollingOptionsDynamicWindow {
/// The length of the window.
pub window_size: Duration,
Expand All @@ -21,17 +20,64 @@ pub struct RollingOptionsDynamicWindow {
/// Which side windows should be closed.
pub closed_window: ClosedWindow,
/// Optional parameters for the rolling function
#[cfg_attr(feature = "serde", serde(skip))]
pub fn_params: DynArgs,
}

#[cfg(feature = "serde")]
impl Serialize for RollingOptionsDynamicWindow {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let rolling_fn_params = RollingFnParams::from_dyn_args(&self.fn_params);
let mut state = serializer.serialize_struct("RollingOptionsDynamicWindow", 4)?;

state.serialize_field("window_size", &self.window_size)?;
state.serialize_field("min_periods", &self.min_periods)?;
state.serialize_field("closed_window", &self.closed_window)?;
state.serialize_field("fn_params", &rolling_fn_params)?;

state.end()
}
}

#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for RollingOptionsDynamicWindow {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct Helper {
window_size: Duration,
min_periods: usize,
closed_window: ClosedWindow,
#[serde(default)]
fn_params: Option<RollingFnParams>,
}

let helper = Helper::deserialize(deserializer)?;
let fn_params = helper
.fn_params
.as_ref()
.and_then(|param| param.to_dyn_args());

Ok(RollingOptionsDynamicWindow {
window_size: helper.window_size,
min_periods: helper.min_periods,
closed_window: helper.closed_window,
fn_params,
})
}
}

#[cfg(feature = "rolling_window_by")]
impl PartialEq for RollingOptionsDynamicWindow {
fn eq(&self, other: &Self) -> bool {
self.window_size == other.window_size
&& self.min_periods == other.min_periods
&& self.closed_window == other.closed_window
&& self.fn_params.is_none()
&& other.fn_params.is_none()
&& RollingFnParams::from_dyn_args(&self.fn_params)
== RollingFnParams::from_dyn_args(&other.fn_params)
}
}
25 changes: 22 additions & 3 deletions py-polars/tests/unit/expr/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,33 @@
from polars.exceptions import ComputeError


def test_expr_serde_roundtrip_binary() -> None:
expr = pl.col("foo").sum().over("bar")
@pytest.mark.parametrize(
"expr",
[
pl.col("foo").sum().over("bar"),
pl.col("foo").rolling_quantile(0.25, window_size=5),
pl.col("foo").rolling_var(window_size=4, ddof=2),
pl.col("foo").rolling_min(window_size=2),
pl.col("foo").rolling_quantile_by("bar", window_size="1mo", quantile=0.75),
],
)
def test_expr_serde_roundtrip_binary(expr: pl.Expr) -> None:
json = expr.meta.serialize(format="binary")
round_tripped = pl.Expr.deserialize(io.BytesIO(json), format="binary")
assert round_tripped.meta == expr


def test_expr_serde_roundtrip_json() -> None:
@pytest.mark.parametrize(
"expr",
[
pl.col("foo").sum().over("bar"),
pl.col("foo").rolling_quantile(0.25, window_size=5),
pl.col("foo").rolling_var(window_size=4, ddof=2),
pl.col("foo").rolling_min(window_size=2),
pl.col("foo").rolling_quantile_by("bar", window_size="1mo", quantile=0.75),
],
)
def test_expr_serde_roundtrip_json(expr: pl.Expr) -> None:
expr = pl.col("foo").sum().over("bar")
json = expr.meta.serialize(format="json")
round_tripped = pl.Expr.deserialize(io.StringIO(json), format="json")
Expand Down

0 comments on commit 71eacbb

Please sign in to comment.