This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 9730404028 Add try_new for LogicalPlan::Join (#15757)
9730404028 is described below

commit 9730404028a91a7fe875ea3f88bafdbcb305ae6c
Author: Lokesh <lkuma...@gmail.com>
AuthorDate: Mon Apr 21 17:21:42 2025 +0200

    Add try_new for LogicalPlan::Join (#15757)
---
 datafusion/expr/src/logical_plan/builder.rs |  66 +++--
 datafusion/expr/src/logical_plan/plan.rs    | 416 ++++++++++++++++++++++++++++
 2 files changed, 447 insertions(+), 35 deletions(-)

diff --git a/datafusion/expr/src/logical_plan/builder.rs 
b/datafusion/expr/src/logical_plan/builder.rs
index 64931df5a8..05a43444d4 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -1117,8 +1117,6 @@ impl LogicalPlanBuilder {
             .collect::<Result<_>>()?;
 
         let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys).collect();
-        let join_schema =
-            build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
         let mut join_on: Vec<(Expr, Expr)> = vec![];
         let mut filters: Option<Expr> = None;
         for (l, r) in &on {
@@ -1151,33 +1149,33 @@ impl LogicalPlanBuilder {
                 DataFusionError::Internal("filters should not be None 
here".to_string())
             })?)
         } else {
-            Ok(Self::new(LogicalPlan::Join(Join {
-                left: self.plan,
-                right: Arc::new(right),
-                on: join_on,
-                filter: filters,
+            let join = Join::try_new(
+                self.plan,
+                Arc::new(right),
+                join_on,
+                filters,
                 join_type,
-                join_constraint: JoinConstraint::Using,
-                schema: DFSchemaRef::new(join_schema),
-                null_equals_null: false,
-            })))
+                JoinConstraint::Using,
+                false,
+            )?;
+
+            Ok(Self::new(LogicalPlan::Join(join)))
         }
     }
 
     /// Apply a cross join
     pub fn cross_join(self, right: LogicalPlan) -> Result<Self> {
-        let join_schema =
-            build_join_schema(self.plan.schema(), right.schema(), 
&JoinType::Inner)?;
-        Ok(Self::new(LogicalPlan::Join(Join {
-            left: self.plan,
-            right: Arc::new(right),
-            on: vec![],
-            filter: None,
-            join_type: JoinType::Inner,
-            join_constraint: JoinConstraint::On,
-            null_equals_null: false,
-            schema: DFSchemaRef::new(join_schema),
-        })))
+        let join = Join::try_new(
+            self.plan,
+            Arc::new(right),
+            vec![],
+            None,
+            JoinType::Inner,
+            JoinConstraint::On,
+            false,
+        )?;
+
+        Ok(Self::new(LogicalPlan::Join(join)))
     }
 
     /// Repartition
@@ -1338,7 +1336,7 @@ impl LogicalPlanBuilder {
     /// to columns from the existing input. `r`, the second element of the 
tuple,
     /// must only refer to columns from the right input.
     ///
-    /// `filter` contains any other other filter expression to apply during the
+    /// `filter` contains any other filter expression to apply during the
     /// join. Note that `equi_exprs` predicates are evaluated more efficiently
     /// than the filter expressions, so they are preferred.
     pub fn join_with_expr_keys(
@@ -1388,19 +1386,17 @@ impl LogicalPlanBuilder {
             })
             .collect::<Result<Vec<_>>>()?;
 
-        let join_schema =
-            build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
-
-        Ok(Self::new(LogicalPlan::Join(Join {
-            left: self.plan,
-            right: Arc::new(right),
-            on: join_key_pairs,
+        let join = Join::try_new(
+            self.plan,
+            Arc::new(right),
+            join_key_pairs,
             filter,
             join_type,
-            join_constraint: JoinConstraint::On,
-            schema: DFSchemaRef::new(join_schema),
-            null_equals_null: false,
-        })))
+            JoinConstraint::On,
+            false,
+        )?;
+
+        Ok(Self::new(LogicalPlan::Join(join)))
     }
 
     /// Unnest the given column.
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 76b45d5d72..edf5f1126b 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -3709,6 +3709,47 @@ pub struct Join {
 }
 
 impl Join {
+    /// Creates a new Join operator with automatically computed schema.
+    ///
+    /// This constructor computes the schema based on the join type and inputs,
+    /// removing the need to manually specify the schema or call 
`recompute_schema`.
+    ///
+    /// # Arguments
+    ///
+    /// * `left` - Left input plan
+    /// * `right` - Right input plan
+    /// * `on` - Join condition as a vector of (left_expr, right_expr) pairs
+    /// * `filter` - Optional filter expression (for non-equijoin conditions)
+    /// * `join_type` - Type of join (Inner, Left, Right, etc.)
+    /// * `join_constraint` - Join constraint (On, Using)
+    /// * `null_equals_null` - Whether NULL = NULL in join comparisons
+    ///
+    /// # Returns
+    ///
+    /// A new Join operator with the computed schema
+    pub fn try_new(
+        left: Arc<LogicalPlan>,
+        right: Arc<LogicalPlan>,
+        on: Vec<(Expr, Expr)>,
+        filter: Option<Expr>,
+        join_type: JoinType,
+        join_constraint: JoinConstraint,
+        null_equals_null: bool,
+    ) -> Result<Self> {
+        let join_schema = build_join_schema(left.schema(), right.schema(), 
&join_type)?;
+
+        Ok(Join {
+            left,
+            right,
+            on,
+            filter,
+            join_type,
+            join_constraint,
+            schema: Arc::new(join_schema),
+            null_equals_null,
+        })
+    }
+
     /// Create Join with input which wrapped with projection, this method is 
used to help create physical join.
     pub fn try_new_with_project_input(
         original: &LogicalPlan,
@@ -4916,4 +4957,379 @@ digraph {
 
         Ok(())
     }
+
+    #[test]
+    fn test_join_try_new() -> Result<()> {
+        let schema = Schema::new(vec![
+            Field::new("a", DataType::Int32, false),
+            Field::new("b", DataType::Int32, false),
+        ]);
+
+        let left_scan = table_scan(Some("t1"), &schema, None)?.build()?;
+
+        let right_scan = table_scan(Some("t2"), &schema, None)?.build()?;
+
+        let join_types = vec![
+            JoinType::Inner,
+            JoinType::Left,
+            JoinType::Right,
+            JoinType::Full,
+            JoinType::LeftSemi,
+            JoinType::LeftAnti,
+            JoinType::RightSemi,
+            JoinType::RightAnti,
+            JoinType::LeftMark,
+        ];
+
+        for join_type in join_types {
+            let join = Join::try_new(
+                Arc::new(left_scan.clone()),
+                Arc::new(right_scan.clone()),
+                vec![(col("t1.a"), col("t2.a"))],
+                Some(col("t1.b").gt(col("t2.b"))),
+                join_type,
+                JoinConstraint::On,
+                false,
+            )?;
+
+            match join_type {
+                JoinType::LeftSemi | JoinType::LeftAnti => {
+                    assert_eq!(join.schema.fields().len(), 2);
+
+                    let fields = join.schema.fields();
+                    assert_eq!(
+                        fields[0].name(),
+                        "a",
+                        "First field should be 'a' from left table"
+                    );
+                    assert_eq!(
+                        fields[1].name(),
+                        "b",
+                        "Second field should be 'b' from left table"
+                    );
+                }
+                JoinType::RightSemi | JoinType::RightAnti => {
+                    assert_eq!(join.schema.fields().len(), 2);
+
+                    let fields = join.schema.fields();
+                    assert_eq!(
+                        fields[0].name(),
+                        "a",
+                        "First field should be 'a' from right table"
+                    );
+                    assert_eq!(
+                        fields[1].name(),
+                        "b",
+                        "Second field should be 'b' from right table"
+                    );
+                }
+                JoinType::LeftMark => {
+                    assert_eq!(join.schema.fields().len(), 3);
+
+                    let fields = join.schema.fields();
+                    assert_eq!(
+                        fields[0].name(),
+                        "a",
+                        "First field should be 'a' from left table"
+                    );
+                    assert_eq!(
+                        fields[1].name(),
+                        "b",
+                        "Second field should be 'b' from left table"
+                    );
+                    assert_eq!(
+                        fields[2].name(),
+                        "mark",
+                        "Third field should be the mark column"
+                    );
+
+                    assert!(!fields[0].is_nullable());
+                    assert!(!fields[1].is_nullable());
+                    assert!(!fields[2].is_nullable());
+                }
+                _ => {
+                    assert_eq!(join.schema.fields().len(), 4);
+
+                    let fields = join.schema.fields();
+                    assert_eq!(
+                        fields[0].name(),
+                        "a",
+                        "First field should be 'a' from left table"
+                    );
+                    assert_eq!(
+                        fields[1].name(),
+                        "b",
+                        "Second field should be 'b' from left table"
+                    );
+                    assert_eq!(
+                        fields[2].name(),
+                        "a",
+                        "Third field should be 'a' from right table"
+                    );
+                    assert_eq!(
+                        fields[3].name(),
+                        "b",
+                        "Fourth field should be 'b' from right table"
+                    );
+
+                    if join_type == JoinType::Left {
+                        // Left side fields (first two) shouldn't be nullable
+                        assert!(!fields[0].is_nullable());
+                        assert!(!fields[1].is_nullable());
+                        // Right side fields (third and fourth) should be 
nullable
+                        assert!(fields[2].is_nullable());
+                        assert!(fields[3].is_nullable());
+                    } else if join_type == JoinType::Right {
+                        // Left side fields (first two) should be nullable
+                        assert!(fields[0].is_nullable());
+                        assert!(fields[1].is_nullable());
+                        // Right side fields (third and fourth) shouldn't be 
nullable
+                        assert!(!fields[2].is_nullable());
+                        assert!(!fields[3].is_nullable());
+                    } else if join_type == JoinType::Full {
+                        assert!(fields[0].is_nullable());
+                        assert!(fields[1].is_nullable());
+                        assert!(fields[2].is_nullable());
+                        assert!(fields[3].is_nullable());
+                    }
+                }
+            }
+
+            assert_eq!(join.on, vec![(col("t1.a"), col("t2.a"))]);
+            assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b"))));
+            assert_eq!(join.join_type, join_type);
+            assert_eq!(join.join_constraint, JoinConstraint::On);
+            assert!(!join.null_equals_null);
+        }
+
+        Ok(())
+    }
+
+    #[test]
+    fn test_join_try_new_with_using_constraint_and_overlapping_columns() -> 
Result<()> {
+        let left_schema = Schema::new(vec![
+            Field::new("id", DataType::Int32, false), // Common column in both 
tables
+            Field::new("name", DataType::Utf8, false), // Unique to left
+            Field::new("value", DataType::Int32, false), // Common column, 
different meaning
+        ]);
+
+        let right_schema = Schema::new(vec![
+            Field::new("id", DataType::Int32, false), // Common column in both 
tables
+            Field::new("category", DataType::Utf8, false), // Unique to right
+            Field::new("value", DataType::Float64, true), // Common column, 
different meaning
+        ]);
+
+        let left_plan = table_scan(Some("t1"), &left_schema, None)?.build()?;
+
+        let right_plan = table_scan(Some("t2"), &right_schema, None)?.build()?;
+
+        // Test 1: USING constraint with a common column
+        {
+            // In the logical plan, both copies of the `id` column are 
preserved
+            // The USING constraint is handled later during physical 
execution, where the common column appears once
+            let join = Join::try_new(
+                Arc::new(left_plan.clone()),
+                Arc::new(right_plan.clone()),
+                vec![(col("t1.id"), col("t2.id"))],
+                None,
+                JoinType::Inner,
+                JoinConstraint::Using,
+                false,
+            )?;
+
+            let fields = join.schema.fields();
+
+            assert_eq!(fields.len(), 6);
+
+            assert_eq!(
+                fields[0].name(),
+                "id",
+                "First field should be 'id' from left table"
+            );
+            assert_eq!(
+                fields[1].name(),
+                "name",
+                "Second field should be 'name' from left table"
+            );
+            assert_eq!(
+                fields[2].name(),
+                "value",
+                "Third field should be 'value' from left table"
+            );
+            assert_eq!(
+                fields[3].name(),
+                "id",
+                "Fourth field should be 'id' from right table"
+            );
+            assert_eq!(
+                fields[4].name(),
+                "category",
+                "Fifth field should be 'category' from right table"
+            );
+            assert_eq!(
+                fields[5].name(),
+                "value",
+                "Sixth field should be 'value' from right table"
+            );
+
+            assert_eq!(join.join_constraint, JoinConstraint::Using);
+        }
+
+        // Test 2: Complex join condition with expressions
+        {
+            // Complex condition: join on id equality AND where left.value < 
right.value
+            let join = Join::try_new(
+                Arc::new(left_plan.clone()),
+                Arc::new(right_plan.clone()),
+                vec![(col("t1.id"), col("t2.id"))], // Equijoin condition
+                Some(col("t1.value").lt(col("t2.value"))), // Non-equi filter 
condition
+                JoinType::Inner,
+                JoinConstraint::On,
+                false,
+            )?;
+
+            let fields = join.schema.fields();
+            assert_eq!(fields.len(), 6);
+
+            assert_eq!(
+                fields[0].name(),
+                "id",
+                "First field should be 'id' from left table"
+            );
+            assert_eq!(
+                fields[1].name(),
+                "name",
+                "Second field should be 'name' from left table"
+            );
+            assert_eq!(
+                fields[2].name(),
+                "value",
+                "Third field should be 'value' from left table"
+            );
+            assert_eq!(
+                fields[3].name(),
+                "id",
+                "Fourth field should be 'id' from right table"
+            );
+            assert_eq!(
+                fields[4].name(),
+                "category",
+                "Fifth field should be 'category' from right table"
+            );
+            assert_eq!(
+                fields[5].name(),
+                "value",
+                "Sixth field should be 'value' from right table"
+            );
+
+            assert_eq!(join.filter, Some(col("t1.value").lt(col("t2.value"))));
+        }
+
+        // Test 3: Join with null equality behavior set to true
+        {
+            let join = Join::try_new(
+                Arc::new(left_plan.clone()),
+                Arc::new(right_plan.clone()),
+                vec![(col("t1.id"), col("t2.id"))],
+                None,
+                JoinType::Inner,
+                JoinConstraint::On,
+                true,
+            )?;
+
+            assert!(join.null_equals_null);
+        }
+
+        Ok(())
+    }
+
+    #[test]
+    fn test_join_try_new_schema_validation() -> Result<()> {
+        let left_schema = Schema::new(vec![
+            Field::new("id", DataType::Int32, false),
+            Field::new("name", DataType::Utf8, false),
+            Field::new("value", DataType::Float64, true),
+        ]);
+
+        let right_schema = Schema::new(vec![
+            Field::new("id", DataType::Int32, false),
+            Field::new("category", DataType::Utf8, true),
+            Field::new("code", DataType::Int16, false),
+        ]);
+
+        let left_plan = table_scan(Some("t1"), &left_schema, None)?.build()?;
+
+        let right_plan = table_scan(Some("t2"), &right_schema, None)?.build()?;
+
+        let join_types = vec![
+            JoinType::Inner,
+            JoinType::Left,
+            JoinType::Right,
+            JoinType::Full,
+        ];
+
+        for join_type in join_types {
+            let join = Join::try_new(
+                Arc::new(left_plan.clone()),
+                Arc::new(right_plan.clone()),
+                vec![(col("t1.id"), col("t2.id"))],
+                Some(col("t1.value").gt(lit(5.0))),
+                join_type,
+                JoinConstraint::On,
+                false,
+            )?;
+
+            let fields = join.schema.fields();
+            assert_eq!(
+                fields.len(),
+                6,
+                "Expected 6 fields for {:?} join",
+                join_type
+            );
+
+            for (i, field) in fields.iter().enumerate() {
+                let expected_nullable = match (i, &join_type) {
+                    // Left table fields (indices 0, 1, 2)
+                    (0, JoinType::Right | JoinType::Full) => true, // id 
becomes nullable in RIGHT/FULL
+                    (1, JoinType::Right | JoinType::Full) => true, // name 
becomes nullable in RIGHT/FULL
+                    (2, _) => true, // value is already nullable
+
+                    // Right table fields (indices 3, 4, 5)
+                    (3, JoinType::Left | JoinType::Full) => true, // id 
becomes nullable in LEFT/FULL
+                    (4, _) => true, // category is already nullable
+                    (5, JoinType::Left | JoinType::Full) => true, // code 
becomes nullable in LEFT/FULL
+
+                    _ => false,
+                };
+
+                assert_eq!(
+                    field.is_nullable(),
+                    expected_nullable,
+                    "Field {} ({}) nullability incorrect for {:?} join",
+                    i,
+                    field.name(),
+                    join_type
+                );
+            }
+        }
+
+        let using_join = Join::try_new(
+            Arc::new(left_plan.clone()),
+            Arc::new(right_plan.clone()),
+            vec![(col("t1.id"), col("t2.id"))],
+            None,
+            JoinType::Inner,
+            JoinConstraint::Using,
+            false,
+        )?;
+
+        assert_eq!(
+            using_join.schema.fields().len(),
+            6,
+            "USING join should have all fields"
+        );
+        assert_eq!(using_join.join_constraint, JoinConstraint::Using);
+
+        Ok(())
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to