Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Optimize only a single cache input #21644

Merged
merged 2 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/polars-mem-engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ polars-time = { workspace = true, optional = true }
polars-utils = { workspace = true }
pyo3 = { workspace = true, optional = true }
rayon = { workspace = true }
recursive = { workspace = true }
tokio = { workspace = true, optional = true }

[features]
Expand Down
30 changes: 20 additions & 10 deletions crates/polars-mem-engine/src/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use polars_expr::state::ExecutionState;
use polars_plan::global::_set_n_rows_for_scan;
use polars_plan::plans::expr_ir::ExprIR;
use polars_utils::format_pl_smallstr;
use recursive::recursive;

use self::expr_ir::OutputName;
use self::predicates::{aexpr_to_column_predicates, aexpr_to_skip_batch_predicate};
Expand Down Expand Up @@ -50,22 +51,24 @@ fn partitionable_gb(
#[derive(Clone)]
struct ConversionState {
expr_depth: u16,
has_cache: bool,
has_cache_child: bool,
has_cache_parent: bool,
}

impl ConversionState {
fn new() -> PolarsResult<Self> {
Ok(ConversionState {
expr_depth: get_expr_depth_limit()?,
has_cache: false,
has_cache_child: false,
has_cache_parent: false,
})
}

fn with_new_branch<K, F: FnOnce(&mut Self) -> K>(&mut self, func: F) -> (K, Self) {
let mut new_state = self.clone();
new_state.has_cache = false;
new_state.has_cache_child = false;
let out = func(&mut new_state);
self.has_cache = new_state.has_cache;
self.has_cache_child = new_state.has_cache_child;
(out, new_state)
}
}
Expand All @@ -79,6 +82,7 @@ pub fn create_physical_plan(
create_physical_plan_impl(root, lp_arena, expr_arena, &mut state)
}

#[recursive]
fn create_physical_plan_impl(
root: Node,
lp_arena: &mut Arena<IR>,
Expand All @@ -87,7 +91,12 @@ fn create_physical_plan_impl(
) -> PolarsResult<Box<dyn Executor>> {
use IR::*;

let logical_plan = lp_arena.take(root);
let logical_plan = if state.has_cache_parent {
lp_arena.get(root).clone()
} else {
lp_arena.take(root)
};

match logical_plan {
#[cfg(feature = "python")]
PythonScan { mut options } => {
Expand Down Expand Up @@ -156,7 +165,7 @@ fn create_physical_plan_impl(
.map(|node| create_physical_plan_impl(node, lp_arena, expr_arena, new_state))
.collect::<PolarsResult<Vec<_>>>()
});
if new_state.has_cache {
if new_state.has_cache_child {
options.parallel = false
}
let inputs = inputs?;
Expand All @@ -174,7 +183,7 @@ fn create_physical_plan_impl(
.collect::<PolarsResult<Vec<_>>>()
});

if new_state.has_cache {
if new_state.has_cache_child {
options.parallel = false
}
let inputs = inputs?;
Expand Down Expand Up @@ -413,8 +422,9 @@ fn create_physical_plan_impl(
id,
cache_hits,
} => {
state.has_cache_parent = true;
let input = create_physical_plan_impl(input, lp_arena, expr_arena, state)?;
state.has_cache = true;
state.has_cache_child = true;
Ok(Box::new(executors::CacheExec {
id,
input,
Expand Down Expand Up @@ -549,7 +559,7 @@ fn create_physical_plan_impl(
let parallel = if options.force_parallel {
true
} else if options.allow_parallel {
!new_state.has_cache
!new_state.has_cache_child
} else {
false
};
Expand Down Expand Up @@ -682,7 +692,7 @@ fn create_physical_plan_impl(
input_left,
input_right,
key,
parallel: new_state.has_cache,
parallel: new_state.has_cache_child,
};
Ok(Box::new(exec))
},
Expand Down
109 changes: 64 additions & 45 deletions crates/polars-plan/src/plans/optimizer/cache_states.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,52 +328,63 @@ pub(super) fn set_cache_states(
}
return Ok(());
}
// Below we restart projection and predicates pushdown
// on the first cache node. As it are cache nodes, the others are the same
// and we can reuse the optimized state for all inputs.
// See #21637

// # RUN PROJECTION PUSHDOWN
if !v.names_union.is_empty() {
for &child in &v.children {
let columns = &v.names_union;
let child_lp = lp_arena.take(child);

// Make sure we project in the order of the schema
// if we don't a union may fail as we would project by the
// order we discovered all values.
let child_schema = child_lp.schema(lp_arena);
let child_schema = child_schema.as_ref();
let projection = child_schema
.iter_names()
.flat_map(|name| columns.get(name.as_str()).cloned())
.collect::<Vec<_>>();

let new_child = lp_arena.add(child_lp);

let lp = IRBuilder::new(new_child, expr_arena, lp_arena)
.project_simple(projection)
.expect("unique names")
.build();

let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?;
// Optimization can lead to a double projection. Only take the last.
let lp = if let IR::SimpleProjection { input, columns } = lp {
let input = if let IR::SimpleProjection { input: input2, .. } =
lp_arena.get(input)
{
let first_child = *v.children.first().expect("at least on child");

let columns = &v.names_union;
let child_lp = lp_arena.take(first_child);

// Make sure we project in the order of the schema
// if we don't a union may fail as we would project by the
// order we discovered all values.
let child_schema = child_lp.schema(lp_arena);
let child_schema = child_schema.as_ref();
let projection = child_schema
.iter_names()
.flat_map(|name| columns.get(name.as_str()).cloned())
.collect::<Vec<_>>();

let new_child = lp_arena.add(child_lp);

let lp = IRBuilder::new(new_child, expr_arena, lp_arena)
.project_simple(projection)
.expect("unique names")
.build();

let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?;
// Optimization can lead to a double projection. Only take the last.
let lp = if let IR::SimpleProjection { input, columns } = lp {
let input =
if let IR::SimpleProjection { input: input2, .. } = lp_arena.get(input) {
*input2
} else {
input
};
IR::SimpleProjection { input, columns }
} else {
lp
};
lp_arena.replace(child, lp);
IR::SimpleProjection { input, columns }
} else {
lp
};
lp_arena.replace(first_child, lp.clone());

// Set the remaining children to the same node.
for &child in &v.children[1..] {
lp_arena.replace(child, lp.clone());
}
} else {
// No upper projections to include, run projection pushdown from cache node.
for &child in &v.children {
let child_lp = lp_arena.take(child);
let lp = proj_pd.optimize(child_lp, lp_arena, expr_arena)?;
lp_arena.replace(child, lp);
let first_child = *v.children.first().expect("at least on child");
let child_lp = lp_arena.take(first_child);
let lp = proj_pd.optimize(child_lp, lp_arena, expr_arena)?;
lp_arena.replace(first_child, lp.clone());

for &child in &v.children[1..] {
lp_arena.replace(child, lp.clone());
}
}

Expand All @@ -387,17 +398,25 @@ pub(super) fn set_cache_states(
*count == v.children.len() as u32
};

for (&child, parents) in v.children.iter().zip(v.parents) {
if allow_parent_predicate_pushdown {
if allow_parent_predicate_pushdown {
let parents = *v.parents.first().unwrap();
let node = get_filter_node(parents, lp_arena)
.expect("expected filter; this is an optimizer bug");
let start_lp = lp_arena.take(node);
let lp = pred_pd.optimize(start_lp, lp_arena, expr_arena)?;
lp_arena.replace(node, lp.clone());
for &parents in &v.parents[1..] {
let node = get_filter_node(parents, lp_arena)
.expect("expected filter; this is an optimizer bug");
let start_lp = lp_arena.take(node);
let lp = pred_pd.optimize(start_lp, lp_arena, expr_arena)?;
lp_arena.replace(node, lp);
} else {
let child_lp = lp_arena.take(child);
let lp = pred_pd.optimize(child_lp, lp_arena, expr_arena)?;
lp_arena.replace(child, lp);
lp_arena.replace(node, lp.clone());
}
} else {
let child = *v.children.first().unwrap();
let child_lp = lp_arena.take(child);
let lp = pred_pd.optimize(child_lp, lp_arena, expr_arena)?;
lp_arena.replace(child, lp.clone());
for &child in &v.children[1..] {
lp_arena.replace(child, lp.clone());
}
}
}
Expand Down
Loading