Skip to content

Commit

Permalink
feat: Improve numeric stability rolling_{std, var, cov, corr}
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 28, 2025
1 parent 69612d4 commit 89dd12d
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 221 deletions.
162 changes: 51 additions & 111 deletions crates/polars-compute/src/rolling/no_nulls/variance.rs
Original file line number Diff line number Diff line change
@@ -1,138 +1,88 @@
use num_traits::{FromPrimitive, ToPrimitive};
use polars_error::polars_ensure;

use super::*;
use crate::var_cov::VarState;

pub(super) struct SumSquaredWindow<'a, T> {
pub struct VarWindow<'a, T> {
slice: &'a [T],
sum_of_squares: T,
var: VarState,
ddof: u8,
last_start: usize,
last_end: usize,
// if we don't recompute every 'n' iterations
// we get a accumulated error/drift
last_recompute: u8,
}

impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul<Output = T>>
RollingAggWindowNoNulls<'a, T> for SumSquaredWindow<'a, T>
impl<T: ToPrimitive + Copy> VarWindow<'_, T> {
fn compute_var(&mut self, start: usize, end: usize) {
self.var = VarState::default();
for value in &self.slice[start..end] {
let value: f64 = NumCast::from(*value).unwrap();
self.var.insert_one(value);
}
}
}

impl<'a, T: NativeType + IsFloat + Float + ToPrimitive + FromPrimitive>
RollingAggWindowNoNulls<'a, T> for VarWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize, _params: Option<RollingFnParams>) -> Self {
let sum = slice[start..end].iter().map(|v| *v * *v).sum::<T>();
Self {
fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
let mut out = Self {
slice,
sum_of_squares: sum,
var: VarState::default(),
last_start: start,
last_end: end,
last_recompute: 0,
}
ddof: match params {
None => 1,
Some(pars) => {
let RollingFnParams::Var(pars) = pars else {
unreachable!("expected Var params");
};
pars.ddof
},
},
};
out.compute_var(start, end);
out
}

unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
// if we exceed the end, we have a completely new window
// so we recompute
let recompute_sum = if start >= self.last_end || self.last_recompute > 128 {
self.last_recompute = 0;
let recompute_var = if start >= self.last_end {
true
} else {
self.last_recompute += 1;
// remove elements that should leave the window
let mut recompute_sum = false;
let mut recompute_var = false;
for idx in self.last_start..start {
// SAFETY:
// we are in bounds
let leaving_value = self.slice.get_unchecked(idx);
// SAFETY: we are in bounds
let leaving_value = *self.slice.get_unchecked(idx);

// if the leaving value is nan we need to recompute the window
if T::is_float() && !leaving_value.is_finite() {
recompute_sum = true;
recompute_var = true;
break;
}

self.sum_of_squares -= *leaving_value * *leaving_value;
let leaving_value: f64 = NumCast::from(leaving_value).unwrap();
self.var.remove_one(leaving_value);
}
recompute_sum
recompute_var
};

self.last_start = start;

// we traverse all values and compute
if T::is_float() && recompute_sum {
self.sum_of_squares = self
.slice
.get_unchecked(start..end)
.iter()
.map(|v| *v * *v)
.sum::<T>();
if recompute_var {
self.compute_var(start, end);
} else {
for idx in self.last_end..end {
let entering_value = *self.slice.get_unchecked(idx);
self.sum_of_squares += entering_value * entering_value;
}
}
self.last_end = end;
Some(self.sum_of_squares)
}
}

// E[(xi - E[x])^2]
// can be expanded to
// E[x^2] - E[x]^2
pub struct VarWindow<'a, T> {
mean: MeanWindow<'a, T>,
sum_of_squares: SumSquaredWindow<'a, T>,
ddof: u8,
}
let entering_value: f64 = NumCast::from(entering_value).unwrap();

impl<
'a,
T: NativeType
+ IsFloat
+ Float
+ std::iter::Sum
+ AddAssign
+ SubAssign
+ Div<Output = T>
+ NumCast
+ One
+ Zero
+ PartialOrd
+ Sub<Output = T>,
> RollingAggWindowNoNulls<'a, T> for VarWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize, params: Option<RollingFnParams>) -> Self {
Self {
mean: MeanWindow::new(slice, start, end, None),
sum_of_squares: SumSquaredWindow::new(slice, start, end, None),
ddof: match params {
None => 1,
Some(pars) => {
let RollingFnParams::Var(pars) = pars else {
unreachable!("expected Var params");
};
pars.ddof
},
},
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
let count: T = NumCast::from(end - start).unwrap();
let sum_of_squares = self.sum_of_squares.update(start, end).unwrap_unchecked();
let mean = self.mean.update(start, end).unwrap_unchecked();

let denom = count - NumCast::from(self.ddof).unwrap();
if denom <= T::zero() {
None
} else if end - start == 1 {
Some(T::zero())
} else {
let out = (sum_of_squares - count * mean * mean) / denom;
// variance cannot be negative.
// if it is negative it is due to numeric instability
if out < T::zero() {
Some(T::zero())
} else {
Some(out)
self.var.insert_one(entering_value);
}
}
self.last_end = end;
self.var
.finalize(self.ddof)
.map(|v| T::from_f64(v).unwrap())
}
}

Expand All @@ -145,17 +95,7 @@ pub fn rolling_var<T>(
params: Option<RollingFnParams>,
) -> PolarsResult<ArrayRef>
where
T: NativeType
+ Float
+ IsFloat
+ std::iter::Sum
+ AddAssign
+ SubAssign
+ Div<Output = T>
+ NumCast
+ One
+ Zero
+ Sub<Output = T>,
T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign,
{
let offset_fn = match center {
true => det_offsets_center,
Expand Down Expand Up @@ -231,7 +171,7 @@ mod test {
&[
None,
None,
Some(52.333333333333336),
Some(52.33333333333333),
Some(f64::nan()),
Some(f64::nan()),
Some(f64::nan()),
Expand Down
Loading

0 comments on commit 89dd12d

Please sign in to comment.