Skip to content

Commit

Permalink
refactor: Replace DynArgs with an enum containing all its variants
Browse files Browse the repository at this point in the history
  • Loading branch information
3ok committed Sep 15, 2024
1 parent 962b576 commit 4ada469
Show file tree
Hide file tree
Showing 22 changed files with 150 additions and 120 deletions.
18 changes: 13 additions & 5 deletions crates/polars-arrow/src/legacy/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ pub mod nulls;
pub mod quantile_filter;
mod window;

use std::any::Any;
use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
use std::sync::Arc;

use num_traits::{Bounded, Float, NumCast, One, Zero};
use polars_utils::float::IsFloat;
use polars_utils::ord::{compare_fn_nan_max, compare_fn_nan_min};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use window::*;

use crate::array::{ArrayRef, PrimitiveArray};
Expand All @@ -23,7 +23,13 @@ type End = usize;
type Idx = usize;
type WindowSize = usize;
type Len = usize;
pub type DynArgs = Option<Arc<dyn Any + Sync + Send>>;

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum RollingFnParams {
Quantile(RollingQuantileParams),
Var(RollingVarParams),
}

fn det_offsets(i: Idx, window_size: WindowSize, _len: Len) -> (usize, usize) {
(i.saturating_sub(window_size - 1), i + 1)
Expand Down Expand Up @@ -77,12 +83,14 @@ where
}

// Parameters allowed for rolling operations.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RollingVarParams {
pub ddof: u8,
}

#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RollingQuantileParams {
pub prob: f64,
pub interpol: QuantileInterpolOptions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ impl<
T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Div<Output = T> + NumCast,
> RollingAggWindowNoNulls<'a, T> for MeanWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize, params: DynArgs) -> Self {
fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
Self {
sum: SumWindow::new(slice, start, end, params),
}
Expand All @@ -29,7 +29,7 @@ pub fn rolling_mean<T>(
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
_params: DynArgs,
_params: Option<RollingFnParams>,
) -> PolarsResult<ArrayRef>
where
T: NativeType + Float + std::iter::Sum<T> + SubAssign + AddAssign + IsFloat,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,12 @@ macro_rules! minmax_window {
impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNoNulls<'a, T>
for $m_window<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize, _params: DynArgs) -> Self {
fn new(
slice: &'a [T],
start: usize,
end: usize,
_params: Option<RollingFnParams>,
) -> Self {
let (idx, m) =
unsafe { $get_m_and_idx(slice, start, end, 0).unwrap_or((0, &slice[start])) };
Self {
Expand Down Expand Up @@ -238,7 +243,7 @@ macro_rules! rolling_minmax_func {
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
_params: DynArgs,
_params: Option<RollingFnParams>,
) -> PolarsResult<ArrayRef>
where
T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul<Output = T> + Num,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::legacy::error::PolarsResult;
use crate::types::NativeType;

pub trait RollingAggWindowNoNulls<'a, T: NativeType> {
fn new(slice: &'a [T], start: usize, end: usize, params: DynArgs) -> Self;
fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self;

/// Update and recompute the window
///
Expand All @@ -36,7 +36,7 @@ pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
window_size: usize,
min_periods: usize,
det_offsets_fn: Fo,
params: DynArgs,
params: Option<RollingFnParams>,
) -> PolarsResult<ArrayRef>
where
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
Expand Down
30 changes: 19 additions & 11 deletions crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@ impl<
+ Sub<Output = T>,
> RollingAggWindowNoNulls<'a, T> for QuantileWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize, params: DynArgs) -> Self {
fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
let params = params.unwrap();
let params = params.downcast_ref::<RollingQuantileParams>().unwrap();
let RollingFnParams::Quantile(params) = params else {
unreachable!("expected Quantile params");
};

Self {
sorted: SortedBuf::new(slice, start, end),
prob: params.prob,
Expand Down Expand Up @@ -103,7 +106,7 @@ pub fn rolling_quantile<T>(
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
params: DynArgs,
params: Option<RollingFnParams>,
) -> PolarsResult<ArrayRef>
where
T: NativeType
Expand All @@ -127,7 +130,9 @@ where
None => {
if !center {
let params = params.as_ref().unwrap();
let params = params.downcast_ref::<RollingQuantileParams>().unwrap();
let RollingFnParams::Quantile(params) = params else {
unreachable!("expected Quantile params");
};
let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(
params.interpol,
min_periods,
Expand Down Expand Up @@ -158,7 +163,10 @@ where
ComputeError: "Weighted quantile is undefined if weights sum to 0"
);
let params = params.unwrap();
let params = params.downcast_ref::<RollingQuantileParams>().unwrap();
let RollingFnParams::Quantile(params) = params else {
unreachable!("expected Quantile params");
};

Ok(rolling_apply_weighted_quantile(
values,
params.prob,
Expand Down Expand Up @@ -263,10 +271,10 @@ mod test {
#[test]
fn test_rolling_median() {
let values = &[1.0, 2.0, 3.0, 4.0];
let med_pars = Some(Arc::new(RollingQuantileParams {
let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 0.5,
interpol: Linear,
}) as Arc<dyn Any + Send + Sync>);
}));
let out = rolling_quantile(values, 2, 2, false, None, med_pars.clone()).unwrap();
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
Expand Down Expand Up @@ -306,10 +314,10 @@ mod test {
];

for interpol in interpol_options {
let min_pars = Some(Arc::new(RollingQuantileParams {
let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 0.0,
interpol,
}) as Arc<dyn Any + Send + Sync>);
}));
let out1 = rolling_min(values, 2, 2, false, None, None).unwrap();
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
Expand All @@ -318,10 +326,10 @@ mod test {
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out1, out2);

let max_pars = Some(Arc::new(RollingQuantileParams {
let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 1.0,
interpol,
}) as Arc<dyn Any + Send + Sync>);
}));
let out1 = rolling_max(values, 2, 2, false, None, None).unwrap();
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub struct SumWindow<'a, T> {
impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign>
RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize, _params: DynArgs) -> Self {
fn new(slice: &'a [T], start: usize, end: usize, _params: Option<RollingFnParams>) -> Self {
let sum = slice[start..end].iter().copied().sum::<T>();
Self {
slice,
Expand Down Expand Up @@ -70,7 +70,7 @@ pub fn rolling_sum<T>(
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
_params: DynArgs,
_params: Option<RollingFnParams>,
) -> PolarsResult<ArrayRef>
where
T: NativeType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub(super) struct SumSquaredWindow<'a, T> {
impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul<Output = T>>
RollingAggWindowNoNulls<'a, T> for SumSquaredWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize, _params: DynArgs) -> Self {
fn new(slice: &'a [T], start: usize, end: usize, _params: Option<RollingFnParams>) -> Self {
let sum = slice[start..end].iter().map(|v| *v * *v).sum::<T>();
Self {
slice,
Expand Down Expand Up @@ -97,13 +97,18 @@ impl<
+ Sub<Output = T>,
> RollingAggWindowNoNulls<'a, T> for VarWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize, params: DynArgs) -> Self {
fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
Self {
mean: MeanWindow::new(slice, start, end, None),
sum_of_squares: SumSquaredWindow::new(slice, start, end, None),
ddof: match params {
None => 1,
Some(pars) => pars.downcast_ref::<RollingVarParams>().unwrap().ddof,
Some(pars) => {
let RollingFnParams::Var(pars) = pars else {
unreachable!("expected Var params");
};
pars.ddof
},
},
}
}
Expand Down Expand Up @@ -137,7 +142,7 @@ pub fn rolling_var<T>(
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
params: DynArgs,
params: Option<RollingFnParams>,
) -> PolarsResult<ArrayRef>
where
T: NativeType
Expand Down Expand Up @@ -199,7 +204,7 @@ mod test {
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, Some(8.0), Some(2.0), Some(0.5)]);

let testpars = Some(Arc::new(RollingVarParams { ddof: 0 }) as Arc<dyn Any + Send + Sync>);
let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));
let out = rolling_var(values, 2, 2, false, None, testpars).unwrap();
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-arrow/src/legacy/kernels/rolling/nulls/mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ impl<
validity: &'a Bitmap,
start: usize,
end: usize,
params: DynArgs,
params: Option<RollingFnParams>,
) -> Self {
Self {
sum: SumWindow::new(slice, validity, start, end, params),
Expand All @@ -36,7 +36,7 @@ pub fn rolling_mean<T>(
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
_params: DynArgs,
_params: Option<RollingFnParams>,
) -> ArrayRef
where
T: NativeType
Expand Down
10 changes: 5 additions & 5 deletions crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl<'a, T: NativeType> RollingAggWindowNulls<'a, T> for SortedMinMax<'a, T> {
validity: &'a Bitmap,
start: usize,
end: usize,
_params: DynArgs,
_params: Option<RollingFnParams>,
) -> Self {
let mut out = Self {
slice,
Expand Down Expand Up @@ -258,7 +258,7 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNulls<'a, T> for
validity: &'a Bitmap,
start: usize,
end: usize,
_params: DynArgs,
_params: Option<RollingFnParams>,
) -> Self {
Self {
inner: MinMaxWindow::new(
Expand Down Expand Up @@ -287,7 +287,7 @@ pub fn rolling_min<T>(
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
_params: DynArgs,
_params: Option<RollingFnParams>,
) -> ArrayRef
where
T: NativeType + std::iter::Sum + Zero + AddAssign + Copy + PartialOrd + Bounded + IsFloat,
Expand Down Expand Up @@ -326,7 +326,7 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNulls<'a, T> for
validity: &'a Bitmap,
start: usize,
end: usize,
_params: DynArgs,
_params: Option<RollingFnParams>,
) -> Self {
Self {
inner: MinMaxWindow::new(
Expand Down Expand Up @@ -355,7 +355,7 @@ pub fn rolling_max<T>(
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
_params: DynArgs,
_params: Option<RollingFnParams>,
) -> ArrayRef
where
T: NativeType + std::iter::Sum + Zero + AddAssign + Copy + PartialOrd + Bounded + IsFloat,
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub trait RollingAggWindowNulls<'a, T: NativeType> {
validity: &'a Bitmap,
start: usize,
end: usize,
params: DynArgs,
params: Option<RollingFnParams>,
) -> Self;

/// # Safety
Expand All @@ -37,7 +37,7 @@ pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
window_size: usize,
min_periods: usize,
det_offsets_fn: Fo,
params: DynArgs,
params: Option<RollingFnParams>,
) -> ArrayRef
where
Fo: Fn(Idx, WindowSize, Len) -> (Start, End) + Copy,
Expand Down Expand Up @@ -175,7 +175,7 @@ mod test {

assert_eq!(out, &[0.0, 0.0, 2.0, 12.5]);

let testpars = Some(Arc::new(RollingVarParams { ddof: 0 }) as Arc<dyn Any + Send + Sync>);
let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));
let out = rolling_var(arr, 3, 1, false, None, testpars.clone());
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out
Expand Down
Loading

0 comments on commit 4ada469

Please sign in to comment.