This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 78d33147dc Make benefits_from_input_partitioning Default in SHJ (#8801)
78d33147dc is described below

commit 78d33147dcd81beefd3cdebb91025b1bf2f51343
Author: Metehan Yıldırım <[email protected]>
AuthorDate: Wed Jan 10 22:20:44 2024 +0300

    Make benefits_from_input_partitioning Default in SHJ (#8801)
    
    * SHJ order fixing
    
    * Update join_selection.rs
    
    * Change proto type for sort exprs
    
    ---------
    
    Co-authored-by: Mustafa Akur <[email protected]>
---
 .../core/src/physical_optimizer/join_selection.rs  | 88 ++++++++++++++++++++--
 .../src/physical_optimizer/projection_pushdown.rs  |  4 +
 datafusion/core/tests/sql/joins.rs                 | 68 +++++++++++++++++
 .../physical-plan/src/joins/symmetric_hash_join.rs | 70 +++++++++++------
 datafusion/physical-plan/src/joins/test_utils.rs   |  6 +-
 datafusion/proto/proto/datafusion.proto            |  2 +
 datafusion/proto/src/generated/pbjson.rs           | 36 +++++++++
 datafusion/proto/src/generated/prost.rs            |  4 +
 datafusion/proto/src/physical_plan/from_proto.rs   | 30 ++++++++
 datafusion/proto/src/physical_plan/mod.rs          | 65 +++++++++++++++-
 .../proto/tests/cases/roundtrip_physical_plan.rs   | 38 +++++++---
 11 files changed, 370 insertions(+), 41 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs 
b/datafusion/core/src/physical_optimizer/join_selection.rs
index 6b2fe24acf..ba66dca55b 100644
--- a/datafusion/core/src/physical_optimizer/join_selection.rs
+++ b/datafusion/core/src/physical_optimizer/join_selection.rs
@@ -38,11 +38,12 @@ use crate::physical_plan::projection::ProjectionExec;
 use crate::physical_plan::ExecutionPlan;
 
 use arrow_schema::Schema;
-use datafusion_common::internal_err;
 use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::{internal_err, JoinSide};
 use datafusion_common::{DataFusionError, JoinType};
 use datafusion_physical_expr::expressions::Column;
-use datafusion_physical_expr::PhysicalExpr;
+use datafusion_physical_expr::sort_properties::SortProperties;
+use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
 
 /// The [`JoinSelection`] rule tries to modify a given plan so that it can
 /// accommodate infinite sources and optimize joins in the plan according to
@@ -425,24 +426,97 @@ pub type PipelineFixerSubrule = dyn Fn(
     &ConfigOptions,
 ) -> Option<Result<PipelineStatePropagator>>;
 
-/// This subrule checks if we can replace a hash join with a symmetric hash
-/// join when we are dealing with infinite inputs on both sides. This change
-/// avoids pipeline breaking and preserves query runnability. If possible,
-/// this subrule makes this replacement; otherwise, it has no effect.
+/// Converts a hash join to a symmetric hash join in the case of infinite 
inputs on both sides.
+///
+/// This subrule checks if a hash join can be replaced with a symmetric hash 
join when dealing
+/// with unbounded (infinite) inputs on both sides. This replacement avoids 
pipeline breaking and
+/// preserves query runnability. If the replacement is applicable, this 
subrule makes this change;
+/// otherwise, it leaves the input unchanged.
+///
+/// # Arguments
+/// * `input` - The current state of the pipeline, including the execution 
plan.
+/// * `config_options` - Configuration options that might affect the 
transformation logic.
+///
+/// # Returns
+/// An `Option` that contains the `Result` of the transformation. If the 
transformation is not applicable,
+/// it returns `None`. If applicable, it returns `Some(Ok(...))` with the 
modified pipeline state,
+/// or `Some(Err(...))` if an error occurs during the transformation.
 fn hash_join_convert_symmetric_subrule(
     mut input: PipelineStatePropagator,
     config_options: &ConfigOptions,
 ) -> Option<Result<PipelineStatePropagator>> {
+    // Check if the current plan node is a HashJoinExec.
     if let Some(hash_join) = 
input.plan.as_any().downcast_ref::<HashJoinExec>() {
+        // Determine if left and right children are unbounded.
         let ub_flags = input.children_unbounded();
         let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]);
+        // Update the unbounded flag of the input.
         input.unbounded = left_unbounded || right_unbounded;
+        // Process only if both left and right sides are unbounded.
         let result = if left_unbounded && right_unbounded {
+            // Determine the partition mode based on configuration.
             let mode = if config_options.optimizer.repartition_joins {
                 StreamJoinPartitionMode::Partitioned
             } else {
                 StreamJoinPartitionMode::SinglePartition
             };
+            // A closure to determine the required sort order for each side of 
the join in the SymmetricHashJoinExec.
+            // This function checks if the columns involved in the filter have 
any specific ordering requirements.
+            // If the child nodes (left or right side of the join) already 
have a defined order and the columns used in the
+            // filter predicate are ordered, this function captures that 
ordering requirement. The identified order is then
+            // used in the SymmetricHashJoinExec to maintain bounded memory 
during join operations.
+            // However, if the child nodes do not have an inherent order, or 
if the filter columns are unordered,
+            // the function concludes that no specific order is required for 
the SymmetricHashJoinExec. This approach
+            // ensures that the symmetric hash join operation only imposes 
ordering constraints when necessary,
+            // based on the properties of the child nodes and the filter 
condition.
+            let determine_order = |side: JoinSide| -> 
Option<Vec<PhysicalSortExpr>> {
+                hash_join
+                    .filter()
+                    .map(|filter| {
+                        filter.column_indices().iter().any(
+                            |ColumnIndex {
+                                 index,
+                                 side: column_side,
+                             }| {
+                                // Skip if column side does not match the join 
side.
+                                if *column_side != side {
+                                    return false;
+                                }
+                                // Retrieve equivalence properties and schema 
based on the side.
+                                let (equivalence, schema) = match side {
+                                    JoinSide::Left => (
+                                        
hash_join.left().equivalence_properties(),
+                                        hash_join.left().schema(),
+                                    ),
+                                    JoinSide::Right => (
+                                        
hash_join.right().equivalence_properties(),
+                                        hash_join.right().schema(),
+                                    ),
+                                };
+
+                                let name = schema.field(*index).name();
+                                let col = Arc::new(Column::new(name, *index)) 
as _;
+                                // Check if the column is ordered.
+                                equivalence.get_expr_ordering(col).state
+                                    != SortProperties::Unordered
+                            },
+                        )
+                    })
+                    .unwrap_or(false)
+                    .then(|| {
+                        match side {
+                            JoinSide::Left => 
hash_join.left().output_ordering(),
+                            JoinSide::Right => 
hash_join.right().output_ordering(),
+                        }
+                        .map(|p| p.to_vec())
+                    })
+                    .flatten()
+            };
+
+            // Determine the sort order for both left and right sides.
+            let left_order = determine_order(JoinSide::Left);
+            let right_order = determine_order(JoinSide::Right);
+
             SymmetricHashJoinExec::try_new(
                 hash_join.left().clone(),
                 hash_join.right().clone(),
@@ -450,6 +524,8 @@ fn hash_join_convert_symmetric_subrule(
                 hash_join.filter().cloned(),
                 hash_join.join_type(),
                 hash_join.null_equals_null(),
+                left_order,
+                right_order,
                 mode,
             )
             .map(|exec| {
diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs 
b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
index d237a3e860..34d1af8556 100644
--- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs
+++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
@@ -795,6 +795,8 @@ fn try_swapping_with_sym_hash_join(
         new_filter,
         sym_join.join_type(),
         sym_join.null_equals_null(),
+        sym_join.right().output_ordering().map(|p| p.to_vec()),
+        sym_join.left().output_ordering().map(|p| p.to_vec()),
         sym_join.partition_mode(),
     )?)))
 }
@@ -2048,6 +2050,8 @@ mod tests {
             )),
             &JoinType::Inner,
             true,
+            None,
+            None,
             StreamJoinPartitionMode::SinglePartition,
         )?);
         let projection: Arc<dyn ExecutionPlan> = 
Arc::new(ProjectionExec::try_new(
diff --git a/datafusion/core/tests/sql/joins.rs 
b/datafusion/core/tests/sql/joins.rs
index d1f270b540..0cc102002e 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -124,6 +124,74 @@ async fn join_change_in_planner() -> Result<()> {
         [
             "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, 
on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND 
CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10",
             "  CoalesceBatchesExec: target_batch_size=8192",
+            "    RepartitionExec: partitioning=Hash([a2@1], 8), 
input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST",
+            "      RepartitionExec: partitioning=RoundRobinBatch(8), 
input_partitions=1",
+            // "     CsvExec: file_groups={1 group: [[tempdir/left.csv]]}, 
projection=[a1, a2], has_header=false",
+            "  CoalesceBatchesExec: target_batch_size=8192",
+            "    RepartitionExec: partitioning=Hash([a2@1], 8), 
input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST",
+            "      RepartitionExec: partitioning=RoundRobinBatch(8), 
input_partitions=1",
+            // "     CsvExec: file_groups={1 group: [[tempdir/right.csv]]}, 
projection=[a1, a2], has_header=false"
+        ]
+    };
+    let mut actual: Vec<&str> = formatted.trim().lines().collect();
+    // Remove CSV lines
+    actual.remove(4);
+    actual.remove(7);
+
+    assert_eq!(
+        expected,
+        actual[..],
+        "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+    );
+    Ok(())
+}
+
+#[tokio::test]
+async fn join_no_order_on_filter() -> Result<()> {
+    let config = SessionConfig::new().with_target_partitions(8);
+    let ctx = SessionContext::new_with_config(config);
+    let tmp_dir = TempDir::new().unwrap();
+    let left_file_path = tmp_dir.path().join("left.csv");
+    File::create(left_file_path.clone()).unwrap();
+    // Create schema
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("a1", DataType::UInt32, false),
+        Field::new("a2", DataType::UInt32, false),
+        Field::new("a3", DataType::UInt32, false),
+    ]));
+    // Specify the ordering:
+    let file_sort_order = vec![[datafusion_expr::col("a1")]
+        .into_iter()
+        .map(|e| {
+            let ascending = true;
+            let nulls_first = false;
+            e.sort(ascending, nulls_first)
+        })
+        .collect::<Vec<_>>()];
+    register_unbounded_file_with_ordering(
+        &ctx,
+        schema.clone(),
+        &left_file_path,
+        "left",
+        file_sort_order.clone(),
+    )?;
+    let right_file_path = tmp_dir.path().join("right.csv");
+    File::create(right_file_path.clone()).unwrap();
+    register_unbounded_file_with_ordering(
+        &ctx,
+        schema,
+        &right_file_path,
+        "right",
+        file_sort_order,
+    )?;
+    let sql = "SELECT * FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 
AND t1.a3 > t2.a3 + 3 AND t1.a3 < t2.a3 + 10";
+    let dataframe = ctx.sql(sql).await?;
+    let physical_plan = dataframe.create_physical_plan().await?;
+    let formatted = 
displayable(physical_plan.as_ref()).indent(true).to_string();
+    let expected = {
+        [
+            "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, 
on=[(a2@1, a2@1)], filter=CAST(a3@0 AS Int64) > CAST(a3@1 AS Int64) + 3 AND 
CAST(a3@0 AS Int64) < CAST(a3@1 AS Int64) + 10",
+            "  CoalesceBatchesExec: target_batch_size=8192",
             "    RepartitionExec: partitioning=Hash([a2@1], 8), 
input_partitions=8",
             "      RepartitionExec: partitioning=RoundRobinBatch(8), 
input_partitions=1",
             // "     CsvExec: file_groups={1 group: [[tempdir/left.csv]]}, 
projection=[a1, a2], has_header=false",
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs 
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index 2d38c2bd16..7719c72774 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -68,6 +68,7 @@ use 
datafusion_physical_expr::equivalence::join_equivalence_properties;
 use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
 
 use ahash::RandomState;
+use datafusion_physical_expr::PhysicalSortRequirement;
 use futures::Stream;
 use hashbrown::HashSet;
 use parking_lot::Mutex;
@@ -181,6 +182,10 @@ pub struct SymmetricHashJoinExec {
     column_indices: Vec<ColumnIndex>,
     /// If null_equals_null is true, null == null else null != null
     pub(crate) null_equals_null: bool,
+    /// Left side sort expression(s)
+    pub(crate) left_sort_exprs: Option<Vec<PhysicalSortExpr>>,
+    /// Right side sort expression(s)
+    pub(crate) right_sort_exprs: Option<Vec<PhysicalSortExpr>>,
     /// Partition Mode
     mode: StreamJoinPartitionMode,
 }
@@ -192,6 +197,7 @@ impl SymmetricHashJoinExec {
     /// - It is not possible to join the left and right sides on keys `on`, or
     /// - It fails to construct `SortedFilterExpr`s, or
     /// - It fails to create the [ExprIntervalGraph].
+    #[allow(clippy::too_many_arguments)]
     pub fn try_new(
         left: Arc<dyn ExecutionPlan>,
         right: Arc<dyn ExecutionPlan>,
@@ -199,6 +205,8 @@ impl SymmetricHashJoinExec {
         filter: Option<JoinFilter>,
         join_type: &JoinType,
         null_equals_null: bool,
+        left_sort_exprs: Option<Vec<PhysicalSortExpr>>,
+        right_sort_exprs: Option<Vec<PhysicalSortExpr>>,
         mode: StreamJoinPartitionMode,
     ) -> Result<Self> {
         let left_schema = left.schema();
@@ -232,6 +240,8 @@ impl SymmetricHashJoinExec {
             metrics: ExecutionPlanMetricsSet::new(),
             column_indices,
             null_equals_null,
+            left_sort_exprs,
+            right_sort_exprs,
             mode,
         })
     }
@@ -271,6 +281,16 @@ impl SymmetricHashJoinExec {
         self.mode
     }
 
+    /// Get left_sort_exprs
+    pub fn left_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> {
+        self.left_sort_exprs.as_deref()
+    }
+
+    /// Get right_sort_exprs
+    pub fn right_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> {
+        self.right_sort_exprs.as_deref()
+    }
+
     /// Check if order information covers every column in the filter 
expression.
     pub fn check_if_order_information_available(&self) -> Result<bool> {
         if let Some(filter) = self.filter() {
@@ -337,10 +357,6 @@ impl ExecutionPlan for SymmetricHashJoinExec {
         Ok(children.iter().any(|u| *u))
     }
 
-    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
-        vec![false, false]
-    }
-
     fn required_input_distribution(&self) -> Vec<Distribution> {
         match self.mode {
             StreamJoinPartitionMode::Partitioned => {
@@ -360,6 +376,17 @@ impl ExecutionPlan for SymmetricHashJoinExec {
         }
     }
 
+    fn required_input_ordering(&self) -> 
Vec<Option<Vec<PhysicalSortRequirement>>> {
+        vec![
+            self.left_sort_exprs
+                .as_ref()
+                .map(PhysicalSortRequirement::from_sort_exprs),
+            self.right_sort_exprs
+                .as_ref()
+                .map(PhysicalSortRequirement::from_sort_exprs),
+        ]
+    }
+
     fn output_partitioning(&self) -> Partitioning {
         let left_columns_len = self.left.schema().fields.len();
         partitioned_join_output_partitioning(
@@ -403,6 +430,8 @@ impl ExecutionPlan for SymmetricHashJoinExec {
             self.filter.clone(),
             &self.join_type,
             self.null_equals_null,
+            self.left_sort_exprs.clone(),
+            self.right_sort_exprs.clone(),
             self.mode,
         )?))
     }
@@ -431,24 +460,21 @@ impl ExecutionPlan for SymmetricHashJoinExec {
         }
         // If `filter_state` and `filter` are both present, then calculate 
sorted filter expressions
         // for both sides, and build an expression graph.
-        let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match 
(
-            self.left.output_ordering(),
-            self.right.output_ordering(),
-            &self.filter,
-        ) {
-            (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
-                let (left, right, graph) = prepare_sorted_exprs(
-                    filter,
-                    &self.left,
-                    &self.right,
-                    left_sort_exprs,
-                    right_sort_exprs,
-                )?;
-                (Some(left), Some(right), Some(graph))
-            }
-            // If `filter_state` or `filter` is not present, then return None 
for all three values:
-            _ => (None, None, None),
-        };
+        let (left_sorted_filter_expr, right_sorted_filter_expr, graph) =
+            match (&self.left_sort_exprs, &self.right_sort_exprs, 
&self.filter) {
+                (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) 
=> {
+                    let (left, right, graph) = prepare_sorted_exprs(
+                        filter,
+                        &self.left,
+                        &self.right,
+                        left_sort_exprs,
+                        right_sort_exprs,
+                    )?;
+                    (Some(left), Some(right), Some(graph))
+                }
+                // If `filter_state` or `filter` is not present, then return 
None for all three values:
+                _ => (None, None, None),
+            };
 
         let (on_left, on_right) = self.on.iter().cloned().unzip();
 
diff --git a/datafusion/physical-plan/src/joins/test_utils.rs 
b/datafusion/physical-plan/src/joins/test_utils.rs
index fbd52ddf0c..477e2de421 100644
--- a/datafusion/physical-plan/src/joins/test_utils.rs
+++ b/datafusion/physical-plan/src/joins/test_utils.rs
@@ -90,17 +90,19 @@ pub async fn partitioned_sym_join_with_filter(
 
     let join = SymmetricHashJoinExec::try_new(
         Arc::new(RepartitionExec::try_new(
-            left,
+            left.clone(),
             Partitioning::Hash(left_expr, partition_count),
         )?),
         Arc::new(RepartitionExec::try_new(
-            right,
+            right.clone(),
             Partitioning::Hash(right_expr, partition_count),
         )?),
         on,
         filter,
         join_type,
         null_equals_null,
+        left.output_ordering().map(|p| p.to_vec()),
+        right.output_ordering().map(|p| p.to_vec()),
         StreamJoinPartitionMode::Partitioned,
     )?;
 
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index f4089e83c6..ef08303d74 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1537,6 +1537,8 @@ message SymmetricHashJoinExecNode {
   StreamPartitionMode partition_mode = 6;
   bool null_equals_null = 7;
   JoinFilter filter = 8;
+  repeated PhysicalSortExprNode left_sort_exprs = 9;
+  repeated PhysicalSortExprNode right_sort_exprs = 10;
 }
 
 message InterleaveExecNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 4c9cbafd8f..6325b98d98 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -25673,6 +25673,12 @@ impl serde::Serialize for SymmetricHashJoinExecNode {
         if self.filter.is_some() {
             len += 1;
         }
+        if !self.left_sort_exprs.is_empty() {
+            len += 1;
+        }
+        if !self.right_sort_exprs.is_empty() {
+            len += 1;
+        }
         let mut struct_ser = 
serializer.serialize_struct("datafusion.SymmetricHashJoinExecNode", len)?;
         if let Some(v) = self.left.as_ref() {
             struct_ser.serialize_field("left", v)?;
@@ -25699,6 +25705,12 @@ impl serde::Serialize for SymmetricHashJoinExecNode {
         if let Some(v) = self.filter.as_ref() {
             struct_ser.serialize_field("filter", v)?;
         }
+        if !self.left_sort_exprs.is_empty() {
+            struct_ser.serialize_field("leftSortExprs", 
&self.left_sort_exprs)?;
+        }
+        if !self.right_sort_exprs.is_empty() {
+            struct_ser.serialize_field("rightSortExprs", 
&self.right_sort_exprs)?;
+        }
         struct_ser.end()
     }
 }
@@ -25719,6 +25731,10 @@ impl<'de> serde::Deserialize<'de> for 
SymmetricHashJoinExecNode {
             "null_equals_null",
             "nullEqualsNull",
             "filter",
+            "left_sort_exprs",
+            "leftSortExprs",
+            "right_sort_exprs",
+            "rightSortExprs",
         ];
 
         #[allow(clippy::enum_variant_names)]
@@ -25730,6 +25746,8 @@ impl<'de> serde::Deserialize<'de> for 
SymmetricHashJoinExecNode {
             PartitionMode,
             NullEqualsNull,
             Filter,
+            LeftSortExprs,
+            RightSortExprs,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -25758,6 +25776,8 @@ impl<'de> serde::Deserialize<'de> for 
SymmetricHashJoinExecNode {
                             "partitionMode" | "partition_mode" => 
Ok(GeneratedField::PartitionMode),
                             "nullEqualsNull" | "null_equals_null" => 
Ok(GeneratedField::NullEqualsNull),
                             "filter" => Ok(GeneratedField::Filter),
+                            "leftSortExprs" | "left_sort_exprs" => 
Ok(GeneratedField::LeftSortExprs),
+                            "rightSortExprs" | "right_sort_exprs" => 
Ok(GeneratedField::RightSortExprs),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -25784,6 +25804,8 @@ impl<'de> serde::Deserialize<'de> for 
SymmetricHashJoinExecNode {
                 let mut partition_mode__ = None;
                 let mut null_equals_null__ = None;
                 let mut filter__ = None;
+                let mut left_sort_exprs__ = None;
+                let mut right_sort_exprs__ = None;
                 while let Some(k) = map_.next_key()? {
                     match k {
                         GeneratedField::Left => {
@@ -25828,6 +25850,18 @@ impl<'de> serde::Deserialize<'de> for 
SymmetricHashJoinExecNode {
                             }
                             filter__ = map_.next_value()?;
                         }
+                        GeneratedField::LeftSortExprs => {
+                            if left_sort_exprs__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("leftSortExprs"));
+                            }
+                            left_sort_exprs__ = Some(map_.next_value()?);
+                        }
+                        GeneratedField::RightSortExprs => {
+                            if right_sort_exprs__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("rightSortExprs"));
+                            }
+                            right_sort_exprs__ = Some(map_.next_value()?);
+                        }
                     }
                 }
                 Ok(SymmetricHashJoinExecNode {
@@ -25838,6 +25872,8 @@ impl<'de> serde::Deserialize<'de> for 
SymmetricHashJoinExecNode {
                     partition_mode: partition_mode__.unwrap_or_default(),
                     null_equals_null: null_equals_null__.unwrap_or_default(),
                     filter: filter__,
+                    left_sort_exprs: left_sort_exprs__.unwrap_or_default(),
+                    right_sort_exprs: right_sort_exprs__.unwrap_or_default(),
                 })
             }
         }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 5db5f3cab7..74f6f34b9a 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2178,6 +2178,10 @@ pub struct SymmetricHashJoinExecNode {
     pub null_equals_null: bool,
     #[prost(message, optional, tag = "8")]
     pub filter: ::core::option::Option<JoinFilter>,
+    #[prost(message, repeated, tag = "9")]
+    pub left_sort_exprs: ::prost::alloc::vec::Vec<PhysicalSortExprNode>,
+    #[prost(message, repeated, tag = "10")]
+    pub right_sort_exprs: ::prost::alloc::vec::Vec<PhysicalSortExprNode>,
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs 
b/datafusion/proto/src/physical_plan/from_proto.rs
index 193c4dfe03..ea28eeee88 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -93,6 +93,36 @@ pub fn parse_physical_sort_expr(
     }
 }
 
+/// Parses a physical sort expressions from a protobuf.
+///
+/// # Arguments
+///
+/// * `proto` - Input proto with vector of physical sort expression node
+/// * `registry` - A registry knows how to build logical expressions out of 
user-defined function' names
+/// * `input_schema` - The Arrow schema for the input, used for determining 
expression data types
+///                    when performing type coercion.
+pub fn parse_physical_sort_exprs(
+    proto: &[protobuf::PhysicalSortExprNode],
+    registry: &dyn FunctionRegistry,
+    input_schema: &Schema,
+) -> Result<Vec<PhysicalSortExpr>> {
+    proto
+        .iter()
+        .map(|sort_expr| {
+            if let Some(expr) = &sort_expr.expr {
+                let expr = parse_physical_expr(expr.as_ref(), registry, 
input_schema)?;
+                let options = SortOptions {
+                    descending: !sort_expr.asc,
+                    nulls_first: sort_expr.nulls_first,
+                };
+                Ok(PhysicalSortExpr { expr, options })
+            } else {
+                Err(proto_error("Unexpected empty physical expression"))
+            }
+        })
+        .collect::<Result<Vec<_>>>()
+}
+
 /// Parses a physical window expr from a protobuf.
 ///
 /// # Arguments
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 95becb3fe4..f39f885b78 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -65,7 +65,8 @@ use prost::Message;
 use crate::common::str_to_byte;
 use crate::common::{byte_to_string, proto_error};
 use crate::physical_plan::from_proto::{
-    parse_physical_expr, parse_physical_sort_expr, 
parse_protobuf_file_scan_config,
+    parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs,
+    parse_protobuf_file_scan_config,
 };
 use crate::protobuf::physical_aggregate_expr_node::AggregateFunction;
 use crate::protobuf::physical_expr_node::ExprType;
@@ -646,6 +647,30 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     })
                     .map_or(Ok(None), |v: Result<JoinFilter>| v.map(Some))?;
 
+                let left_schema = left.schema();
+                let left_sort_exprs = parse_physical_sort_exprs(
+                    &sym_join.left_sort_exprs,
+                    registry,
+                    &left_schema,
+                )?;
+                let left_sort_exprs = if left_sort_exprs.is_empty() {
+                    None
+                } else {
+                    Some(left_sort_exprs)
+                };
+
+                let right_schema = right.schema();
+                let right_sort_exprs = parse_physical_sort_exprs(
+                    &sym_join.right_sort_exprs,
+                    registry,
+                    &right_schema,
+                )?;
+                let right_sort_exprs = if right_sort_exprs.is_empty() {
+                    None
+                } else {
+                    Some(right_sort_exprs)
+                };
+
                 let partition_mode =
                     
protobuf::StreamPartitionMode::try_from(sym_join.partition_mode).map_err(|_| {
                         proto_error(format!(
@@ -668,6 +693,8 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     filter,
                     &join_type.into(),
                     sym_join.null_equals_null,
+                    left_sort_exprs,
+                    right_sort_exprs,
                     partition_mode,
                 )
                 .map(|e| Arc::new(e) as _)
@@ -1233,6 +1260,40 @@ impl AsExecutionPlan for PhysicalPlanNode {
                 }
             };
 
+            let left_sort_exprs = exec
+                .left_sort_exprs()
+                .map(|exprs| {
+                    exprs
+                        .iter()
+                        .map(|expr| {
+                            Ok(protobuf::PhysicalSortExprNode {
+                                expr: 
Some(Box::new(expr.expr.to_owned().try_into()?)),
+                                asc: !expr.options.descending,
+                                nulls_first: expr.options.nulls_first,
+                            })
+                        })
+                        .collect::<Result<Vec<_>>>()
+                })
+                .transpose()?
+                .unwrap_or(vec![]);
+
+            let right_sort_exprs = exec
+                .right_sort_exprs()
+                .map(|exprs| {
+                    exprs
+                        .iter()
+                        .map(|expr| {
+                            Ok(protobuf::PhysicalSortExprNode {
+                                expr: 
Some(Box::new(expr.expr.to_owned().try_into()?)),
+                                asc: !expr.options.descending,
+                                nulls_first: expr.options.nulls_first,
+                            })
+                        })
+                        .collect::<Result<Vec<_>>>()
+                })
+                .transpose()?
+                .unwrap_or(vec![]);
+
             return Ok(protobuf::PhysicalPlanNode {
                 physical_plan_type: 
Some(PhysicalPlanType::SymmetricHashJoin(Box::new(
                     protobuf::SymmetricHashJoinExecNode {
@@ -1242,6 +1303,8 @@ impl AsExecutionPlan for PhysicalPlanNode {
                         join_type: join_type.into(),
                         partition_mode: partition_mode.into(),
                         null_equals_null: exec.null_equals_null(),
+                        left_sort_exprs,
+                        right_sort_exprs,
                         filter,
                     },
                 ))),
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index dd5fd73c69..9ee8d0d51d 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -903,17 +903,35 @@ fn roundtrip_sym_hash_join() -> Result<()> {
             StreamJoinPartitionMode::Partitioned,
             StreamJoinPartitionMode::SinglePartition,
         ] {
-            roundtrip_test(Arc::new(
-                
datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new(
-                    Arc::new(EmptyExec::new(schema_left.clone())),
-                    Arc::new(EmptyExec::new(schema_right.clone())),
-                    on.clone(),
+            for left_order in &[
+                None,
+                Some(vec![PhysicalSortExpr {
+                    expr: Arc::new(Column::new("col", 
schema_left.index_of("col")?)),
+                    options: Default::default(),
+                }]),
+            ] {
+                for right_order in &[
                     None,
-                    join_type,
-                    false,
-                    *partition_mode,
-                )?,
-            ))?;
+                    Some(vec![PhysicalSortExpr {
+                        expr: Arc::new(Column::new("col", 
schema_right.index_of("col")?)),
+                        options: Default::default(),
+                    }]),
+                ] {
+                    roundtrip_test(Arc::new(
+                        
datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new(
+                            Arc::new(EmptyExec::new(schema_left.clone())),
+                            Arc::new(EmptyExec::new(schema_right.clone())),
+                            on.clone(),
+                            None,
+                            join_type,
+                            false,
+                            left_order.clone(),
+                            right_order.clone(),
+                            *partition_mode,
+                        )?,
+                    ))?;
+                }
+            }
         }
     }
     Ok(())

Reply via email to