diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 460ef7400e90..66bd6b75123e 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -268,8 +268,8 @@ fn join_keys_in_subquery_alias_1() { fn push_down_filter_groupby_expr_contains_alias() { let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3"; let plan = test_sql(sql).unwrap(); - let expected = "Projection: test.col_int32 + test.col_uint32 AS c, count(*)\ - \n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[count(Int64(1)) AS count(*)]]\ + let expected = "Projection: test.col_int32 + test.col_uint32 AS c, count(Int64(1)) AS count(*)\ + \n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[count(Int64(1))]]\ \n Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\ \n TableScan: test projection=[col_int32, col_uint32]"; assert_eq!(expected, format!("{plan}")); @@ -312,10 +312,9 @@ fn eliminate_redundant_null_check_on_count() { GROUP BY col_int32 HAVING c IS NOT NULL"; let plan = test_sql(sql).unwrap(); - let expected = "\ - Projection: test.col_int32, count(*) AS c\ - \n Aggregate: groupBy=[[test.col_int32]], aggr=[[count(Int64(1)) AS count(*)]]\ - \n TableScan: test projection=[col_int32]"; + let expected = "Projection: test.col_int32, count(Int64(1)) AS count(*) AS c\ + \n Aggregate: groupBy=[[test.col_int32]], aggr=[[count(Int64(1))]]\ + \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan}")); } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 07189fbaecfc..b83effbf7f87 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -684,40 +684,50 @@ async fn roundtrip_union_all() -> Result<()> { #[tokio::test] async fn simple_intersect() -> Result<()> { - // Substrait treats both count(*) and count(1) the same - assert_expected_plan( - "SELECT count(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", - "Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\ - \n Projection: \ - \n LeftSemi Join: data.a = data2.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data2 projection=[a]", - true - ).await?; - - assert_expected_plan( - "SELECT count() FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", - "Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count()]]\ - \n Projection: \ - \n LeftSemi Join: data.a = data2.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data2 projection=[a]", - true - ).await?; + async fn check_wildcard(syntax: &str) -> Result<()> { + let expected_plan_str = format!( + "Projection: count(Int64(1)) AS {syntax}\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ + \n Projection: \ + \n LeftSemi Join: data.a = data2.a\ + \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ + \n TableScan: data projection=[a]\ + \n TableScan: data2 projection=[a]" + ); + + assert_expected_plan( + &format!("SELECT {syntax} FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);"), + &expected_plan_str, + true + ).await + } - assert_expected_plan( - "SELECT count(1) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", - "Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ - \n Projection: \ - \n LeftSemi Join: data.a = data2.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data2 projection=[a]", - true - ).await?; + async fn check_constant(sql_syntax: &str, plan_expr: &str) -> Result<()> { + let expected_plan_str = format!( + "Aggregate: groupBy=[[]], aggr=[[{plan_expr}]]\ + \n Projection: \ + \n LeftSemi Join: data.a = data2.a\ + \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ + \n TableScan: data projection=[a]\ + \n TableScan: data2 projection=[a]" + ); + + assert_expected_plan( + &format!("SELECT {sql_syntax} FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);"), + &expected_plan_str, + true + ).await + } + check_wildcard("count(*)").await?; + check_wildcard("count()").await?; + check_constant("count(1)", "count(Int64(1))").await?; + check_constant("count(2)", "count(Int64(2))").await?; + check_constant( + "count(1 + 2)", + "count(Int64(3)) AS count(Int64(1) + Int64(2))", + ) + .await?; Ok(()) } @@ -843,44 +853,55 @@ async fn simple_intersect_table_reuse() -> Result<()> { // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. // In this case the aliasing happens at a different point in the plan, so we cannot use roundtrip. // Schema check works because we set aliases to what the Substrait consumer will generate. - assert_expected_plan( - "SELECT count(*) FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);", - "Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\ - \n Projection: \ - \n LeftSemi Join: left.a = right.a\ - \n SubqueryAlias: left\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n SubqueryAlias: right\ - \n TableScan: data projection=[a]", - true - ).await?; - assert_expected_plan( - "SELECT count() FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);", - "Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count()]]\ - \n Projection: \ - \n LeftSemi Join: left.a = right.a\ - \n SubqueryAlias: left\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n SubqueryAlias: right\ - \n TableScan: data projection=[a]", - true - ).await?; + async fn check_wildcard(syntax: &str) -> Result<()> { + let expected_plan_str = format!( + "Projection: count(Int64(1)) AS {syntax}\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ + \n Projection: \ + \n LeftSemi Join: left.a = right.a\ + \n SubqueryAlias: left\ + \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ + \n TableScan: data projection=[a]\ + \n SubqueryAlias: right\ + \n TableScan: data projection=[a]" + ); + + assert_expected_plan( + &format!("SELECT {syntax} FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);"), + &expected_plan_str, + true + ).await + } - assert_expected_plan( - "SELECT count(1) FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);", - "Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ + async fn check_constant(sql_syntax: &str, plan_expr: &str) -> Result<()> { + let expected_plan_str = format!( + "Aggregate: groupBy=[[]], aggr=[[{plan_expr}]]\ \n Projection: \ \n LeftSemi Join: left.a = right.a\ \n SubqueryAlias: left\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ \n TableScan: data projection=[a]\ \n SubqueryAlias: right\ - \n TableScan: data projection=[a]", - true - ).await?; + \n TableScan: data projection=[a]" + ); + + assert_expected_plan( + &format!("SELECT {sql_syntax} FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);"), + &expected_plan_str, + true + ).await + } + + check_wildcard("count(*)").await?; + check_wildcard("count()").await?; + check_constant("count(1)", "count(Int64(1))").await?; + check_constant("count(2)", "count(Int64(2))").await?; + check_constant( + "count(1 + 2)", + "count(Int64(3)) AS count(Int64(1) + Int64(2))", + ) + .await?; Ok(()) }