From 800c9ac00ae863dc242c6db06610621a19e86e96 Mon Sep 17 00:00:00 2001 From: blaginin Date: Sat, 15 Feb 2025 18:14:33 +0000 Subject: [PATCH] Reuse last projection layer when renaming column --- datafusion/core/src/dataframe/mod.rs | 99 +++++++++++++++++++------- datafusion/core/tests/dataframe/mod.rs | 18 +++-- 2 files changed, 86 insertions(+), 31 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index b1eb2a19e31d..19ea10ceada0 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -52,7 +52,7 @@ use datafusion_common::{ SchemaError, UnnestOptions, }; use datafusion_expr::dml::InsertOp; -use datafusion_expr::{case, is_null, lit, SortExpr}; +use datafusion_expr::{case, is_null, lit, Projection, SortExpr}; use datafusion_expr::{ utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; @@ -1770,7 +1770,7 @@ impl DataFrame { /// # } /// ``` pub fn with_column_renamed( - self, + mut self, old_name: impl Into, new_name: &str, ) -> Result { @@ -1779,37 +1779,82 @@ impl DataFrame { .config_options() .sql_parser .enable_ident_normalization; + let old_column: Column = if ident_opts { Column::from_qualified_name(old_name) } else { Column::from_qualified_name_ignore_case(old_name) }; - let (qualifier_rename, field_rename) = - match self.plan.schema().qualified_field_from_column(&old_column) { - Ok(qualifier_and_field) => qualifier_and_field, - // no-op if field not found - Err(DataFusionError::SchemaError( - SchemaError::FieldNotFound { .. }, - _, - )) => return Ok(self), - Err(err) => return Err(err), - }; - let projection = self - .plan - .schema() - .iter() - .map(|(qualifier, field)| { - if qualifier.eq(&qualifier_rename) && field.as_ref() == field_rename { - col(Column::from((qualifier, field))).alias(new_name) - } else { - col(Column::from((qualifier, field))) - } - }) - .collect::>(); - let project_plan = LogicalPlanBuilder::from(self.plan) - .project(projection)? - .build()?; + let project_plan = if let LogicalPlan::Projection(Projection { + expr, + input, + schema, + .. + }) = self.plan + { + // special case: we already have a projection on top, so we can reuse it rather than creating a new one + let (qualifier_rename, field_rename) = + match schema.qualified_field_from_column(&old_column) { + Ok(qualifier_and_field) => qualifier_and_field, + // no-op if field not found + Err(DataFusionError::SchemaError( + SchemaError::FieldNotFound { .. }, + _, + )) => { + self.plan = LogicalPlan::Projection( + Projection::try_new_with_schema(expr, input, schema)?, + ); + return Ok(self); + } + Err(err) => return Err(err), + }; + + let expr: Vec<_> = expr + .into_iter() + .map(|e| { + let (qualifier, field) = e.qualified_name(); + + if qualifier.as_ref().eq(&qualifier_rename) + && field.as_str() == field_rename.name() + { + e.alias_qualified(qualifier, new_name.to_string()) + } else { + e + } + }) + .collect(); + LogicalPlan::Projection(Projection::try_new(expr, input)?) + } else { + let (qualifier_rename, field_rename) = + match self.plan.schema().qualified_field_from_column(&old_column) { + Ok(qualifier_and_field) => qualifier_and_field, + // no-op if field not found + Err(DataFusionError::SchemaError( + SchemaError::FieldNotFound { .. }, + _, + )) => return Ok(self), + Err(err) => return Err(err), + }; + + let projection = self + .plan + .schema() + .iter() + .map(|(qualifier, field)| { + if qualifier.eq(&qualifier_rename) && field.as_ref() == field_rename { + col(Column::from((qualifier, field))).alias(new_name) + } else { + col(Column::from((qualifier, field))) + } + }) + .collect::>(); + + LogicalPlanBuilder::from(self.plan) + .project(projection)? + .build()? + }; + Ok(DataFrame { session_state: self.session_state, plan: project_plan, diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 8155fd6a2ff9..9a03ef518308 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1617,9 +1617,19 @@ async fn with_column_renamed() -> Result<()> { // accepts table qualifier .with_column_renamed("aggregate_test_100.c2", "two")? // no-op for missing column - .with_column_renamed("c4", "boom")? - .collect() - .await?; + .with_column_renamed("c4", "boom")?; + + assert_eq!("\ + Projection: aggregate_test_100.c1 AS one, aggregate_test_100.c2 AS two, aggregate_test_100.c3, aggregate_test_100.c2 + aggregate_test_100.c3 AS sum AS total\ + \n Limit: skip=0, fetch=1\ + \n Sort: aggregate_test_100.c1 ASC NULLS FIRST, aggregate_test_100.c2 ASC NULLS FIRST, aggregate_test_100.c3 ASC NULLS FIRST\ + \n Filter: aggregate_test_100.c2 = Int32(3) AND aggregate_test_100.c1 = Utf8(\"a\")\ + \n Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c3\ + \n TableScan: aggregate_test_100", + format!("{}", df_sum_renamed.logical_plan()) // one projection is reused for all renames + ); + + let batches = df_sum_renamed.collect().await?; assert_batches_sorted_eq!( [ @@ -1629,7 +1639,7 @@ async fn with_column_renamed() -> Result<()> { "| a | 3 | -72 | -69 |", "+-----+-----+-----+-------+", ], - &df_sum_renamed + &batches ); Ok(())