diff --git a/Cargo.lock b/Cargo.lock index dcf382913791..53a9238a672e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3281,6 +3281,7 @@ dependencies = [ "polars-utils", "pyo3", "rayon", + "recursive", "tokio", ] diff --git a/crates/polars-mem-engine/Cargo.toml b/crates/polars-mem-engine/Cargo.toml index 3f9e44784407..562918bf0763 100644 --- a/crates/polars-mem-engine/Cargo.toml +++ b/crates/polars-mem-engine/Cargo.toml @@ -23,6 +23,7 @@ polars-time = { workspace = true, optional = true } polars-utils = { workspace = true } pyo3 = { workspace = true, optional = true } rayon = { workspace = true } +recursive = { workspace = true } tokio = { workspace = true, optional = true } [features] diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index 57759baffa42..4a93027a392f 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -4,6 +4,7 @@ use polars_expr::state::ExecutionState; use polars_plan::global::_set_n_rows_for_scan; use polars_plan::plans::expr_ir::ExprIR; use polars_utils::format_pl_smallstr; +use recursive::recursive; use self::expr_ir::OutputName; use self::predicates::{aexpr_to_column_predicates, aexpr_to_skip_batch_predicate}; @@ -50,22 +51,24 @@ fn partitionable_gb( #[derive(Clone)] struct ConversionState { expr_depth: u16, - has_cache: bool, + has_cache_child: bool, + has_cache_parent: bool, } impl ConversionState { fn new() -> PolarsResult { Ok(ConversionState { expr_depth: get_expr_depth_limit()?, - has_cache: false, + has_cache_child: false, + has_cache_parent: false, }) } fn with_new_branch K>(&mut self, func: F) -> (K, Self) { let mut new_state = self.clone(); - new_state.has_cache = false; + new_state.has_cache_child = false; let out = func(&mut new_state); - self.has_cache = new_state.has_cache; + self.has_cache_child = new_state.has_cache_child; (out, new_state) } } @@ -79,6 +82,7 @@ pub fn create_physical_plan( create_physical_plan_impl(root, lp_arena, expr_arena, &mut state) } +#[recursive] fn create_physical_plan_impl( root: Node, lp_arena: &mut Arena, @@ -87,7 +91,12 @@ fn create_physical_plan_impl( ) -> PolarsResult> { use IR::*; - let logical_plan = lp_arena.take(root); + let logical_plan = if state.has_cache_parent { + lp_arena.get(root).clone() + } else { + lp_arena.take(root) + }; + match logical_plan { #[cfg(feature = "python")] PythonScan { mut options } => { @@ -156,7 +165,7 @@ fn create_physical_plan_impl( .map(|node| create_physical_plan_impl(node, lp_arena, expr_arena, new_state)) .collect::>>() }); - if new_state.has_cache { + if new_state.has_cache_child { options.parallel = false } let inputs = inputs?; @@ -174,7 +183,7 @@ fn create_physical_plan_impl( .collect::>>() }); - if new_state.has_cache { + if new_state.has_cache_child { options.parallel = false } let inputs = inputs?; @@ -413,8 +422,9 @@ fn create_physical_plan_impl( id, cache_hits, } => { + state.has_cache_parent = true; let input = create_physical_plan_impl(input, lp_arena, expr_arena, state)?; - state.has_cache = true; + state.has_cache_child = true; Ok(Box::new(executors::CacheExec { id, input, @@ -549,7 +559,7 @@ fn create_physical_plan_impl( let parallel = if options.force_parallel { true } else if options.allow_parallel { - !new_state.has_cache + !new_state.has_cache_child } else { false }; @@ -682,7 +692,7 @@ fn create_physical_plan_impl( input_left, input_right, key, - parallel: new_state.has_cache, + parallel: new_state.has_cache_child, }; Ok(Box::new(exec)) }, diff --git a/crates/polars-plan/src/plans/optimizer/cache_states.rs b/crates/polars-plan/src/plans/optimizer/cache_states.rs index 1dbf7e8fc340..5a568570b5e0 100644 --- a/crates/polars-plan/src/plans/optimizer/cache_states.rs +++ b/crates/polars-plan/src/plans/optimizer/cache_states.rs @@ -328,52 +328,63 @@ pub(super) fn set_cache_states( } return Ok(()); } + // Below we restart projection and predicates pushdown + // on the first cache node. As it are cache nodes, the others are the same + // and we can reuse the optimized state for all inputs. + // See #21637 // # RUN PROJECTION PUSHDOWN if !v.names_union.is_empty() { - for &child in &v.children { - let columns = &v.names_union; - let child_lp = lp_arena.take(child); - - // Make sure we project in the order of the schema - // if we don't a union may fail as we would project by the - // order we discovered all values. - let child_schema = child_lp.schema(lp_arena); - let child_schema = child_schema.as_ref(); - let projection = child_schema - .iter_names() - .flat_map(|name| columns.get(name.as_str()).cloned()) - .collect::>(); - - let new_child = lp_arena.add(child_lp); - - let lp = IRBuilder::new(new_child, expr_arena, lp_arena) - .project_simple(projection) - .expect("unique names") - .build(); - - let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?; - // Optimization can lead to a double projection. Only take the last. - let lp = if let IR::SimpleProjection { input, columns } = lp { - let input = if let IR::SimpleProjection { input: input2, .. } = - lp_arena.get(input) - { + let first_child = *v.children.first().expect("at least on child"); + + let columns = &v.names_union; + let child_lp = lp_arena.take(first_child); + + // Make sure we project in the order of the schema + // if we don't a union may fail as we would project by the + // order we discovered all values. + let child_schema = child_lp.schema(lp_arena); + let child_schema = child_schema.as_ref(); + let projection = child_schema + .iter_names() + .flat_map(|name| columns.get(name.as_str()).cloned()) + .collect::>(); + + let new_child = lp_arena.add(child_lp); + + let lp = IRBuilder::new(new_child, expr_arena, lp_arena) + .project_simple(projection) + .expect("unique names") + .build(); + + let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?; + // Optimization can lead to a double projection. Only take the last. + let lp = if let IR::SimpleProjection { input, columns } = lp { + let input = + if let IR::SimpleProjection { input: input2, .. } = lp_arena.get(input) { *input2 } else { input }; - IR::SimpleProjection { input, columns } - } else { - lp - }; - lp_arena.replace(child, lp); + IR::SimpleProjection { input, columns } + } else { + lp + }; + lp_arena.replace(first_child, lp.clone()); + + // Set the remaining children to the same node. + for &child in &v.children[1..] { + lp_arena.replace(child, lp.clone()); } } else { // No upper projections to include, run projection pushdown from cache node. - for &child in &v.children { - let child_lp = lp_arena.take(child); - let lp = proj_pd.optimize(child_lp, lp_arena, expr_arena)?; - lp_arena.replace(child, lp); + let first_child = *v.children.first().expect("at least on child"); + let child_lp = lp_arena.take(first_child); + let lp = proj_pd.optimize(child_lp, lp_arena, expr_arena)?; + lp_arena.replace(first_child, lp.clone()); + + for &child in &v.children[1..] { + lp_arena.replace(child, lp.clone()); } } @@ -387,17 +398,25 @@ pub(super) fn set_cache_states( *count == v.children.len() as u32 }; - for (&child, parents) in v.children.iter().zip(v.parents) { - if allow_parent_predicate_pushdown { + if allow_parent_predicate_pushdown { + let parents = *v.parents.first().unwrap(); + let node = get_filter_node(parents, lp_arena) + .expect("expected filter; this is an optimizer bug"); + let start_lp = lp_arena.take(node); + let lp = pred_pd.optimize(start_lp, lp_arena, expr_arena)?; + lp_arena.replace(node, lp.clone()); + for &parents in &v.parents[1..] { let node = get_filter_node(parents, lp_arena) .expect("expected filter; this is an optimizer bug"); - let start_lp = lp_arena.take(node); - let lp = pred_pd.optimize(start_lp, lp_arena, expr_arena)?; - lp_arena.replace(node, lp); - } else { - let child_lp = lp_arena.take(child); - let lp = pred_pd.optimize(child_lp, lp_arena, expr_arena)?; - lp_arena.replace(child, lp); + lp_arena.replace(node, lp.clone()); + } + } else { + let child = *v.children.first().unwrap(); + let child_lp = lp_arena.take(child); + let lp = pred_pd.optimize(child_lp, lp_arena, expr_arena)?; + lp_arena.replace(child, lp.clone()); + for &child in &v.children[1..] { + lp_arena.replace(child, lp.clone()); } } }