Skip to content

Commit

Permalink
optimize quantile
Browse files Browse the repository at this point in the history
  • Loading branch information
GreyRaphael committed Nov 2, 2024
1 parent 89c25f4 commit b7c0e37
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 76 deletions.
19 changes: 0 additions & 19 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ edition = "2021"
name = "ta"
crate-type = ["cdylib"]

[dependencies]
ordered-float = "4"

[dependencies.pyo3]
version = "0.22"
# "abi3-py38" tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.8
Expand Down
8 changes: 4 additions & 4 deletions examples/rolling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
miner = rolling.Miner(3)
deltaer = rolling.Deltaer(3)
pctchanger = rolling.Pctchanger(3)
quantiler = rolling.Quantiler(11, 0.3)
quantiler = rolling.Quantiler(10, 0.3)

# rollingers = [sumer, meaner, maxer, miner, deltaer, pctchanger, quantiler]
rollingers = [deltaer, pctchanger]
rollingers = [sumer, meaner, maxer, miner, deltaer, pctchanger, quantiler]
# rollingers = [deltaer, pctchanger]

for rollinger in rollingers:
for i in range(20):
# print(i, rollinger.update(i), rollinger.get(0), rolling.get(2))
print(i, rollinger.update(i), rollinger.get(0), rollinger.get(2))
print(i, rollinger.update(i))
print("-" * 20)
71 changes: 21 additions & 50 deletions src/rolling/quantile.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
use ordered_float::OrderedFloat;
use pyo3::prelude::*;
use std::collections::BTreeMap;
use std::f64::NAN;
use std::ops::Bound::{Excluded, Unbounded};

use crate::utils::is_nan_or_inf;

#[pyclass]
pub struct Quantiler {
dataset: BTreeMap<OrderedFloat<f64>, usize>,
dataset: Vec<f64>,
buf: Vec<f64>,
n: usize,
cur_idx: usize,
nan_count: usize,
quantile: f64,
Expand All @@ -21,9 +17,8 @@ impl Quantiler {
#[new]
pub fn new(n: usize, quantile: f64) -> Self {
Self {
dataset: BTreeMap::new(),
dataset: Vec::new(),
buf: vec![NAN; n],
n,
cur_idx: 0,
nan_count: n,
quantile,
Expand All @@ -33,69 +28,45 @@ impl Quantiler {
pub fn update(&mut self, new_val: f64) -> f64 {
let old_val = self.buf[self.cur_idx];
self.buf[self.cur_idx] = new_val;
self.cur_idx = (self.cur_idx + 1) % self.n;
self.cur_idx = (self.cur_idx + 1) % self.buf.len();

// Update nan_count and dataset based on new_val
if is_nan_or_inf(new_val) {
self.nan_count += 1;
} else {
let ordered_val = OrderedFloat(new_val);
*self.dataset.entry(ordered_val).or_insert(0) += 1;
let pos = self
.dataset
.binary_search_by(|v| v.partial_cmp(&new_val).unwrap())
.unwrap_or_else(|e| e);
self.dataset.insert(pos, new_val);
}

// Update nan_count and dataset based on old_val
if is_nan_or_inf(old_val) {
self.nan_count -= 1;
} else {
let ordered_old_val = OrderedFloat(old_val);
if let Some(count) = self.dataset.get_mut(&ordered_old_val) {
*count -= 1;
if *count == 0 {
self.dataset.remove(&ordered_old_val);
}
}
let pos = self
.dataset
.binary_search_by(|v| v.partial_cmp(&old_val).unwrap())
.unwrap();
self.dataset.remove(pos);
}

if self.nan_count > 0 {
NAN
} else {
let size: usize = self.dataset.iter().map(|(_, &count)| count).sum();

let index = (size as f64 - 1.0) * self.quantile;
let index = (self.dataset.len() - 1) as f64 * self.quantile;
let lower_index = index.floor() as usize;
let fraction = index - lower_index as f64;

// Iterate through the sorted dataset to find the lower and upper values
let mut cumulative = 0;
let mut lower_value = NAN;
let mut upper_value = NAN;

for (&key, &count) in &self.dataset {
if cumulative + count > lower_index {
lower_value = key.0;
if fraction == 0.0 {
upper_value = key.0;
break;
} else {
// Find the next value
if let Some((&next_key, _)) =
self.dataset.range((Excluded(&key), Unbounded)).next()
{
upper_value = next_key.0;
} else {
upper_value = key.0;
}
break;
}
}
cumulative += count;
}

if lower_value.is_nan() || upper_value.is_nan() {
NAN
let lower_value = self.dataset[lower_index];
let upper_value = if lower_index + 1 < self.dataset.len() {
self.dataset[lower_index + 1]
} else {
lower_value + fraction * (upper_value - lower_value)
}
lower_value
};

lower_value + fraction * (upper_value - lower_value)
}
}
}

0 comments on commit b7c0e37

Please sign in to comment.