From 040224b0aa77244aa33a0cfcd5b2c9d5b11c9bcc Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 1 Feb 2025 12:14:06 -0500 Subject: [PATCH] Fix join type coercion (#14387) --- datafusion/core/tests/dataframe/mod.rs | 37 ++++++++++++++++++- .../optimizer/src/analyzer/type_coercion.rs | 23 +++++++++--- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 1ebbf92c736e..eed11f634c9d 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -30,8 +30,8 @@ use arrow::{ record_batch::RecordBatch, }; use arrow_array::{ - Array, BooleanArray, DictionaryArray, Float32Array, Float64Array, Int8Array, - UnionArray, + record_batch, Array, BooleanArray, DictionaryArray, Float32Array, Float64Array, + Int8Array, UnionArray, }; use arrow_buffer::ScalarBuffer; use arrow_schema::{ArrowError, SchemaRef, UnionFields, UnionMode}; @@ -1121,6 +1121,39 @@ async fn join() -> Result<()> { Ok(()) } +#[tokio::test] +async fn join_coercion_unnnamed() -> Result<()> { + let ctx = SessionContext::new(); + + // Test that join will coerce column types when necessary + // even when the relations don't have unique names + let left = ctx.read_batch(record_batch!( + ("id", Int32, [1, 2, 3]), + ("name", Utf8, ["a", "b", "c"]) + )?)?; + let right = ctx.read_batch(record_batch!( + ("id", Int32, [10, 3]), + ("name", Utf8View, ["d", "c"]) // Utf8View is a different type + )?)?; + let cols = vec!["name", "id"]; + + let filter = None; + let join = right.join(left, JoinType::LeftAnti, &cols, &cols, filter)?; + let results = join.collect().await?; + + assert_batches_sorted_eq!( + [ + "+----+------+", + "| id | name |", + "+----+------+", + "| 10 | d |", + "+----+------+", + ], + &results + ); + Ok(()) +} + #[tokio::test] async fn join_on() -> Result<()> { let left = test_table_with_name("a") diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 48a5e2f9a07c..7a41f54c56e1 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -190,7 +190,15 @@ impl<'a> TypeCoercionRewriter<'a> { .map(|(lhs, rhs)| { // coerce the arguments as though they were a single binary equality // expression - let (lhs, rhs) = self.coerce_binary_op(lhs, Operator::Eq, rhs)?; + let left_schema = join.left.schema(); + let right_schema = join.right.schema(); + let (lhs, rhs) = self.coerce_binary_op( + lhs, + left_schema, + Operator::Eq, + rhs, + right_schema, + )?; Ok((lhs, rhs)) }) .collect::>>()?; @@ -275,17 +283,19 @@ impl<'a> TypeCoercionRewriter<'a> { fn coerce_binary_op( &self, left: Expr, + left_schema: &DFSchema, op: Operator, right: Expr, + right_schema: &DFSchema, ) -> Result<(Expr, Expr)> { let (left_type, right_type) = get_input_types( - &left.get_type(self.schema)?, + &left.get_type(left_schema)?, &op, - &right.get_type(self.schema)?, + &right.get_type(right_schema)?, )?; Ok(( - left.cast_to(&left_type, self.schema)?, - right.cast_to(&right_type, self.schema)?, + left.cast_to(&left_type, left_schema)?, + right.cast_to(&right_type, right_schema)?, )) } } @@ -404,7 +414,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { )))) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let (left, right) = self.coerce_binary_op(*left, op, *right)?; + let (left, right) = + self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), op,