This is an automated email from the ASF dual-hosted git repository.

jonah pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8356c94b83 Handle columns in with_new_exprs with a Join (#15055)
8356c94b83 is described below

commit 8356c94b83727cc21d1bb1364b39a555ce3794cf
Author: delamarch3 <[email protected]>
AuthorDate: Sat Mar 8 01:20:13 2025 +0000

    Handle columns in with_new_exprs with a Join (#15055)
    
    * handle columns in with_new_exprs with Join
    
    * test doesn't return result
    
    * take join from result
    
    * clippy
    
    * make test fallible
    
    * accept any pair of expression for new_on in with_new_exprs for Join
    
    * use with_capacity
---
 datafusion/expr/src/logical_plan/plan.rs | 130 ++++++++++++++++++++++++++++---
 1 file changed, 118 insertions(+), 12 deletions(-)

diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 72b82fc219..682342d27b 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -903,7 +903,7 @@ impl LogicalPlan {
                 let (left, right) = self.only_two_inputs(inputs)?;
                 let schema = build_join_schema(left.schema(), right.schema(), 
join_type)?;
 
-                let equi_expr_count = on.len();
+                let equi_expr_count = on.len() * 2;
                 assert!(expr.len() >= equi_expr_count);
 
                 // Assume that the last expr, if any,
@@ -917,17 +917,16 @@ impl LogicalPlan {
                 // The first part of expr is equi-exprs,
                 // and the struct of each equi-expr is like `left-expr = 
right-expr`.
                 assert_eq!(expr.len(), equi_expr_count);
-                let new_on = expr.into_iter().map(|equi_expr| {
+                let mut new_on = Vec::with_capacity(on.len());
+                let mut iter = expr.into_iter();
+                while let Some(left) = iter.next() {
+                    let Some(right) = iter.next() else {
+                        internal_err!("Expected a pair of expressions to 
construct the join on expression")?
+                    };
+
                     // SimplifyExpression rule may add alias to the equi_expr.
-                    let unalias_expr = equi_expr.clone().unalias();
-                    if let Expr::BinaryExpr(BinaryExpr { left, op: 
Operator::Eq, right }) = unalias_expr {
-                        Ok((*left, *right))
-                    } else {
-                        internal_err!(
-                            "The front part expressions should be an binary 
equality expression, actual:{equi_expr}"
-                        )
-                    }
-                }).collect::<Result<Vec<(Expr, Expr)>>>()?;
+                    new_on.push((left.unalias(), right.unalias()));
+                }
 
                 Ok(LogicalPlan::Join(Join {
                     left: Arc::new(left),
@@ -3780,7 +3779,8 @@ mod tests {
     use crate::builder::LogicalTableSource;
     use crate::logical_plan::table_scan;
     use crate::{
-        col, exists, in_subquery, lit, placeholder, scalar_subquery, 
GroupingSet,
+        binary_expr, col, exists, in_subquery, lit, placeholder, 
scalar_subquery,
+        GroupingSet,
     };
 
     use datafusion_common::tree_node::{
@@ -4632,4 +4632,110 @@ digraph {
         let parameter_type = 
params.clone().get(placeholder_value).unwrap().clone();
         assert_eq!(parameter_type, None);
     }
+
+    #[test]
+    fn test_join_with_new_exprs() -> Result<()> {
+        fn create_test_join(
+            on: Vec<(Expr, Expr)>,
+            filter: Option<Expr>,
+        ) -> Result<LogicalPlan> {
+            let schema = Schema::new(vec![
+                Field::new("a", DataType::Int32, false),
+                Field::new("b", DataType::Int32, false),
+            ]);
+
+            let left_schema = DFSchema::try_from_qualified_schema("t1", 
&schema)?;
+            let right_schema = DFSchema::try_from_qualified_schema("t2", 
&schema)?;
+
+            Ok(LogicalPlan::Join(Join {
+                left: Arc::new(
+                    table_scan(Some("t1"), left_schema.as_arrow(), 
None)?.build()?,
+                ),
+                right: Arc::new(
+                    table_scan(Some("t2"), right_schema.as_arrow(), 
None)?.build()?,
+                ),
+                on,
+                filter,
+                join_type: JoinType::Inner,
+                join_constraint: JoinConstraint::On,
+                schema: Arc::new(left_schema.join(&right_schema)?),
+                null_equals_null: false,
+            }))
+        }
+
+        {
+            let join = create_test_join(vec![(col("t1.a"), (col("t2.a")))], 
None)?;
+            let LogicalPlan::Join(join) = join.with_new_exprs(
+                join.expressions(),
+                join.inputs().into_iter().cloned().collect(),
+            )?
+            else {
+                unreachable!()
+            };
+            assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]);
+            assert_eq!(join.filter, None);
+        }
+
+        {
+            let join = create_test_join(vec![], 
Some(col("t1.a").gt(col("t2.a"))))?;
+            let LogicalPlan::Join(join) = join.with_new_exprs(
+                join.expressions(),
+                join.inputs().into_iter().cloned().collect(),
+            )?
+            else {
+                unreachable!()
+            };
+            assert_eq!(join.on, vec![]);
+            assert_eq!(join.filter, Some(col("t1.a").gt(col("t2.a"))));
+        }
+
+        {
+            let join = create_test_join(
+                vec![(col("t1.a"), (col("t2.a")))],
+                Some(col("t1.b").gt(col("t2.b"))),
+            )?;
+            let LogicalPlan::Join(join) = join.with_new_exprs(
+                join.expressions(),
+                join.inputs().into_iter().cloned().collect(),
+            )?
+            else {
+                unreachable!()
+            };
+            assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]);
+            assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b"))));
+        }
+
+        {
+            let join = create_test_join(
+                vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), 
(col("t2.b")))],
+                None,
+            )?;
+            let LogicalPlan::Join(join) = join.with_new_exprs(
+                vec![
+                    binary_expr(col("t1.a"), Operator::Plus, lit(1)),
+                    binary_expr(col("t2.a"), Operator::Plus, lit(2)),
+                    col("t1.b"),
+                    col("t2.b"),
+                    lit(true),
+                ],
+                join.inputs().into_iter().cloned().collect(),
+            )?
+            else {
+                unreachable!()
+            };
+            assert_eq!(
+                join.on,
+                vec![
+                    (
+                        binary_expr(col("t1.a"), Operator::Plus, lit(1)),
+                        binary_expr(col("t2.a"), Operator::Plus, lit(2))
+                    ),
+                    (col("t1.b"), (col("t2.b")))
+                ]
+            );
+            assert_eq!(join.filter, Some(lit(true)));
+        }
+
+        Ok(())
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to