Skip to content

Commit

Permalink
perf: Pre-fill caches (#21646)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Mar 7, 2025
1 parent 36c49ed commit 251fab5
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 176 deletions.
25 changes: 18 additions & 7 deletions crates/polars-expr/src/state/execution_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ type CachedValue = Arc<(AtomicI64, OnceCell<DataFrame>)>;
/// State/ cache that is maintained during the Execution of the physical plan.
pub struct ExecutionState {
// cached by a `.cache` call and kept in memory for the duration of the plan.
df_cache: Arc<Mutex<PlHashMap<usize, CachedValue>>>,
df_cache: Arc<RwLock<PlHashMap<usize, CachedValue>>>,
pub schema_cache: RwLock<Option<SchemaRef>>,
/// Used by Window Expressions to cache intermediate state
pub window_cache: Arc<WindowCache>,
Expand Down Expand Up @@ -218,15 +218,26 @@ impl ExecutionState {
}

pub fn get_df_cache(&self, key: usize, cache_hits: u32) -> CachedValue {
let mut guard = self.df_cache.lock().unwrap();
guard
.entry(key)
.or_insert_with(|| Arc::new((AtomicI64::new(cache_hits as i64), OnceCell::new())))
.clone()
let guard = self.df_cache.read().unwrap();

match guard.get(&key) {
Some(v) => v.clone(),
None => {
drop(guard);
let mut guard = self.df_cache.write().unwrap();

guard
.entry(key)
.or_insert_with(|| {
Arc::new((AtomicI64::new(cache_hits as i64), OnceCell::new()))
})
.clone()
},
}
}

pub fn remove_df_cache(&self, key: usize) {
let mut guard = self.df_cache.lock().unwrap();
let mut guard = self.df_cache.write().unwrap();
let _ = guard.remove(&key).unwrap();
}

Expand Down
64 changes: 44 additions & 20 deletions crates/polars-mem-engine/src/executors/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,60 @@ use std::sync::atomic::Ordering;
use super::*;

pub struct CacheExec {
pub input: Box<dyn Executor>,
pub input: Option<Box<dyn Executor>>,
pub id: usize,
pub count: u32,
}

impl Executor for CacheExec {
fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult<DataFrame> {
let cache = state.get_df_cache(self.id, self.count);
let mut cache_hit = true;
let previous = cache.0.fetch_sub(1, Ordering::Relaxed);
debug_assert!(previous >= 0);
match &mut self.input {
// Cache node
None => {
if state.verbose() {
eprintln!("CACHE HIT: cache id: {:x}", self.id);
}
let cache = state.get_df_cache(self.id, self.count);
let out = cache.1.get().expect("prefilled").clone();
let previous = cache.0.fetch_sub(1, Ordering::Relaxed);
if previous == 0 {
state.remove_df_cache(self.id);
}

let df = cache.1.get_or_try_init(|| {
cache_hit = false;
self.input.execute(state)
})?;

// Decrement count on cache hits.
if cache_hit && previous == 0 {
state.remove_df_cache(self.id);
Ok(out)
},
// Cache Prefill node
Some(input) => {
if state.verbose() {
eprintln!("CACHE SET: cache id: {:x}", self.id);
}
let df = input.execute(state)?;
let cache = state.get_df_cache(self.id, self.count);
cache.1.set(df).expect("should be empty");
Ok(DataFrame::empty())
},
}
}
}

pub struct CachePrefiller {
pub caches: PlIndexMap<usize, Box<dyn Executor>>,
pub phys_plan: Box<dyn Executor>,
}

impl Executor for CachePrefiller {
fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult<DataFrame> {
if state.verbose() {
if cache_hit {
eprintln!("CACHE HIT: cache id: {:x}", self.id);
} else {
eprintln!("CACHE SET: cache id: {:x}", self.id);
}
eprintln!("PREFILL CACHES")
}

Ok(df.clone())
// Ensure we traverse in discovery order. This will ensure that caches aren't dependent on each
// other.
for cache in self.caches.values_mut() {
let _df = cache.execute(state)?;
}
if state.verbose() {
eprintln!("EXECUTE PHYS PLAN")
}
self.phys_plan.execute(state)
}
}
8 changes: 1 addition & 7 deletions crates/polars-mem-engine/src/executors/merge_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ pub(crate) struct MergeSorted {
pub(crate) input_left: Box<dyn Executor>,
pub(crate) input_right: Box<dyn Executor>,
pub(crate) key: PlSmallStr,
pub(crate) parallel: bool,
}

impl Executor for MergeSorted {
Expand All @@ -18,19 +17,14 @@ impl Executor for MergeSorted {
eprintln!("run MergeSorted")
}
}
let (left, right) = if self.parallel {
let (left, right) = {
let mut state2 = state.split();
state2.branch_idx += 1;
let (left, right) = POOL.join(
|| self.input_left.execute(state),
|| self.input_right.execute(&mut state2),
);
(left?, right?)
} else {
(
self.input_left.execute(state)?,
self.input_right.execute(state)?,
)
};

let profile_name = Cow::Borrowed("Merge Sorted");
Expand Down
Loading

0 comments on commit 251fab5

Please sign in to comment.