This is an automated email from the ASF dual-hosted git repository.

berkay pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a08dc0af8f Move join type input swapping to pub methods on Joins 
(#13910)
a08dc0af8f is described below

commit a08dc0af8f798afe73a2fdf170ab36e72ad7e782
Author: Andrew Lamb <[email protected]>
AuthorDate: Fri Dec 27 06:53:08 2024 -0500

    Move join type input swapping to pub methods on Joins (#13910)
---
 datafusion/common/src/join_type.rs                 |  34 +++
 .../core/src/physical_optimizer/join_selection.rs  | 263 ++++-----------------
 datafusion/physical-plan/src/joins/cross_join.rs   |  17 +-
 datafusion/physical-plan/src/joins/hash_join.rs    |  81 ++++++-
 datafusion/physical-plan/src/joins/join_filter.rs  | 100 ++++++++
 datafusion/physical-plan/src/joins/mod.rs          |   8 +-
 .../physical-plan/src/joins/nested_loop_join.rs    |  39 ++-
 datafusion/physical-plan/src/joins/utils.rs        | 142 ++++++-----
 8 files changed, 397 insertions(+), 287 deletions(-)

diff --git a/datafusion/common/src/join_type.rs 
b/datafusion/common/src/join_type.rs
index bdca253c5f..ac81d977b7 100644
--- a/datafusion/common/src/join_type.rs
+++ b/datafusion/common/src/join_type.rs
@@ -73,6 +73,40 @@ impl JoinType {
     pub fn is_outer(self) -> bool {
         self == JoinType::Left || self == JoinType::Right || self == 
JoinType::Full
     }
+
+    /// Returns the `JoinType` if the (2) inputs were swapped
+    ///
+    /// Panics if [`Self::supports_swap`] returns false
+    pub fn swap(&self) -> JoinType {
+        match self {
+            JoinType::Inner => JoinType::Inner,
+            JoinType::Full => JoinType::Full,
+            JoinType::Left => JoinType::Right,
+            JoinType::Right => JoinType::Left,
+            JoinType::LeftSemi => JoinType::RightSemi,
+            JoinType::RightSemi => JoinType::LeftSemi,
+            JoinType::LeftAnti => JoinType::RightAnti,
+            JoinType::RightAnti => JoinType::LeftAnti,
+            JoinType::LeftMark => {
+                unreachable!("LeftMark join type does not support swapping")
+            }
+        }
+    }
+
+    /// Does the join type support swapping  inputs?
+    pub fn supports_swap(&self) -> bool {
+        matches!(
+            self,
+            JoinType::Inner
+                | JoinType::Left
+                | JoinType::Right
+                | JoinType::Full
+                | JoinType::LeftSemi
+                | JoinType::RightSemi
+                | JoinType::LeftAnti
+                | JoinType::RightAnti
+        )
+    }
 }
 
 impl Display for JoinType {
diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs 
b/datafusion/core/src/physical_optimizer/join_selection.rs
index 29c6e00788..d7a2f17401 100644
--- a/datafusion/core/src/physical_optimizer/join_selection.rs
+++ b/datafusion/core/src/physical_optimizer/join_selection.rs
@@ -32,15 +32,12 @@ use crate::physical_plan::joins::{
     CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode,
     StreamJoinPartitionMode, SymmetricHashJoinExec,
 };
-use crate::physical_plan::projection::ProjectionExec;
 use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
 
-use arrow_schema::Schema;
 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
 use datafusion_common::{internal_err, JoinSide, JoinType};
 use datafusion_expr::sort_properties::SortProperties;
 use datafusion_physical_expr::expressions::Column;
-use datafusion_physical_expr::PhysicalExpr;
 use datafusion_physical_expr_common::sort_expr::LexOrdering;
 use datafusion_physical_optimizer::PhysicalOptimizerRule;
 use datafusion_physical_plan::execution_plan::EmissionType;
@@ -108,197 +105,49 @@ fn supports_collect_by_thresholds(
 }
 
 /// Predicate that checks whether the given join type supports input swapping.
+#[deprecated(since = "45.0.0", note = "use JoinType::supports_swap instead")]
+#[allow(dead_code)]
 pub(crate) fn supports_swap(join_type: JoinType) -> bool {
-    matches!(
-        join_type,
-        JoinType::Inner
-            | JoinType::Left
-            | JoinType::Right
-            | JoinType::Full
-            | JoinType::LeftSemi
-            | JoinType::RightSemi
-            | JoinType::LeftAnti
-            | JoinType::RightAnti
-    )
+    join_type.supports_swap()
 }
 
 /// This function returns the new join type we get after swapping the given
 /// join's inputs.
+#[deprecated(since = "45.0.0", note = "use datafusion-functions-nested 
instead")]
+#[allow(dead_code)]
 pub(crate) fn swap_join_type(join_type: JoinType) -> JoinType {
-    match join_type {
-        JoinType::Inner => JoinType::Inner,
-        JoinType::Full => JoinType::Full,
-        JoinType::Left => JoinType::Right,
-        JoinType::Right => JoinType::Left,
-        JoinType::LeftSemi => JoinType::RightSemi,
-        JoinType::RightSemi => JoinType::LeftSemi,
-        JoinType::LeftAnti => JoinType::RightAnti,
-        JoinType::RightAnti => JoinType::LeftAnti,
-        JoinType::LeftMark => {
-            unreachable!("LeftMark join type does not support swapping")
-        }
-    }
-}
-
-/// This function swaps the given join's projection.
-fn swap_join_projection(
-    left_schema_len: usize,
-    right_schema_len: usize,
-    projection: Option<&Vec<usize>>,
-    join_type: &JoinType,
-) -> Option<Vec<usize>> {
-    match join_type {
-        // For Anti/Semi join types, projection should remain unmodified,
-        // since these joins output schema remains the same after swap
-        JoinType::LeftAnti
-        | JoinType::LeftSemi
-        | JoinType::RightAnti
-        | JoinType::RightSemi => projection.cloned(),
-
-        _ => projection.map(|p| {
-            p.iter()
-                .map(|i| {
-                    // If the index is less than the left schema length, it is 
from
-                    // the left schema, so we add the right schema length to 
it.
-                    // Otherwise, it is from the right schema, so we subtract 
the left
-                    // schema length from it.
-                    if *i < left_schema_len {
-                        *i + right_schema_len
-                    } else {
-                        *i - left_schema_len
-                    }
-                })
-                .collect()
-        }),
-    }
+    join_type.swap()
 }
 
 /// This function swaps the inputs of the given join operator.
 /// This function is public so other downstream projects can use it
 /// to construct `HashJoinExec` with right side as the build side.
+#[deprecated(since = "45.0.0", note = "use HashJoinExec::swap_inputs instead")]
 pub fn swap_hash_join(
     hash_join: &HashJoinExec,
     partition_mode: PartitionMode,
 ) -> Result<Arc<dyn ExecutionPlan>> {
-    let left = hash_join.left();
-    let right = hash_join.right();
-    let new_join = HashJoinExec::try_new(
-        Arc::clone(right),
-        Arc::clone(left),
-        hash_join
-            .on()
-            .iter()
-            .map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
-            .collect(),
-        swap_join_filter(hash_join.filter()),
-        &swap_join_type(*hash_join.join_type()),
-        swap_join_projection(
-            left.schema().fields().len(),
-            right.schema().fields().len(),
-            hash_join.projection.as_ref(),
-            hash_join.join_type(),
-        ),
-        partition_mode,
-        hash_join.null_equals_null(),
-    )?;
-    // In case of anti / semi joins or if there is embedded projection in 
HashJoinExec, output column order is preserved, no need to add projection again
-    if matches!(
-        hash_join.join_type(),
-        JoinType::LeftSemi
-            | JoinType::RightSemi
-            | JoinType::LeftAnti
-            | JoinType::RightAnti
-    ) || hash_join.projection.is_some()
-    {
-        Ok(Arc::new(new_join))
-    } else {
-        // TODO avoid adding ProjectionExec again and again, only adding Final 
Projection
-        let proj = ProjectionExec::try_new(
-            swap_reverting_projection(&left.schema(), &right.schema()),
-            Arc::new(new_join),
-        )?;
-        Ok(Arc::new(proj))
-    }
+    hash_join.swap_inputs(partition_mode)
 }
 
 /// Swaps inputs of `NestedLoopJoinExec` and wraps it into `ProjectionExec` is 
required
+#[deprecated(since = "45.0.0", note = "use NestedLoopJoinExec::swap_inputs")]
+#[allow(dead_code)]
 pub(crate) fn swap_nl_join(join: &NestedLoopJoinExec) -> Result<Arc<dyn 
ExecutionPlan>> {
-    let new_filter = swap_join_filter(join.filter());
-    let new_join_type = &swap_join_type(*join.join_type());
-
-    let new_join = NestedLoopJoinExec::try_new(
-        Arc::clone(join.right()),
-        Arc::clone(join.left()),
-        new_filter,
-        new_join_type,
-    )?;
-
-    // For Semi/Anti joins, swap result will produce same output schema,
-    // no need to wrap them into additional projection
-    let plan: Arc<dyn ExecutionPlan> = if matches!(
-        join.join_type(),
-        JoinType::LeftSemi
-            | JoinType::RightSemi
-            | JoinType::LeftAnti
-            | JoinType::RightAnti
-    ) {
-        Arc::new(new_join)
-    } else {
-        let projection =
-            swap_reverting_projection(&join.left().schema(), 
&join.right().schema());
-
-        Arc::new(ProjectionExec::try_new(projection, Arc::new(new_join))?)
-    };
-
-    Ok(plan)
+    join.swap_inputs()
 }
 
-/// When the order of the join is changed by the optimizer, the columns in
-/// the output should not be impacted. This function creates the expressions
-/// that will allow to swap back the values from the original left as the first
-/// columns and those on the right next.
-pub(crate) fn swap_reverting_projection(
-    left_schema: &Schema,
-    right_schema: &Schema,
-) -> Vec<(Arc<dyn PhysicalExpr>, String)> {
-    let right_cols = right_schema.fields().iter().enumerate().map(|(i, f)| {
-        (
-            Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>,
-            f.name().to_owned(),
-        )
-    });
-    let right_len = right_cols.len();
-    let left_cols = left_schema.fields().iter().enumerate().map(|(i, f)| {
-        (
-            Arc::new(Column::new(f.name(), right_len + i)) as Arc<dyn 
PhysicalExpr>,
-            f.name().to_owned(),
-        )
-    });
-
-    left_cols.chain(right_cols).collect()
+/// Swaps join sides for filter column indices and produces new `JoinFilter` 
(if exists).
+#[deprecated(since = "45.0.0", note = "use filter.map(JoinFilter::swap) 
instead")]
+#[allow(dead_code)]
+fn swap_join_filter(filter: Option<&JoinFilter>) -> Option<JoinFilter> {
+    filter.map(JoinFilter::swap)
 }
 
-/// Swaps join sides for filter column indices and produces new JoinFilter
+#[deprecated(since = "45.0.0", note = "use JoinFilter::swap instead")]
+#[allow(dead_code)]
 pub(crate) fn swap_filter(filter: &JoinFilter) -> JoinFilter {
-    let column_indices = filter
-        .column_indices()
-        .iter()
-        .map(|idx| ColumnIndex {
-            index: idx.index,
-            side: idx.side.negate(),
-        })
-        .collect();
-
-    JoinFilter::new(
-        Arc::clone(filter.expression()),
-        column_indices,
-        filter.schema().clone(),
-    )
-}
-
-/// Swaps join sides for filter column indices and produces new `JoinFilter` 
(if exists).
-fn swap_join_filter(filter: Option<&JoinFilter>) -> Option<JoinFilter> {
-    filter.map(swap_filter)
+    filter.swap()
 }
 
 impl PhysicalOptimizerRule for JoinSelection {
@@ -383,10 +232,10 @@ pub(crate) fn try_collect_left(
 
     match (left_can_collect, right_can_collect) {
         (true, true) => {
-            if supports_swap(*hash_join.join_type())
+            if hash_join.join_type().supports_swap()
                 && should_swap_join_order(&**left, &**right)?
             {
-                Ok(Some(swap_hash_join(hash_join, 
PartitionMode::CollectLeft)?))
+                Ok(Some(hash_join.swap_inputs(PartitionMode::CollectLeft)?))
             } else {
                 Ok(Some(Arc::new(HashJoinExec::try_new(
                     Arc::clone(left),
@@ -411,8 +260,8 @@ pub(crate) fn try_collect_left(
             hash_join.null_equals_null(),
         )?))),
         (false, true) => {
-            if supports_swap(*hash_join.join_type()) {
-                swap_hash_join(hash_join, PartitionMode::CollectLeft).map(Some)
+            if hash_join.join_type().supports_swap() {
+                hash_join.swap_inputs(PartitionMode::CollectLeft).map(Some)
             } else {
                 Ok(None)
             }
@@ -431,9 +280,9 @@ pub(crate) fn partitioned_hash_join(
 ) -> Result<Arc<dyn ExecutionPlan>> {
     let left = hash_join.left();
     let right = hash_join.right();
-    if supports_swap(*hash_join.join_type()) && 
should_swap_join_order(&**left, &**right)?
+    if hash_join.join_type().supports_swap() && 
should_swap_join_order(&**left, &**right)?
     {
-        swap_hash_join(hash_join, PartitionMode::Partitioned)
+        hash_join.swap_inputs(PartitionMode::Partitioned)
     } else {
         Ok(Arc::new(HashJoinExec::try_new(
             Arc::clone(left),
@@ -476,10 +325,12 @@ fn statistical_join_selection_subrule(
                 PartitionMode::Partitioned => {
                     let left = hash_join.left();
                     let right = hash_join.right();
-                    if supports_swap(*hash_join.join_type())
+                    if hash_join.join_type().supports_swap()
                         && should_swap_join_order(&**left, &**right)?
                     {
-                        swap_hash_join(hash_join, 
PartitionMode::Partitioned).map(Some)?
+                        hash_join
+                            .swap_inputs(PartitionMode::Partitioned)
+                            .map(Some)?
                     } else {
                         None
                     }
@@ -489,23 +340,17 @@ fn statistical_join_selection_subrule(
             let left = cross_join.left();
             let right = cross_join.right();
             if should_swap_join_order(&**left, &**right)? {
-                let new_join = CrossJoinExec::new(Arc::clone(right), 
Arc::clone(left));
-                // TODO avoid adding ProjectionExec again and again, only 
adding Final Projection
-                let proj: Arc<dyn ExecutionPlan> = 
Arc::new(ProjectionExec::try_new(
-                    swap_reverting_projection(&left.schema(), &right.schema()),
-                    Arc::new(new_join),
-                )?);
-                Some(proj)
+                cross_join.swap_inputs().map(Some)?
             } else {
                 None
             }
         } else if let Some(nl_join) = 
plan.as_any().downcast_ref::<NestedLoopJoinExec>() {
             let left = nl_join.left();
             let right = nl_join.right();
-            if supports_swap(*nl_join.join_type())
+            if nl_join.join_type().supports_swap()
                 && should_swap_join_order(&**left, &**right)?
             {
-                swap_nl_join(nl_join).map(Some)?
+                nl_join.swap_inputs().map(Some)?
             } else {
                 None
             }
@@ -718,10 +563,10 @@ fn swap_join_according_to_unboundedness(
             JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | 
JoinType::Full,
         ) => internal_err!("{join_type} join cannot be swapped for unbounded 
input."),
         (PartitionMode::Partitioned, _) => {
-            swap_hash_join(hash_join, PartitionMode::Partitioned)
+            hash_join.swap_inputs(PartitionMode::Partitioned)
         }
         (PartitionMode::CollectLeft, _) => {
-            swap_hash_join(hash_join, PartitionMode::CollectLeft)
+            hash_join.swap_inputs(PartitionMode::CollectLeft)
         }
         (PartitionMode::Auto, _) => {
             internal_err!("Auto is not acceptable for unbounded input here.")
@@ -751,12 +596,15 @@ mod tests_statistical {
     };
 
     use arrow::datatypes::{DataType, Field};
+    use arrow_schema::Schema;
     use datafusion_common::{stats::Precision, JoinType, ScalarValue};
     use datafusion_expr::Operator;
     use datafusion_physical_expr::expressions::col;
     use datafusion_physical_expr::expressions::BinaryExpr;
     use datafusion_physical_expr::PhysicalExprRef;
 
+    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
+    use datafusion_physical_plan::projection::ProjectionExec;
     use rstest::rstest;
 
     /// Return statistics for empty table
@@ -1372,7 +1220,8 @@ mod tests_statistical {
             false,
         )?);
 
-        let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned)
+        let swapped = join
+            .swap_inputs(PartitionMode::Partitioned)
             .expect("swap_hash_join must support joins with projections");
         let swapped_join = 
swapped.as_any().downcast_ref::<HashJoinExec>().expect(
             "ProjectionExec won't be added above if HashJoinExec contains 
embedded projection",
@@ -1384,32 +1233,6 @@ mod tests_statistical {
         Ok(())
     }
 
-    #[tokio::test]
-    async fn test_swap_reverting_projection() {
-        let left_schema = Schema::new(vec![
-            Field::new("a", DataType::Int32, false),
-            Field::new("b", DataType::Int32, false),
-        ]);
-
-        let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, 
false)]);
-
-        let proj = swap_reverting_projection(&left_schema, &right_schema);
-
-        assert_eq!(proj.len(), 3);
-
-        let (col, name) = &proj[0];
-        assert_eq!(name, "a");
-        assert_col_expr(col, "a", 1);
-
-        let (col, name) = &proj[1];
-        assert_eq!(name, "b");
-        assert_col_expr(col, "b", 2);
-
-        let (col, name) = &proj[2];
-        assert_eq!(name, "c");
-        assert_col_expr(col, "c", 0);
-    }
-
     fn assert_col_expr(expr: &Arc<dyn PhysicalExpr>, name: &str, index: usize) 
{
         let col = expr
             .as_any()
@@ -1643,7 +1466,9 @@ mod hash_join_tests {
 
     use arrow::datatypes::{DataType, Field};
     use arrow::record_batch::RecordBatch;
+    use arrow_schema::Schema;
     use datafusion_physical_expr::expressions::col;
+    use datafusion_physical_plan::projection::ProjectionExec;
 
     struct TestCase {
         case: String,
@@ -1723,7 +1548,7 @@ mod hash_join_tests {
                 initial_join_type: join_type,
                 initial_mode: PartitionMode::CollectLeft,
                 expected_sources_unbounded: (SourceType::Bounded, 
SourceType::Unbounded),
-                expected_join_type: swap_join_type(join_type),
+                expected_join_type: join_type.swap(),
                 expected_mode: PartitionMode::CollectLeft,
                 expecting_swap: true,
             });
@@ -1766,7 +1591,7 @@ mod hash_join_tests {
                 initial_join_type: join_type,
                 initial_mode: PartitionMode::Partitioned,
                 expected_sources_unbounded: (SourceType::Bounded, 
SourceType::Unbounded),
-                expected_join_type: swap_join_type(join_type),
+                expected_join_type: join_type.swap(),
                 expected_mode: PartitionMode::Partitioned,
                 expecting_swap: true,
             });
@@ -1824,7 +1649,7 @@ mod hash_join_tests {
                 initial_join_type: join_type,
                 initial_mode: PartitionMode::Partitioned,
                 expected_sources_unbounded: (SourceType::Bounded, 
SourceType::Unbounded),
-                expected_join_type: swap_join_type(join_type),
+                expected_join_type: join_type.swap(),
                 expected_mode: PartitionMode::Partitioned,
                 expecting_swap: true,
             });
diff --git a/datafusion/physical-plan/src/joins/cross_join.rs 
b/datafusion/physical-plan/src/joins/cross_join.rs
index b70eeb313b..69300fce77 100644
--- a/datafusion/physical-plan/src/joins/cross_join.rs
+++ b/datafusion/physical-plan/src/joins/cross_join.rs
@@ -19,8 +19,8 @@
 //! and producing batches in parallel for the right partitions
 
 use super::utils::{
-    adjust_right_output_partitioning, BatchSplitter, BatchTransformer,
-    BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut,
+    adjust_right_output_partitioning, reorder_output_after_swap, BatchSplitter,
+    BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, 
OnceFut,
     StatefulStreamResult,
 };
 use crate::coalesce_partitions::CoalescePartitionsExec;
@@ -168,6 +168,19 @@ impl CrossJoinExec {
             boundedness_from_children([left, right]),
         )
     }
+
+    /// Returns a new `ExecutionPlan` that computes the same join as this one,
+    /// with the left and right inputs swapped using the  specified
+    /// `partition_mode`.
+    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
+        let new_join =
+            CrossJoinExec::new(Arc::clone(&self.right), 
Arc::clone(&self.left));
+        reorder_output_after_swap(
+            Arc::new(new_join),
+            &self.left.schema(),
+            &self.right.schema(),
+        )
+    }
 }
 
 /// Asynchronously collect the result of the left child
diff --git a/datafusion/physical-plan/src/joins/hash_join.rs 
b/datafusion/physical-plan/src/joins/hash_join.rs
index dabe42ee43..a0fe0bd116 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -24,7 +24,7 @@ use std::sync::Arc;
 use std::task::Poll;
 use std::{any::Any, vec};
 
-use super::utils::asymmetric_join_output_partitioning;
+use super::utils::{asymmetric_join_output_partitioning, 
reorder_output_after_swap};
 use super::{
     utils::{OnceAsync, OnceFut},
     PartitionMode,
@@ -566,8 +566,87 @@ impl HashJoinExec {
             boundedness_from_children([left, right]),
         ))
     }
+
+    /// Returns a new `ExecutionPlan` that computes the same join as this one,
+    /// with the left and right inputs swapped using the  specified
+    /// `partition_mode`.
+    ///
+    /// # Notes:
+    ///
+    /// This function is public so other downstream projects can use it to
+    /// construct `HashJoinExec` with right side as the build side.
+    pub fn swap_inputs(
+        &self,
+        partition_mode: PartitionMode,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        let left = self.left();
+        let right = self.right();
+        let new_join = HashJoinExec::try_new(
+            Arc::clone(right),
+            Arc::clone(left),
+            self.on()
+                .iter()
+                .map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
+                .collect(),
+            self.filter().map(JoinFilter::swap),
+            &self.join_type().swap(),
+            swap_join_projection(
+                left.schema().fields().len(),
+                right.schema().fields().len(),
+                self.projection.as_ref(),
+                self.join_type(),
+            ),
+            partition_mode,
+            self.null_equals_null(),
+        )?;
+        // In case of anti / semi joins or if there is embedded projection in 
HashJoinExec, output column order is preserved, no need to add projection again
+        if matches!(
+            self.join_type(),
+            JoinType::LeftSemi
+                | JoinType::RightSemi
+                | JoinType::LeftAnti
+                | JoinType::RightAnti
+        ) || self.projection.is_some()
+        {
+            Ok(Arc::new(new_join))
+        } else {
+            reorder_output_after_swap(Arc::new(new_join), &left.schema(), 
&right.schema())
+        }
+    }
 }
 
+/// This function swaps the given join's projection.
+fn swap_join_projection(
+    left_schema_len: usize,
+    right_schema_len: usize,
+    projection: Option<&Vec<usize>>,
+    join_type: &JoinType,
+) -> Option<Vec<usize>> {
+    match join_type {
+        // For Anti/Semi join types, projection should remain unmodified,
+        // since these joins output schema remains the same after swap
+        JoinType::LeftAnti
+        | JoinType::LeftSemi
+        | JoinType::RightAnti
+        | JoinType::RightSemi => projection.cloned(),
+
+        _ => projection.map(|p| {
+            p.iter()
+                .map(|i| {
+                    // If the index is less than the left schema length, it is 
from
+                    // the left schema, so we add the right schema length to 
it.
+                    // Otherwise, it is from the right schema, so we subtract 
the left
+                    // schema length from it.
+                    if *i < left_schema_len {
+                        *i + right_schema_len
+                    } else {
+                        *i - left_schema_len
+                    }
+                })
+                .collect()
+        }),
+    }
+}
 impl DisplayAs for HashJoinExec {
     fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> 
fmt::Result {
         match t {
diff --git a/datafusion/physical-plan/src/joins/join_filter.rs 
b/datafusion/physical-plan/src/joins/join_filter.rs
new file mode 100644
index 0000000000..b99afd87c9
--- /dev/null
+++ b/datafusion/physical-plan/src/joins/join_filter.rs
@@ -0,0 +1,100 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::joins::utils::ColumnIndex;
+use arrow_schema::Schema;
+use datafusion_common::JoinSide;
+use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
+use std::sync::Arc;
+
+/// Filter applied before join output. Fields are crate-public to allow
+/// downstream implementations to experiment with custom joins.
+#[derive(Debug, Clone)]
+pub struct JoinFilter {
+    /// Filter expression
+    pub(crate) expression: Arc<dyn PhysicalExpr>,
+    /// Column indices required to construct intermediate batch for filtering
+    pub(crate) column_indices: Vec<ColumnIndex>,
+    /// Physical schema of intermediate batch
+    pub(crate) schema: Schema,
+}
+
+impl JoinFilter {
+    /// Creates new JoinFilter
+    pub fn new(
+        expression: Arc<dyn PhysicalExpr>,
+        column_indices: Vec<ColumnIndex>,
+        schema: Schema,
+    ) -> JoinFilter {
+        JoinFilter {
+            expression,
+            column_indices,
+            schema,
+        }
+    }
+
+    /// Helper for building ColumnIndex vector from left and right indices
+    pub fn build_column_indices(
+        left_indices: Vec<usize>,
+        right_indices: Vec<usize>,
+    ) -> Vec<ColumnIndex> {
+        left_indices
+            .into_iter()
+            .map(|i| ColumnIndex {
+                index: i,
+                side: JoinSide::Left,
+            })
+            .chain(right_indices.into_iter().map(|i| ColumnIndex {
+                index: i,
+                side: JoinSide::Right,
+            }))
+            .collect()
+    }
+
+    /// Filter expression
+    pub fn expression(&self) -> &Arc<dyn PhysicalExpr> {
+        &self.expression
+    }
+
+    /// Column indices for intermediate batch creation
+    pub fn column_indices(&self) -> &[ColumnIndex] {
+        &self.column_indices
+    }
+
+    /// Intermediate batch schema
+    pub fn schema(&self) -> &Schema {
+        &self.schema
+    }
+
+    /// Rewrites the join filter if the inputs to the join are rewritten
+    pub fn swap(&self) -> JoinFilter {
+        let column_indices = self
+            .column_indices()
+            .iter()
+            .map(|idx| ColumnIndex {
+                index: idx.index,
+                side: idx.side.negate(),
+            })
+            .collect();
+
+        JoinFilter::new(
+            Arc::clone(self.expression()),
+            column_indices,
+            self.schema().clone(),
+        )
+    }
+}
diff --git a/datafusion/physical-plan/src/joins/mod.rs 
b/datafusion/physical-plan/src/joins/mod.rs
index 6ddf19c511..fa077d2008 100644
--- a/datafusion/physical-plan/src/joins/mod.rs
+++ b/datafusion/physical-plan/src/joins/mod.rs
@@ -31,18 +31,20 @@ mod stream_join_utils;
 mod symmetric_hash_join;
 pub mod utils;
 
+mod join_filter;
 #[cfg(test)]
 pub mod test_utils;
 
 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
-/// Partitioning mode to use for hash join
+/// Hash join Partitioning mode
 pub enum PartitionMode {
     /// Left/right children are partitioned using the left and right keys
     Partitioned,
     /// Left side will collected into one partition
     CollectLeft,
-    /// When set to Auto, DataFusion optimizer will decide which PartitionMode 
mode(Partitioned/CollectLeft) is optimal based on statistics.
-    /// It will also consider swapping the left and right inputs for the Join
+    /// DataFusion optimizer decides which PartitionMode
+    /// mode(Partitioned/CollectLeft) is optimal based on statistics. It will
+    /// also consider swapping the left and right inputs for the Join
     Auto,
 }
 
diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs 
b/datafusion/physical-plan/src/joins/nested_loop_join.rs
index 8caf5d9b5d..c69fa28888 100644
--- a/datafusion/physical-plan/src/joins/nested_loop_join.rs
+++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs
@@ -24,8 +24,9 @@ use std::sync::Arc;
 use std::task::Poll;
 
 use super::utils::{
-    asymmetric_join_output_partitioning, need_produce_result_in_final, 
BatchSplitter,
-    BatchTransformer, NoopBatchTransformer, StatefulStreamResult,
+    asymmetric_join_output_partitioning, need_produce_result_in_final,
+    reorder_output_after_swap, BatchSplitter, BatchTransformer, 
NoopBatchTransformer,
+    StatefulStreamResult,
 };
 use crate::coalesce_partitions::CoalescePartitionsExec;
 use crate::execution_plan::{boundedness_from_children, EmissionType};
@@ -296,6 +297,40 @@ impl NestedLoopJoinExec {
             ),
         ]
     }
+
+    /// Returns a new `ExecutionPlan` that runs NestedLoopsJoins with the left
+    /// and right inputs swapped.
+    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
+        let new_filter = self.filter().map(JoinFilter::swap);
+        let new_join_type = &self.join_type().swap();
+
+        let new_join = NestedLoopJoinExec::try_new(
+            Arc::clone(self.right()),
+            Arc::clone(self.left()),
+            new_filter,
+            new_join_type,
+        )?;
+
+        // For Semi/Anti joins, swap result will produce same output schema,
+        // no need to wrap them into additional projection
+        let plan: Arc<dyn ExecutionPlan> = if matches!(
+            self.join_type(),
+            JoinType::LeftSemi
+                | JoinType::RightSemi
+                | JoinType::LeftAnti
+                | JoinType::RightAnti
+        ) {
+            Arc::new(new_join)
+        } else {
+            reorder_output_after_swap(
+                Arc::new(new_join),
+                &self.left().schema(),
+                &self.right().schema(),
+            )?
+        };
+
+        Ok(plan)
+    }
 }
 
 impl DisplayAs for NestedLoopJoinExec {
diff --git a/datafusion/physical-plan/src/joins/utils.rs 
b/datafusion/physical-plan/src/joins/utils.rs
index d792e14304..371949a325 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -29,6 +29,8 @@ use crate::metrics::{self, ExecutionPlanMetricsSet, 
MetricBuilder};
 use crate::{
     ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, 
Statistics,
 };
+// compatibility
+pub use super::join_filter::JoinFilter;
 
 use arrow::array::{
     downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array,
@@ -54,6 +56,7 @@ use datafusion_physical_expr::{
     LexOrdering, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr,
 };
 
+use crate::projection::ProjectionExec;
 use futures::future::{BoxFuture, Shared};
 use futures::{ready, FutureExt};
 use hashbrown::raw::RawTable;
@@ -549,66 +552,6 @@ pub struct ColumnIndex {
     pub side: JoinSide,
 }
 
-/// Filter applied before join output. Fields are crate-public to allow
-/// downstream implementations to experiment with custom joins.
-#[derive(Debug, Clone)]
-pub struct JoinFilter {
-    /// Filter expression
-    pub(crate) expression: Arc<dyn PhysicalExpr>,
-    /// Column indices required to construct intermediate batch for filtering
-    pub(crate) column_indices: Vec<ColumnIndex>,
-    /// Physical schema of intermediate batch
-    pub(crate) schema: Schema,
-}
-
-impl JoinFilter {
-    /// Creates new JoinFilter
-    pub fn new(
-        expression: Arc<dyn PhysicalExpr>,
-        column_indices: Vec<ColumnIndex>,
-        schema: Schema,
-    ) -> JoinFilter {
-        JoinFilter {
-            expression,
-            column_indices,
-            schema,
-        }
-    }
-
-    /// Helper for building ColumnIndex vector from left and right indices
-    pub fn build_column_indices(
-        left_indices: Vec<usize>,
-        right_indices: Vec<usize>,
-    ) -> Vec<ColumnIndex> {
-        left_indices
-            .into_iter()
-            .map(|i| ColumnIndex {
-                index: i,
-                side: JoinSide::Left,
-            })
-            .chain(right_indices.into_iter().map(|i| ColumnIndex {
-                index: i,
-                side: JoinSide::Right,
-            }))
-            .collect()
-    }
-
-    /// Filter expression
-    pub fn expression(&self) -> &Arc<dyn PhysicalExpr> {
-        &self.expression
-    }
-
-    /// Column indices for intermediate batch creation
-    pub fn column_indices(&self) -> &[ColumnIndex] {
-        &self.column_indices
-    }
-
-    /// Intermediate batch schema
-    pub fn schema(&self) -> &Schema {
-        &self.schema
-    }
-}
-
 /// Returns the output field given the input field. Outer joins may
 /// insert nulls even if the input was not null
 ///
@@ -1788,6 +1731,50 @@ impl BatchTransformer for BatchSplitter {
     }
 }
 
+/// When the order of the join inputs are changed, the output order of columns
+/// must remain the same.
+///
+/// Joins output columns from their left input followed by their right input.
+/// Thus if the inputs are reordered, the output columns must be reordered to
+/// match the original order.
+pub(crate) fn reorder_output_after_swap(
+    plan: Arc<dyn ExecutionPlan>,
+    left_schema: &Schema,
+    right_schema: &Schema,
+) -> Result<Arc<dyn ExecutionPlan>> {
+    let proj = ProjectionExec::try_new(
+        swap_reverting_projection(left_schema, right_schema),
+        plan,
+    )?;
+    Ok(Arc::new(proj))
+}
+
+/// When the order of the join is changed, the output order of columns must
+/// remain the same.
+///
+/// Returns the expressions that will allow to swap back the values from the
+/// original left as the first columns and those on the right next.
+fn swap_reverting_projection(
+    left_schema: &Schema,
+    right_schema: &Schema,
+) -> Vec<(Arc<dyn PhysicalExpr>, String)> {
+    let right_cols = right_schema.fields().iter().enumerate().map(|(i, f)| {
+        (
+            Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>,
+            f.name().to_owned(),
+        )
+    });
+    let right_len = right_cols.len();
+    let left_cols = left_schema.fields().iter().enumerate().map(|(i, f)| {
+        (
+            Arc::new(Column::new(f.name(), right_len + i)) as Arc<dyn 
PhysicalExpr>,
+            f.name().to_owned(),
+        )
+    });
+
+    left_cols.chain(right_cols).collect()
+}
+
 #[cfg(test)]
 mod tests {
     use std::pin::Pin;
@@ -2754,4 +2741,39 @@ mod tests {
         assert!(splitter.next().is_none());
         assert_split_batches(batches, batch_size, num_rows);
     }
+
+    #[tokio::test]
+    async fn test_swap_reverting_projection() {
+        let left_schema = Schema::new(vec![
+            Field::new("a", DataType::Int32, false),
+            Field::new("b", DataType::Int32, false),
+        ]);
+
+        let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, 
false)]);
+
+        let proj = swap_reverting_projection(&left_schema, &right_schema);
+
+        assert_eq!(proj.len(), 3);
+
+        let (col, name) = &proj[0];
+        assert_eq!(name, "a");
+        assert_col_expr(col, "a", 1);
+
+        let (col, name) = &proj[1];
+        assert_eq!(name, "b");
+        assert_col_expr(col, "b", 2);
+
+        let (col, name) = &proj[2];
+        assert_eq!(name, "c");
+        assert_col_expr(col, "c", 0);
+    }
+
+    fn assert_col_expr(expr: &Arc<dyn PhysicalExpr>, name: &str, index: usize) 
{
+        let col = expr
+            .as_any()
+            .downcast_ref::<Column>()
+            .expect("Projection items should be Column expression");
+        assert_eq!(col.name(), name);
+        assert_eq!(col.index(), index);
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to