This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 78d33147dc Make benefits_from_input_partitioning Default in SHJ (#8801)
78d33147dc is described below
commit 78d33147dcd81beefd3cdebb91025b1bf2f51343
Author: Metehan Yıldırım <[email protected]>
AuthorDate: Wed Jan 10 22:20:44 2024 +0300
Make benefits_from_input_partitioning Default in SHJ (#8801)
* SHJ order fixing
* Update join_selection.rs
* Change proto type for sort exprs
---------
Co-authored-by: Mustafa Akur <[email protected]>
---
.../core/src/physical_optimizer/join_selection.rs | 88 ++++++++++++++++++++--
.../src/physical_optimizer/projection_pushdown.rs | 4 +
datafusion/core/tests/sql/joins.rs | 68 +++++++++++++++++
.../physical-plan/src/joins/symmetric_hash_join.rs | 70 +++++++++++------
datafusion/physical-plan/src/joins/test_utils.rs | 6 +-
datafusion/proto/proto/datafusion.proto | 2 +
datafusion/proto/src/generated/pbjson.rs | 36 +++++++++
datafusion/proto/src/generated/prost.rs | 4 +
datafusion/proto/src/physical_plan/from_proto.rs | 30 ++++++++
datafusion/proto/src/physical_plan/mod.rs | 65 +++++++++++++++-
.../proto/tests/cases/roundtrip_physical_plan.rs | 38 +++++++---
11 files changed, 370 insertions(+), 41 deletions(-)
diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs
b/datafusion/core/src/physical_optimizer/join_selection.rs
index 6b2fe24acf..ba66dca55b 100644
--- a/datafusion/core/src/physical_optimizer/join_selection.rs
+++ b/datafusion/core/src/physical_optimizer/join_selection.rs
@@ -38,11 +38,12 @@ use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::ExecutionPlan;
use arrow_schema::Schema;
-use datafusion_common::internal_err;
use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::{internal_err, JoinSide};
use datafusion_common::{DataFusionError, JoinType};
use datafusion_physical_expr::expressions::Column;
-use datafusion_physical_expr::PhysicalExpr;
+use datafusion_physical_expr::sort_properties::SortProperties;
+use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
/// The [`JoinSelection`] rule tries to modify a given plan so that it can
/// accommodate infinite sources and optimize joins in the plan according to
@@ -425,24 +426,97 @@ pub type PipelineFixerSubrule = dyn Fn(
&ConfigOptions,
) -> Option<Result<PipelineStatePropagator>>;
-/// This subrule checks if we can replace a hash join with a symmetric hash
-/// join when we are dealing with infinite inputs on both sides. This change
-/// avoids pipeline breaking and preserves query runnability. If possible,
-/// this subrule makes this replacement; otherwise, it has no effect.
+/// Converts a hash join to a symmetric hash join in the case of infinite
inputs on both sides.
+///
+/// This subrule checks if a hash join can be replaced with a symmetric hash
join when dealing
+/// with unbounded (infinite) inputs on both sides. This replacement avoids
pipeline breaking and
+/// preserves query runnability. If the replacement is applicable, this
subrule makes this change;
+/// otherwise, it leaves the input unchanged.
+///
+/// # Arguments
+/// * `input` - The current state of the pipeline, including the execution
plan.
+/// * `config_options` - Configuration options that might affect the
transformation logic.
+///
+/// # Returns
+/// An `Option` that contains the `Result` of the transformation. If the
transformation is not applicable,
+/// it returns `None`. If applicable, it returns `Some(Ok(...))` with the
modified pipeline state,
+/// or `Some(Err(...))` if an error occurs during the transformation.
fn hash_join_convert_symmetric_subrule(
mut input: PipelineStatePropagator,
config_options: &ConfigOptions,
) -> Option<Result<PipelineStatePropagator>> {
+ // Check if the current plan node is a HashJoinExec.
if let Some(hash_join) =
input.plan.as_any().downcast_ref::<HashJoinExec>() {
+ // Determine if left and right children are unbounded.
let ub_flags = input.children_unbounded();
let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]);
+ // Update the unbounded flag of the input.
input.unbounded = left_unbounded || right_unbounded;
+ // Process only if both left and right sides are unbounded.
let result = if left_unbounded && right_unbounded {
+ // Determine the partition mode based on configuration.
let mode = if config_options.optimizer.repartition_joins {
StreamJoinPartitionMode::Partitioned
} else {
StreamJoinPartitionMode::SinglePartition
};
+ // A closure to determine the required sort order for each side of
the join in the SymmetricHashJoinExec.
+ // This function checks if the columns involved in the filter have
any specific ordering requirements.
+ // If the child nodes (left or right side of the join) already
have a defined order and the columns used in the
+ // filter predicate are ordered, this function captures that
ordering requirement. The identified order is then
+ // used in the SymmetricHashJoinExec to maintain bounded memory
during join operations.
+ // However, if the child nodes do not have an inherent order, or
if the filter columns are unordered,
+ // the function concludes that no specific order is required for
the SymmetricHashJoinExec. This approach
+ // ensures that the symmetric hash join operation only imposes
ordering constraints when necessary,
+ // based on the properties of the child nodes and the filter
condition.
+ let determine_order = |side: JoinSide| ->
Option<Vec<PhysicalSortExpr>> {
+ hash_join
+ .filter()
+ .map(|filter| {
+ filter.column_indices().iter().any(
+ |ColumnIndex {
+ index,
+ side: column_side,
+ }| {
+ // Skip if column side does not match the join
side.
+ if *column_side != side {
+ return false;
+ }
+ // Retrieve equivalence properties and schema
based on the side.
+ let (equivalence, schema) = match side {
+ JoinSide::Left => (
+
hash_join.left().equivalence_properties(),
+ hash_join.left().schema(),
+ ),
+ JoinSide::Right => (
+
hash_join.right().equivalence_properties(),
+ hash_join.right().schema(),
+ ),
+ };
+
+ let name = schema.field(*index).name();
+ let col = Arc::new(Column::new(name, *index))
as _;
+ // Check if the column is ordered.
+ equivalence.get_expr_ordering(col).state
+ != SortProperties::Unordered
+ },
+ )
+ })
+ .unwrap_or(false)
+ .then(|| {
+ match side {
+ JoinSide::Left =>
hash_join.left().output_ordering(),
+ JoinSide::Right =>
hash_join.right().output_ordering(),
+ }
+ .map(|p| p.to_vec())
+ })
+ .flatten()
+ };
+
+ // Determine the sort order for both left and right sides.
+ let left_order = determine_order(JoinSide::Left);
+ let right_order = determine_order(JoinSide::Right);
+
SymmetricHashJoinExec::try_new(
hash_join.left().clone(),
hash_join.right().clone(),
@@ -450,6 +524,8 @@ fn hash_join_convert_symmetric_subrule(
hash_join.filter().cloned(),
hash_join.join_type(),
hash_join.null_equals_null(),
+ left_order,
+ right_order,
mode,
)
.map(|exec| {
diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs
b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
index d237a3e860..34d1af8556 100644
--- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs
+++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
@@ -795,6 +795,8 @@ fn try_swapping_with_sym_hash_join(
new_filter,
sym_join.join_type(),
sym_join.null_equals_null(),
+ sym_join.right().output_ordering().map(|p| p.to_vec()),
+ sym_join.left().output_ordering().map(|p| p.to_vec()),
sym_join.partition_mode(),
)?)))
}
@@ -2048,6 +2050,8 @@ mod tests {
)),
&JoinType::Inner,
true,
+ None,
+ None,
StreamJoinPartitionMode::SinglePartition,
)?);
let projection: Arc<dyn ExecutionPlan> =
Arc::new(ProjectionExec::try_new(
diff --git a/datafusion/core/tests/sql/joins.rs
b/datafusion/core/tests/sql/joins.rs
index d1f270b540..0cc102002e 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -124,6 +124,74 @@ async fn join_change_in_planner() -> Result<()> {
[
"SymmetricHashJoinExec: mode=Partitioned, join_type=Full,
on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND
CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10",
" CoalesceBatchesExec: target_batch_size=8192",
+ " RepartitionExec: partitioning=Hash([a2@1], 8),
input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST",
+ " RepartitionExec: partitioning=RoundRobinBatch(8),
input_partitions=1",
+ // " CsvExec: file_groups={1 group: [[tempdir/left.csv]]},
projection=[a1, a2], has_header=false",
+ " CoalesceBatchesExec: target_batch_size=8192",
+ " RepartitionExec: partitioning=Hash([a2@1], 8),
input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST",
+ " RepartitionExec: partitioning=RoundRobinBatch(8),
input_partitions=1",
+ // " CsvExec: file_groups={1 group: [[tempdir/right.csv]]},
projection=[a1, a2], has_header=false"
+ ]
+ };
+ let mut actual: Vec<&str> = formatted.trim().lines().collect();
+ // Remove CSV lines
+ actual.remove(4);
+ actual.remove(7);
+
+ assert_eq!(
+ expected,
+ actual[..],
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+ Ok(())
+}
+
+#[tokio::test]
+async fn join_no_order_on_filter() -> Result<()> {
+ let config = SessionConfig::new().with_target_partitions(8);
+ let ctx = SessionContext::new_with_config(config);
+ let tmp_dir = TempDir::new().unwrap();
+ let left_file_path = tmp_dir.path().join("left.csv");
+ File::create(left_file_path.clone()).unwrap();
+ // Create schema
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a1", DataType::UInt32, false),
+ Field::new("a2", DataType::UInt32, false),
+ Field::new("a3", DataType::UInt32, false),
+ ]));
+ // Specify the ordering:
+ let file_sort_order = vec![[datafusion_expr::col("a1")]
+ .into_iter()
+ .map(|e| {
+ let ascending = true;
+ let nulls_first = false;
+ e.sort(ascending, nulls_first)
+ })
+ .collect::<Vec<_>>()];
+ register_unbounded_file_with_ordering(
+ &ctx,
+ schema.clone(),
+ &left_file_path,
+ "left",
+ file_sort_order.clone(),
+ )?;
+ let right_file_path = tmp_dir.path().join("right.csv");
+ File::create(right_file_path.clone()).unwrap();
+ register_unbounded_file_with_ordering(
+ &ctx,
+ schema,
+ &right_file_path,
+ "right",
+ file_sort_order,
+ )?;
+ let sql = "SELECT * FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2
AND t1.a3 > t2.a3 + 3 AND t1.a3 < t2.a3 + 10";
+ let dataframe = ctx.sql(sql).await?;
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted =
displayable(physical_plan.as_ref()).indent(true).to_string();
+ let expected = {
+ [
+ "SymmetricHashJoinExec: mode=Partitioned, join_type=Full,
on=[(a2@1, a2@1)], filter=CAST(a3@0 AS Int64) > CAST(a3@1 AS Int64) + 3 AND
CAST(a3@0 AS Int64) < CAST(a3@1 AS Int64) + 10",
+ " CoalesceBatchesExec: target_batch_size=8192",
" RepartitionExec: partitioning=Hash([a2@1], 8),
input_partitions=8",
" RepartitionExec: partitioning=RoundRobinBatch(8),
input_partitions=1",
// " CsvExec: file_groups={1 group: [[tempdir/left.csv]]},
projection=[a1, a2], has_header=false",
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index 2d38c2bd16..7719c72774 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -68,6 +68,7 @@ use
datafusion_physical_expr::equivalence::join_equivalence_properties;
use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
use ahash::RandomState;
+use datafusion_physical_expr::PhysicalSortRequirement;
use futures::Stream;
use hashbrown::HashSet;
use parking_lot::Mutex;
@@ -181,6 +182,10 @@ pub struct SymmetricHashJoinExec {
column_indices: Vec<ColumnIndex>,
/// If null_equals_null is true, null == null else null != null
pub(crate) null_equals_null: bool,
+ /// Left side sort expression(s)
+ pub(crate) left_sort_exprs: Option<Vec<PhysicalSortExpr>>,
+ /// Right side sort expression(s)
+ pub(crate) right_sort_exprs: Option<Vec<PhysicalSortExpr>>,
/// Partition Mode
mode: StreamJoinPartitionMode,
}
@@ -192,6 +197,7 @@ impl SymmetricHashJoinExec {
/// - It is not possible to join the left and right sides on keys `on`, or
/// - It fails to construct `SortedFilterExpr`s, or
/// - It fails to create the [ExprIntervalGraph].
+ #[allow(clippy::too_many_arguments)]
pub fn try_new(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
@@ -199,6 +205,8 @@ impl SymmetricHashJoinExec {
filter: Option<JoinFilter>,
join_type: &JoinType,
null_equals_null: bool,
+ left_sort_exprs: Option<Vec<PhysicalSortExpr>>,
+ right_sort_exprs: Option<Vec<PhysicalSortExpr>>,
mode: StreamJoinPartitionMode,
) -> Result<Self> {
let left_schema = left.schema();
@@ -232,6 +240,8 @@ impl SymmetricHashJoinExec {
metrics: ExecutionPlanMetricsSet::new(),
column_indices,
null_equals_null,
+ left_sort_exprs,
+ right_sort_exprs,
mode,
})
}
@@ -271,6 +281,16 @@ impl SymmetricHashJoinExec {
self.mode
}
+ /// Get left_sort_exprs
+ pub fn left_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> {
+ self.left_sort_exprs.as_deref()
+ }
+
+ /// Get right_sort_exprs
+ pub fn right_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> {
+ self.right_sort_exprs.as_deref()
+ }
+
/// Check if order information covers every column in the filter
expression.
pub fn check_if_order_information_available(&self) -> Result<bool> {
if let Some(filter) = self.filter() {
@@ -337,10 +357,6 @@ impl ExecutionPlan for SymmetricHashJoinExec {
Ok(children.iter().any(|u| *u))
}
- fn benefits_from_input_partitioning(&self) -> Vec<bool> {
- vec![false, false]
- }
-
fn required_input_distribution(&self) -> Vec<Distribution> {
match self.mode {
StreamJoinPartitionMode::Partitioned => {
@@ -360,6 +376,17 @@ impl ExecutionPlan for SymmetricHashJoinExec {
}
}
+ fn required_input_ordering(&self) ->
Vec<Option<Vec<PhysicalSortRequirement>>> {
+ vec![
+ self.left_sort_exprs
+ .as_ref()
+ .map(PhysicalSortRequirement::from_sort_exprs),
+ self.right_sort_exprs
+ .as_ref()
+ .map(PhysicalSortRequirement::from_sort_exprs),
+ ]
+ }
+
fn output_partitioning(&self) -> Partitioning {
let left_columns_len = self.left.schema().fields.len();
partitioned_join_output_partitioning(
@@ -403,6 +430,8 @@ impl ExecutionPlan for SymmetricHashJoinExec {
self.filter.clone(),
&self.join_type,
self.null_equals_null,
+ self.left_sort_exprs.clone(),
+ self.right_sort_exprs.clone(),
self.mode,
)?))
}
@@ -431,24 +460,21 @@ impl ExecutionPlan for SymmetricHashJoinExec {
}
// If `filter_state` and `filter` are both present, then calculate
sorted filter expressions
// for both sides, and build an expression graph.
- let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match
(
- self.left.output_ordering(),
- self.right.output_ordering(),
- &self.filter,
- ) {
- (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
- let (left, right, graph) = prepare_sorted_exprs(
- filter,
- &self.left,
- &self.right,
- left_sort_exprs,
- right_sort_exprs,
- )?;
- (Some(left), Some(right), Some(graph))
- }
- // If `filter_state` or `filter` is not present, then return None
for all three values:
- _ => (None, None, None),
- };
+ let (left_sorted_filter_expr, right_sorted_filter_expr, graph) =
+ match (&self.left_sort_exprs, &self.right_sort_exprs,
&self.filter) {
+ (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter))
=> {
+ let (left, right, graph) = prepare_sorted_exprs(
+ filter,
+ &self.left,
+ &self.right,
+ left_sort_exprs,
+ right_sort_exprs,
+ )?;
+ (Some(left), Some(right), Some(graph))
+ }
+ // If `filter_state` or `filter` is not present, then return
None for all three values:
+ _ => (None, None, None),
+ };
let (on_left, on_right) = self.on.iter().cloned().unzip();
diff --git a/datafusion/physical-plan/src/joins/test_utils.rs
b/datafusion/physical-plan/src/joins/test_utils.rs
index fbd52ddf0c..477e2de421 100644
--- a/datafusion/physical-plan/src/joins/test_utils.rs
+++ b/datafusion/physical-plan/src/joins/test_utils.rs
@@ -90,17 +90,19 @@ pub async fn partitioned_sym_join_with_filter(
let join = SymmetricHashJoinExec::try_new(
Arc::new(RepartitionExec::try_new(
- left,
+ left.clone(),
Partitioning::Hash(left_expr, partition_count),
)?),
Arc::new(RepartitionExec::try_new(
- right,
+ right.clone(),
Partitioning::Hash(right_expr, partition_count),
)?),
on,
filter,
join_type,
null_equals_null,
+ left.output_ordering().map(|p| p.to_vec()),
+ right.output_ordering().map(|p| p.to_vec()),
StreamJoinPartitionMode::Partitioned,
)?;
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index f4089e83c6..ef08303d74 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1537,6 +1537,8 @@ message SymmetricHashJoinExecNode {
StreamPartitionMode partition_mode = 6;
bool null_equals_null = 7;
JoinFilter filter = 8;
+ repeated PhysicalSortExprNode left_sort_exprs = 9;
+ repeated PhysicalSortExprNode right_sort_exprs = 10;
}
message InterleaveExecNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 4c9cbafd8f..6325b98d98 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -25673,6 +25673,12 @@ impl serde::Serialize for SymmetricHashJoinExecNode {
if self.filter.is_some() {
len += 1;
}
+ if !self.left_sort_exprs.is_empty() {
+ len += 1;
+ }
+ if !self.right_sort_exprs.is_empty() {
+ len += 1;
+ }
let mut struct_ser =
serializer.serialize_struct("datafusion.SymmetricHashJoinExecNode", len)?;
if let Some(v) = self.left.as_ref() {
struct_ser.serialize_field("left", v)?;
@@ -25699,6 +25705,12 @@ impl serde::Serialize for SymmetricHashJoinExecNode {
if let Some(v) = self.filter.as_ref() {
struct_ser.serialize_field("filter", v)?;
}
+ if !self.left_sort_exprs.is_empty() {
+ struct_ser.serialize_field("leftSortExprs",
&self.left_sort_exprs)?;
+ }
+ if !self.right_sort_exprs.is_empty() {
+ struct_ser.serialize_field("rightSortExprs",
&self.right_sort_exprs)?;
+ }
struct_ser.end()
}
}
@@ -25719,6 +25731,10 @@ impl<'de> serde::Deserialize<'de> for
SymmetricHashJoinExecNode {
"null_equals_null",
"nullEqualsNull",
"filter",
+ "left_sort_exprs",
+ "leftSortExprs",
+ "right_sort_exprs",
+ "rightSortExprs",
];
#[allow(clippy::enum_variant_names)]
@@ -25730,6 +25746,8 @@ impl<'de> serde::Deserialize<'de> for
SymmetricHashJoinExecNode {
PartitionMode,
NullEqualsNull,
Filter,
+ LeftSortExprs,
+ RightSortExprs,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -25758,6 +25776,8 @@ impl<'de> serde::Deserialize<'de> for
SymmetricHashJoinExecNode {
"partitionMode" | "partition_mode" =>
Ok(GeneratedField::PartitionMode),
"nullEqualsNull" | "null_equals_null" =>
Ok(GeneratedField::NullEqualsNull),
"filter" => Ok(GeneratedField::Filter),
+ "leftSortExprs" | "left_sort_exprs" =>
Ok(GeneratedField::LeftSortExprs),
+ "rightSortExprs" | "right_sort_exprs" =>
Ok(GeneratedField::RightSortExprs),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -25784,6 +25804,8 @@ impl<'de> serde::Deserialize<'de> for
SymmetricHashJoinExecNode {
let mut partition_mode__ = None;
let mut null_equals_null__ = None;
let mut filter__ = None;
+ let mut left_sort_exprs__ = None;
+ let mut right_sort_exprs__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Left => {
@@ -25828,6 +25850,18 @@ impl<'de> serde::Deserialize<'de> for
SymmetricHashJoinExecNode {
}
filter__ = map_.next_value()?;
}
+ GeneratedField::LeftSortExprs => {
+ if left_sort_exprs__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("leftSortExprs"));
+ }
+ left_sort_exprs__ = Some(map_.next_value()?);
+ }
+ GeneratedField::RightSortExprs => {
+ if right_sort_exprs__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("rightSortExprs"));
+ }
+ right_sort_exprs__ = Some(map_.next_value()?);
+ }
}
}
Ok(SymmetricHashJoinExecNode {
@@ -25838,6 +25872,8 @@ impl<'de> serde::Deserialize<'de> for
SymmetricHashJoinExecNode {
partition_mode: partition_mode__.unwrap_or_default(),
null_equals_null: null_equals_null__.unwrap_or_default(),
filter: filter__,
+ left_sort_exprs: left_sort_exprs__.unwrap_or_default(),
+ right_sort_exprs: right_sort_exprs__.unwrap_or_default(),
})
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 5db5f3cab7..74f6f34b9a 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2178,6 +2178,10 @@ pub struct SymmetricHashJoinExecNode {
pub null_equals_null: bool,
#[prost(message, optional, tag = "8")]
pub filter: ::core::option::Option<JoinFilter>,
+ #[prost(message, repeated, tag = "9")]
+ pub left_sort_exprs: ::prost::alloc::vec::Vec<PhysicalSortExprNode>,
+ #[prost(message, repeated, tag = "10")]
+ pub right_sort_exprs: ::prost::alloc::vec::Vec<PhysicalSortExprNode>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index 193c4dfe03..ea28eeee88 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -93,6 +93,36 @@ pub fn parse_physical_sort_expr(
}
}
+/// Parses a physical sort expressions from a protobuf.
+///
+/// # Arguments
+///
+/// * `proto` - Input proto with vector of physical sort expression node
+/// * `registry` - A registry knows how to build logical expressions out of
user-defined function' names
+/// * `input_schema` - The Arrow schema for the input, used for determining
expression data types
+/// when performing type coercion.
+pub fn parse_physical_sort_exprs(
+ proto: &[protobuf::PhysicalSortExprNode],
+ registry: &dyn FunctionRegistry,
+ input_schema: &Schema,
+) -> Result<Vec<PhysicalSortExpr>> {
+ proto
+ .iter()
+ .map(|sort_expr| {
+ if let Some(expr) = &sort_expr.expr {
+ let expr = parse_physical_expr(expr.as_ref(), registry,
input_schema)?;
+ let options = SortOptions {
+ descending: !sort_expr.asc,
+ nulls_first: sort_expr.nulls_first,
+ };
+ Ok(PhysicalSortExpr { expr, options })
+ } else {
+ Err(proto_error("Unexpected empty physical expression"))
+ }
+ })
+ .collect::<Result<Vec<_>>>()
+}
+
/// Parses a physical window expr from a protobuf.
///
/// # Arguments
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 95becb3fe4..f39f885b78 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -65,7 +65,8 @@ use prost::Message;
use crate::common::str_to_byte;
use crate::common::{byte_to_string, proto_error};
use crate::physical_plan::from_proto::{
- parse_physical_expr, parse_physical_sort_expr,
parse_protobuf_file_scan_config,
+ parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs,
+ parse_protobuf_file_scan_config,
};
use crate::protobuf::physical_aggregate_expr_node::AggregateFunction;
use crate::protobuf::physical_expr_node::ExprType;
@@ -646,6 +647,30 @@ impl AsExecutionPlan for PhysicalPlanNode {
})
.map_or(Ok(None), |v: Result<JoinFilter>| v.map(Some))?;
+ let left_schema = left.schema();
+ let left_sort_exprs = parse_physical_sort_exprs(
+ &sym_join.left_sort_exprs,
+ registry,
+ &left_schema,
+ )?;
+ let left_sort_exprs = if left_sort_exprs.is_empty() {
+ None
+ } else {
+ Some(left_sort_exprs)
+ };
+
+ let right_schema = right.schema();
+ let right_sort_exprs = parse_physical_sort_exprs(
+ &sym_join.right_sort_exprs,
+ registry,
+ &right_schema,
+ )?;
+ let right_sort_exprs = if right_sort_exprs.is_empty() {
+ None
+ } else {
+ Some(right_sort_exprs)
+ };
+
let partition_mode =
protobuf::StreamPartitionMode::try_from(sym_join.partition_mode).map_err(|_| {
proto_error(format!(
@@ -668,6 +693,8 @@ impl AsExecutionPlan for PhysicalPlanNode {
filter,
&join_type.into(),
sym_join.null_equals_null,
+ left_sort_exprs,
+ right_sort_exprs,
partition_mode,
)
.map(|e| Arc::new(e) as _)
@@ -1233,6 +1260,40 @@ impl AsExecutionPlan for PhysicalPlanNode {
}
};
+ let left_sort_exprs = exec
+ .left_sort_exprs()
+ .map(|exprs| {
+ exprs
+ .iter()
+ .map(|expr| {
+ Ok(protobuf::PhysicalSortExprNode {
+ expr:
Some(Box::new(expr.expr.to_owned().try_into()?)),
+ asc: !expr.options.descending,
+ nulls_first: expr.options.nulls_first,
+ })
+ })
+ .collect::<Result<Vec<_>>>()
+ })
+ .transpose()?
+ .unwrap_or(vec![]);
+
+ let right_sort_exprs = exec
+ .right_sort_exprs()
+ .map(|exprs| {
+ exprs
+ .iter()
+ .map(|expr| {
+ Ok(protobuf::PhysicalSortExprNode {
+ expr:
Some(Box::new(expr.expr.to_owned().try_into()?)),
+ asc: !expr.options.descending,
+ nulls_first: expr.options.nulls_first,
+ })
+ })
+ .collect::<Result<Vec<_>>>()
+ })
+ .transpose()?
+ .unwrap_or(vec![]);
+
return Ok(protobuf::PhysicalPlanNode {
physical_plan_type:
Some(PhysicalPlanType::SymmetricHashJoin(Box::new(
protobuf::SymmetricHashJoinExecNode {
@@ -1242,6 +1303,8 @@ impl AsExecutionPlan for PhysicalPlanNode {
join_type: join_type.into(),
partition_mode: partition_mode.into(),
null_equals_null: exec.null_equals_null(),
+ left_sort_exprs,
+ right_sort_exprs,
filter,
},
))),
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index dd5fd73c69..9ee8d0d51d 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -903,17 +903,35 @@ fn roundtrip_sym_hash_join() -> Result<()> {
StreamJoinPartitionMode::Partitioned,
StreamJoinPartitionMode::SinglePartition,
] {
- roundtrip_test(Arc::new(
-
datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new(
- Arc::new(EmptyExec::new(schema_left.clone())),
- Arc::new(EmptyExec::new(schema_right.clone())),
- on.clone(),
+ for left_order in &[
+ None,
+ Some(vec![PhysicalSortExpr {
+ expr: Arc::new(Column::new("col",
schema_left.index_of("col")?)),
+ options: Default::default(),
+ }]),
+ ] {
+ for right_order in &[
None,
- join_type,
- false,
- *partition_mode,
- )?,
- ))?;
+ Some(vec![PhysicalSortExpr {
+ expr: Arc::new(Column::new("col",
schema_right.index_of("col")?)),
+ options: Default::default(),
+ }]),
+ ] {
+ roundtrip_test(Arc::new(
+
datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new(
+ Arc::new(EmptyExec::new(schema_left.clone())),
+ Arc::new(EmptyExec::new(schema_right.clone())),
+ on.clone(),
+ None,
+ join_type,
+ false,
+ left_order.clone(),
+ right_order.clone(),
+ *partition_mode,
+ )?,
+ ))?;
+ }
+ }
}
}
Ok(())