Skip to content

Commit

Permalink
Fix unnest rewriting logic (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrebnov authored Dec 24, 2024
1 parent 04487ba commit 72374f6
Showing 1 changed file with 57 additions and 20 deletions.
77 changes: 57 additions & 20 deletions sources/sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{any::Any, collections::HashMap, sync::Arc, vec};
use async_trait::async_trait;
use datafusion::{
arrow::datatypes::{Schema, SchemaRef},
common::Column,
common::{Column, RecursionUnnestOption, UnnestOptions},
config::ConfigOptions,
error::{DataFusionError, Result},
execution::{context::SessionState, TaskContext},
Expand Down Expand Up @@ -208,7 +208,7 @@ fn rewrite_unnest_plan(
));
};

// rewrite aliases in inner projection; columns were rewritten via `rewrite_table_scans_in_expr``
// rewrite aliases in inner projection; columns were rewritten via `rewrite_table_scans_in_expr`
let new_expressions = projection
.expr
.into_iter()
Expand All @@ -227,15 +227,66 @@ fn rewrite_unnest_plan(
let updated_unnest_inner_projection =
Projection::try_new(new_expressions, Arc::clone(&projection.input))?;

let unnest_options = rewrite_unnest_options(&unnest.options, known_rewrites);

// reconstruct the unnest plan with updated projection and rewritten column names
let new_plan =
LogicalPlanBuilder::new(LogicalPlan::Projection(updated_unnest_inner_projection))
.unnest_columns_with_options(unnest_columns, unnest.options.clone())?
.unnest_columns_with_options(unnest_columns, unnest_options)?
.build()?;

Ok(new_plan)
}

/// Rewrites columns names in the unnest options to use the original federated table name:
/// "unnest_placeholder(foo.df_table.a,depth=1)"" -> "unnest_placeholder(remote_table.a,depth=1)""
fn rewrite_unnest_options(
options: &UnnestOptions,
known_rewrites: &HashMap<TableReference, TableReference>,
) -> UnnestOptions {
let mut new_options = options.clone();
new_options
.recursions
.iter_mut()
.for_each(|x: &mut RecursionUnnestOption| {
if let Some(new_name) = rewrite_column_name(&x.input_column.name, known_rewrites) {
x.input_column.name = new_name;
}

if let Some(new_name) = rewrite_column_name(&x.output_column.name, known_rewrites) {
x.output_column.name = new_name;
}
});
new_options
}

/// Checks if any of the rewrites match any substring in col_name, and replace that part of the string if so.
/// This handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)"
/// Returns the rewritten name if any rewrite was applied, otherwise None.
fn rewrite_column_name(
col_name: &str,
known_rewrites: &HashMap<TableReference, TableReference>,
) -> Option<String> {
let (new_col_name, was_rewritten) = known_rewrites.iter().fold(
(col_name.to_string(), false),
|(col_name, was_rewritten), (table_ref, rewrite)| match rewrite_column_name_in_expr(
&col_name,
&table_ref.to_string(),
&rewrite.to_string(),
0,
) {
Some(new_name) => (new_name, true),
None => (col_name, was_rewritten),
},
);

if was_rewritten {
Some(new_col_name)
} else {
None
}
}

// The function replaces occurrences of table_ref_str in col_name with the new name defined by rewrite.
// The name to rewrite should NOT be a substring of another name.
// Supports multiple occurrences of table_ref_str in col_name.
Expand Down Expand Up @@ -344,21 +395,7 @@ fn rewrite_table_scans_in_expr(

// Check if any of the rewrites match any substring in col.name, and replace that part of the string if so.
// This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)"
let (new_name, was_rewritten) = known_rewrites.iter().fold(
(col.name.to_string(), false),
|(col_name, was_rewritten), (table_ref, rewrite)| {
match rewrite_column_name_in_expr(
&col_name,
&table_ref.to_string(),
&rewrite.to_string(),
0,
) {
Some(new_name) => (new_name, true),
None => (col_name, was_rewritten),
}
},
);
if was_rewritten {
if let Some(new_name) = rewrite_column_name(&col.name, known_rewrites) {
Ok(Expr::Column(Column::new(col.relation.take(), new_name)))
} else {
Ok(Expr::Column(col))
Expand Down Expand Up @@ -1020,11 +1057,11 @@ mod tests {
let tests = vec![
(
"SELECT UNNEST([1, 2, 2, 5, NULL]), b, c from app_table where a > 10 order by b limit 10;",
r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)), remote_table.b, remote_table.c FROM remote_table WHERE (remote_table.a > 10) ORDER BY remote_table.b ASC NULLS LAST LIMIT 10"#,
r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL))", remote_table.b, remote_table.c FROM remote_table WHERE (remote_table.a > 10) ORDER BY remote_table.b ASC NULLS LAST LIMIT 10"#,
),
(
"SELECT UNNEST(app_table.d), b, c from app_table where a > 10 order by b limit 10;",
r#"SELECT UNNEST(remote_table.d), remote_table.b, remote_table.c FROM remote_table WHERE (remote_table.a > 10) ORDER BY remote_table.b ASC NULLS LAST LIMIT 10"#,
r#"SELECT UNNEST(remote_table.d) AS "UNNEST(app_table.d)", remote_table.b, remote_table.c FROM remote_table WHERE (remote_table.a > 10) ORDER BY remote_table.b ASC NULLS LAST LIMIT 10"#,
),
(
"SELECT sum(b.x) AS total FROM (SELECT UNNEST(d) AS x from app_table where a > 0) AS b;",
Expand Down

0 comments on commit 72374f6

Please sign in to comment.