From 89dd12d6b51c87e1acfb859ec27e13c99a448e28 Mon Sep 17 00:00:00 2001 From: ritchie Date: Fri, 28 Feb 2025 13:55:29 +0100 Subject: [PATCH 1/2] feat: Improve numeric stability rolling_{std, var, cov, corr} --- .../src/rolling/no_nulls/variance.rs | 162 ++++++------------ .../src/rolling/nulls/variance.rs | 157 ++++++----------- crates/polars-compute/src/var_cov.rs | 8 +- .../rolling_kernels/no_nulls.rs | 3 +- .../unit/operations/rolling/test_rolling.py | 2 +- 5 files changed, 111 insertions(+), 221 deletions(-) diff --git a/crates/polars-compute/src/rolling/no_nulls/variance.rs b/crates/polars-compute/src/rolling/no_nulls/variance.rs index 5513a72758cb..61b1f7f069df 100644 --- a/crates/polars-compute/src/rolling/no_nulls/variance.rs +++ b/crates/polars-compute/src/rolling/no_nulls/variance.rs @@ -1,138 +1,88 @@ +use num_traits::{FromPrimitive, ToPrimitive}; use polars_error::polars_ensure; use super::*; +use crate::var_cov::VarState; -pub(super) struct SumSquaredWindow<'a, T> { +pub struct VarWindow<'a, T> { slice: &'a [T], - sum_of_squares: T, + var: VarState, + ddof: u8, last_start: usize, last_end: usize, - // if we don't recompute every 'n' iterations - // we get a accumulated error/drift - last_recompute: u8, } -impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul> - RollingAggWindowNoNulls<'a, T> for SumSquaredWindow<'a, T> +impl VarWindow<'_, T> { + fn compute_var(&mut self, start: usize, end: usize) { + self.var = VarState::default(); + for value in &self.slice[start..end] { + let value: f64 = NumCast::from(*value).unwrap(); + self.var.insert_one(value); + } + } +} + +impl<'a, T: NativeType + IsFloat + Float + ToPrimitive + FromPrimitive> + RollingAggWindowNoNulls<'a, T> for VarWindow<'a, T> { - fn new(slice: &'a [T], start: usize, end: usize, _params: Option) -> Self { - let sum = slice[start..end].iter().map(|v| *v * *v).sum::(); - Self { + fn new(slice: &'a [T], start: usize, end: usize, params: Option) -> Self { + let mut out = Self { slice, - sum_of_squares: sum, + var: VarState::default(), last_start: start, last_end: end, - last_recompute: 0, - } + ddof: match params { + None => 1, + Some(pars) => { + let RollingFnParams::Var(pars) = pars else { + unreachable!("expected Var params"); + }; + pars.ddof + }, + }, + }; + out.compute_var(start, end); + out } unsafe fn update(&mut self, start: usize, end: usize) -> Option { - // if we exceed the end, we have a completely new window - // so we recompute - let recompute_sum = if start >= self.last_end || self.last_recompute > 128 { - self.last_recompute = 0; + let recompute_var = if start >= self.last_end { true } else { - self.last_recompute += 1; // remove elements that should leave the window - let mut recompute_sum = false; + let mut recompute_var = false; for idx in self.last_start..start { - // SAFETY: - // we are in bounds - let leaving_value = self.slice.get_unchecked(idx); + // SAFETY: we are in bounds + let leaving_value = *self.slice.get_unchecked(idx); + // if the leaving value is nan we need to recompute the window if T::is_float() && !leaving_value.is_finite() { - recompute_sum = true; + recompute_var = true; break; } - - self.sum_of_squares -= *leaving_value * *leaving_value; + let leaving_value: f64 = NumCast::from(leaving_value).unwrap(); + self.var.remove_one(leaving_value); } - recompute_sum + recompute_var }; self.last_start = start; // we traverse all values and compute - if T::is_float() && recompute_sum { - self.sum_of_squares = self - .slice - .get_unchecked(start..end) - .iter() - .map(|v| *v * *v) - .sum::(); + if recompute_var { + self.compute_var(start, end); } else { for idx in self.last_end..end { let entering_value = *self.slice.get_unchecked(idx); - self.sum_of_squares += entering_value * entering_value; - } - } - self.last_end = end; - Some(self.sum_of_squares) - } -} - -// E[(xi - E[x])^2] -// can be expanded to -// E[x^2] - E[x]^2 -pub struct VarWindow<'a, T> { - mean: MeanWindow<'a, T>, - sum_of_squares: SumSquaredWindow<'a, T>, - ddof: u8, -} + let entering_value: f64 = NumCast::from(entering_value).unwrap(); -impl< - 'a, - T: NativeType - + IsFloat - + Float - + std::iter::Sum - + AddAssign - + SubAssign - + Div - + NumCast - + One - + Zero - + PartialOrd - + Sub, - > RollingAggWindowNoNulls<'a, T> for VarWindow<'a, T> -{ - fn new(slice: &'a [T], start: usize, end: usize, params: Option) -> Self { - Self { - mean: MeanWindow::new(slice, start, end, None), - sum_of_squares: SumSquaredWindow::new(slice, start, end, None), - ddof: match params { - None => 1, - Some(pars) => { - let RollingFnParams::Var(pars) = pars else { - unreachable!("expected Var params"); - }; - pars.ddof - }, - }, - } - } - - unsafe fn update(&mut self, start: usize, end: usize) -> Option { - let count: T = NumCast::from(end - start).unwrap(); - let sum_of_squares = self.sum_of_squares.update(start, end).unwrap_unchecked(); - let mean = self.mean.update(start, end).unwrap_unchecked(); - - let denom = count - NumCast::from(self.ddof).unwrap(); - if denom <= T::zero() { - None - } else if end - start == 1 { - Some(T::zero()) - } else { - let out = (sum_of_squares - count * mean * mean) / denom; - // variance cannot be negative. - // if it is negative it is due to numeric instability - if out < T::zero() { - Some(T::zero()) - } else { - Some(out) + self.var.insert_one(entering_value); } } + self.last_end = end; + self.var + .finalize(self.ddof) + .map(|v| T::from_f64(v).unwrap()) } } @@ -145,17 +95,7 @@ pub fn rolling_var( params: Option, ) -> PolarsResult where - T: NativeType - + Float - + IsFloat - + std::iter::Sum - + AddAssign - + SubAssign - + Div - + NumCast - + One - + Zero - + Sub, + T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign, { let offset_fn = match center { true => det_offsets_center, @@ -231,7 +171,7 @@ mod test { &[ None, None, - Some(52.333333333333336), + Some(52.33333333333333), Some(f64::nan()), Some(f64::nan()), Some(f64::nan()), diff --git a/crates/polars-compute/src/rolling/nulls/variance.rs b/crates/polars-compute/src/rolling/nulls/variance.rs index 5d303eab982c..f3ba0a89a087 100644 --- a/crates/polars-compute/src/rolling/nulls/variance.rs +++ b/crates/polars-compute/src/rolling/nulls/variance.rs @@ -1,67 +1,76 @@ +use num_traits::{FromPrimitive, ToPrimitive}; + use super::*; +use crate::var_cov::VarState; -pub(super) struct SumSquaredWindow<'a, T> { +pub struct VarWindow<'a, T> { slice: &'a [T], validity: &'a Bitmap, - sum_of_squares: Option, + var: Option, last_start: usize, last_end: usize, null_count: usize, + ddof: u8, } -impl + Sub + Mul> - SumSquaredWindow<'_, T> -{ +impl VarWindow<'_, T> { // compute sum from the entire window - unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option { - let mut sum_of_squares = None; + unsafe fn compute_var_and_null_count(&mut self, start: usize, end: usize) { + let mut var = None; let mut idx = start; self.null_count = 0; for value in &self.slice[start..end] { let valid = self.validity.get_bit_unchecked(idx); if valid { - match sum_of_squares { - None => sum_of_squares = Some(*value * *value), - Some(current) => sum_of_squares = Some(*value * *value + current), + let value: f64 = NumCast::from(*value).unwrap(); + match &mut var { + None => var = Some(VarState::new_single(value)), + Some(current) => current.insert_one(value), } } else { self.null_count += 1; } idx += 1; } - self.sum_of_squares = sum_of_squares; - sum_of_squares + self.var = var; } } -impl<'a, T: NativeType + IsFloat + Add + Sub + Mul> - RollingAggWindowNulls<'a, T> for SumSquaredWindow<'a, T> +impl<'a, T: NativeType + ToPrimitive + IsFloat + FromPrimitive> RollingAggWindowNulls<'a, T> + for VarWindow<'a, T> { unsafe fn new( slice: &'a [T], validity: &'a Bitmap, start: usize, end: usize, - _params: Option, + params: Option, ) -> Self { + let ddof = if let Some(RollingFnParams::Var(params)) = params { + params.ddof + } else { + 1 + }; + let mut out = Self { slice, validity, - sum_of_squares: None, + var: None, last_start: start, last_end: end, null_count: 0, + ddof, }; - out.compute_sum_and_null_count(start, end); + out.compute_var_and_null_count(start, end); out } unsafe fn update(&mut self, start: usize, end: usize) -> Option { - let recompute_sum = if start >= self.last_end { + let recompute_var = if start >= self.last_end { true } else { // remove elements that should leave the window - let mut recompute_sum = false; + let mut recompute_var = false; for idx in self.last_start..start { // SAFETY: // we are in bounds @@ -71,42 +80,44 @@ impl<'a, T: NativeType + IsFloat + Add + Sub + Mul self.sum_of_squares = Some(value), - Some(current) => self.sum_of_squares = Some(current + value), + let entering_value = *self.slice.get_unchecked(idx); + let entering_value: f64 = NumCast::from(entering_value).unwrap(); + + match &mut self.var { + None => self.var = Some(VarState::new_single(entering_value)), + Some(current) => current.insert_one(entering_value), } } else { // null value entering the window @@ -115,82 +126,14 @@ impl<'a, T: NativeType + IsFloat + Add + Sub + Mul bool { - ((self.last_end - self.last_start) - self.null_count) >= min_periods - } -} -// E[(xi - E[x])^2] -// can be expanded to -// E[x^2] - E[x]^2 -pub struct VarWindow<'a, T> { - mean: MeanWindow<'a, T>, - sum_of_squares: SumSquaredWindow<'a, T>, - ddof: u8, -} - -impl< - 'a, - T: NativeType - + IsFloat - + Float - + std::iter::Sum - + AddAssign - + SubAssign - + Div - + NumCast - + One - + Zero - + PartialOrd - + Add - + Sub, - > RollingAggWindowNulls<'a, T> for VarWindow<'a, T> -{ - unsafe fn new( - slice: &'a [T], - validity: &'a Bitmap, - start: usize, - end: usize, - params: Option, - ) -> Self { - Self { - mean: MeanWindow::new(slice, validity, start, end, None), - sum_of_squares: SumSquaredWindow::new(slice, validity, start, end, None), - ddof: match params { - None => 1, - Some(pars) => match pars { - RollingFnParams::Var(p) => p.ddof, - _ => unreachable!("expected Var params"), - }, - }, - } - } - - unsafe fn update(&mut self, start: usize, end: usize) -> Option { - let sum_of_squares = self.sum_of_squares.update(start, end)?; - let null_count = self.sum_of_squares.null_count; - let count: T = NumCast::from(end - start - null_count).unwrap(); - - let mean = self.mean.update(start, end)?; - let ddof = NumCast::from(self.ddof).unwrap(); - - let denom = count - ddof; - - if denom <= T::zero() { - None - } else if count == T::one() { - Some(T::zero()) - } else if denom <= T::zero() { - Some(T::infinity()) - } else { - let var = (sum_of_squares - count * mean * mean) / denom; - Some(if var < T::zero() { T::zero() } else { var }) - } - } fn is_valid(&self, min_periods: usize) -> bool { - self.mean.is_valid(min_periods) + ((self.last_end - self.last_start) - self.null_count) >= min_periods } } @@ -203,7 +146,7 @@ pub fn rolling_var( params: Option, ) -> ArrayRef where - T: NativeType + std::iter::Sum + Zero + AddAssign + SubAssign + IsFloat + Float, + T: NativeType + ToPrimitive + FromPrimitive + IsFloat + Float, { if weights.is_some() { panic!("weights not yet supported on array with null values") diff --git a/crates/polars-compute/src/var_cov.rs b/crates/polars-compute/src/var_cov.rs index deb9c0f8d853..0a47230bc521 100644 --- a/crates/polars-compute/src/var_cov.rs +++ b/crates/polars-compute/src/var_cov.rs @@ -47,7 +47,7 @@ pub struct PearsonState { } impl VarState { - fn new(x: &[f64]) -> Self { + pub fn new(x: &[f64]) -> Self { if x.is_empty() { return Self::default(); } @@ -61,6 +61,12 @@ impl VarState { } } + pub(crate) fn new_single(x: f64) -> Self { + let mut out = Self::default(); + out.insert_one(x); + out + } + pub fn insert_one(&mut self, x: f64) { // Just a specialized version of // self.combine(&Self { weight: 1.0, mean: x, dp: 0.0 }) diff --git a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs index ab893f151edd..95ecf93ee2be 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs @@ -4,6 +4,7 @@ use arrow::bitmap::MutableBitmap; use bytemuck::allocation::zeroed_vec; #[cfg(feature = "timezones")] use chrono_tz::Tz; +use num_traits::{FromPrimitive, ToPrimitive}; use polars_compute::rolling::no_nulls::{self, RollingAggWindowNoNulls}; use polars_compute::rolling::quantile_filter::SealedRolling; use polars_compute::rolling::RollingFnParams; @@ -312,7 +313,7 @@ pub(crate) fn rolling_var( sorting_indices: Option<&[IdxSize]>, ) -> PolarsResult where - T: NativeType + Float + std::iter::Sum + SubAssign + AddAssign + IsFloat, + T: NativeType + Float + ToPrimitive + FromPrimitive + AddAssign + IsFloat, { let offset_iter = match tz { #[cfg(feature = "timezones")] diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index 5898278cf804..b46c3ab7d5dd 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -587,7 +587,7 @@ def test_rolling_cov_corr() -> None: pl.rolling_corr("x", "y", window_size=3).alias("corr"), ).to_dict(as_series=False) assert res["cov"][2:] == pytest.approx([0.0, 0.0, 5.333333333333336]) - assert res["corr"][2:] == pytest.approx([nan, nan, 0.9176629354822473], nan_ok=True) + assert res["corr"][2:] == pytest.approx([nan, 0.0, 0.9176629354822473], nan_ok=True) assert res["cov"][:2] == [None] * 2 assert res["corr"][:2] == [None] * 2 From 1d1e81625b69437cc3e64aa6037bc774b7137f19 Mon Sep 17 00:00:00 2001 From: ritchie Date: Fri, 28 Feb 2025 13:57:44 +0100 Subject: [PATCH 2/2] pub --- crates/polars-compute/src/var_cov.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/polars-compute/src/var_cov.rs b/crates/polars-compute/src/var_cov.rs index 0a47230bc521..ac950963f923 100644 --- a/crates/polars-compute/src/var_cov.rs +++ b/crates/polars-compute/src/var_cov.rs @@ -47,7 +47,7 @@ pub struct PearsonState { } impl VarState { - pub fn new(x: &[f64]) -> Self { + fn new(x: &[f64]) -> Self { if x.is_empty() { return Self::default(); }