This is an automated email from the ASF dual-hosted git repository.

akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ee9078d8cf [MINOR] Reduce complexity on SHJ (#7607)
ee9078d8cf is described below

commit ee9078d8cfd99afa2ee2467af7837687770aac43
Author: Metehan Yıldırım <[email protected]>
AuthorDate: Fri Sep 22 15:23:24 2023 +0300

    [MINOR] Reduce complexity on SHJ (#7607)
    
    * Before fix
    
    * Clippy
    
    * Minor changes
    
    * Simplifications
    
    ---------
    
    Co-authored-by: Mustafa Akur <[email protected]>
---
 .../physical-plan/src/joins/hash_join_utils.rs     | 105 +--------------------
 .../physical-plan/src/joins/symmetric_hash_join.rs |  46 +++++----
 datafusion/physical-plan/src/joins/utils.rs        |  87 +++++++++++++++++
 3 files changed, 110 insertions(+), 128 deletions(-)

diff --git a/datafusion/physical-plan/src/joins/hash_join_utils.rs 
b/datafusion/physical-plan/src/joins/hash_join_utils.rs
index bb79763458..525c1a7145 100644
--- a/datafusion/physical-plan/src/joins/hash_join_utils.rs
+++ b/datafusion/physical-plan/src/joins/hash_join_utils.rs
@@ -19,13 +19,12 @@
 //! related functionality, used both in join calculations and optimization 
rules.
 
 use std::collections::{HashMap, VecDeque};
-use std::fmt::{Debug, Formatter};
+use std::fmt::Debug;
 use std::ops::IndexMut;
 use std::sync::Arc;
 use std::{fmt, usize};
 
 use crate::joins::utils::{JoinFilter, JoinSide};
-use crate::ExecutionPlan;
 
 use arrow::compute::concat_batches;
 use arrow::datatypes::{ArrowNativeType, SchemaRef};
@@ -34,13 +33,12 @@ use arrow_array::{ArrowPrimitiveType, NativeAdapter, 
PrimitiveArray, RecordBatch
 use datafusion_common::tree_node::{Transformed, TreeNode};
 use datafusion_common::{DataFusionError, Result, ScalarValue};
 use datafusion_physical_expr::expressions::Column;
-use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval, 
IntervalBound};
+use datafusion_physical_expr::intervals::{Interval, IntervalBound};
 use datafusion_physical_expr::utils::collect_columns;
 use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
 
 use hashbrown::raw::RawTable;
 use hashbrown::HashSet;
-use parking_lot::Mutex;
 
 // Maps a `u64` hash value based on the build side ["on" values] to a list of 
indices with this key's value.
 // By allocating a `HashMap` with capacity for *at least* the number of rows 
for entries at the build side,
@@ -446,105 +444,6 @@ fn convert_filter_columns(
     })
 }
 
-#[derive(Default)]
-pub struct IntervalCalculatorInnerState {
-    /// Expression graph for interval calculations
-    graph: Option<ExprIntervalGraph>,
-    sorted_exprs: Vec<Option<SortedFilterExpr>>,
-    calculated: bool,
-}
-
-impl Debug for IntervalCalculatorInnerState {
-    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
-        write!(f, "Exprs({:?})", self.sorted_exprs)
-    }
-}
-
-pub fn build_filter_expression_graph(
-    interval_state: &Arc<Mutex<IntervalCalculatorInnerState>>,
-    left: &Arc<dyn ExecutionPlan>,
-    right: &Arc<dyn ExecutionPlan>,
-    filter: &JoinFilter,
-) -> Result<(
-    Option<SortedFilterExpr>,
-    Option<SortedFilterExpr>,
-    Option<ExprIntervalGraph>,
-)> {
-    // Lock the mutex of the interval state:
-    let mut filter_state = interval_state.lock();
-    // If this is the first partition to be invoked, then we need to 
initialize our state
-    // (the expression graph for pruning, sorted filter expressions etc.)
-    if !filter_state.calculated {
-        // Interval calculations require each column to exhibit monotonicity
-        // independently. However, a `PhysicalSortExpr` object defines a
-        // lexicographical ordering, so we can only use their first elements.
-        // when deducing column monotonicities.
-        // TODO: Extend the `PhysicalSortExpr` mechanism to express independent
-        //       (i.e. simultaneous) ordering properties of columns.
-
-        // Build sorted filter expressions for the left and right join side:
-        let join_sides = [JoinSide::Left, JoinSide::Right];
-        let children = [left, right];
-        for (join_side, child) in join_sides.iter().zip(children.iter()) {
-            let sorted_expr = child
-                .output_ordering()
-                .and_then(|orders| {
-                    build_filter_input_order(
-                        *join_side,
-                        filter,
-                        &child.schema(),
-                        &orders[0],
-                    )
-                    .transpose()
-                })
-                .transpose()?;
-
-            filter_state.sorted_exprs.push(sorted_expr);
-        }
-
-        // Collect available sorted filter expressions:
-        let sorted_exprs_size = filter_state.sorted_exprs.len();
-        let mut sorted_exprs = filter_state
-            .sorted_exprs
-            .iter_mut()
-            .flatten()
-            .collect::<Vec<_>>();
-
-        // Create the expression graph if we can create sorted filter 
expressions for both children:
-        filter_state.graph = if sorted_exprs.len() == sorted_exprs_size {
-            let mut graph = 
ExprIntervalGraph::try_new(filter.expression().clone())?;
-
-            // Gather filter expressions:
-            let filter_exprs = sorted_exprs
-                .iter()
-                .map(|sorted_expr| sorted_expr.filter_expr().clone())
-                .collect::<Vec<_>>();
-
-            // Gather node indices of converted filter expressions in 
`SortedFilterExpr`s
-            // using the filter columns vector:
-            let child_node_indices = graph.gather_node_indices(&filter_exprs);
-
-            // Update SortedFilterExpr instances with the corresponding node 
indices:
-            for (sorted_expr, (_, index)) in
-                sorted_exprs.iter_mut().zip(child_node_indices.iter())
-            {
-                sorted_expr.set_node_index(*index);
-            }
-
-            Some(graph)
-        } else {
-            None
-        };
-        filter_state.calculated = true;
-    }
-    // Return the sorted filter expressions for both sides along with the 
expression graph:
-    Ok((
-        filter_state.sorted_exprs[0].clone(),
-        filter_state.sorted_exprs[1].clone(),
-        filter_state.graph.as_ref().cloned(),
-    ))
-}
-
 /// The [SortedFilterExpr] object represents a sorted filter expression. It
 /// contains the following information: The origin expression, the filter
 /// expression, an interval encapsulating expression bounds, and a stable
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs 
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index a2fd127112..e6eb5dd695 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -35,10 +35,9 @@ use std::{any::Any, usize};
 use crate::common::SharedMemoryReservation;
 use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash};
 use crate::joins::hash_join_utils::{
-    build_filter_expression_graph, calculate_filter_expr_intervals, 
combine_two_batches,
+    calculate_filter_expr_intervals, combine_two_batches,
     convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
-    get_pruning_semi_indices, record_visited_indices, 
IntervalCalculatorInnerState,
-    PruningJoinHashMap,
+    get_pruning_semi_indices, record_visited_indices, PruningJoinHashMap,
 };
 use crate::joins::StreamJoinPartitionMode;
 use crate::DisplayAs;
@@ -69,6 +68,7 @@ use datafusion_execution::memory_pool::MemoryConsumer;
 use datafusion_execution::TaskContext;
 use datafusion_physical_expr::intervals::ExprIntervalGraph;
 
+use crate::joins::utils::prepare_sorted_exprs;
 use ahash::RandomState;
 use futures::stream::{select, BoxStream};
 use futures::{Stream, StreamExt};
@@ -174,8 +174,6 @@ pub struct SymmetricHashJoinExec {
     pub(crate) filter: Option<JoinFilter>,
     /// How the join is performed
     pub(crate) join_type: JoinType,
-    /// Expression graph and `SortedFilterExpr`s for interval calculations
-    filter_state: Option<Arc<Mutex<IntervalCalculatorInnerState>>>,
     /// The schema once the join is applied
     schema: SchemaRef,
     /// Shares the `RandomState` for the hashing algorithm
@@ -285,20 +283,12 @@ impl SymmetricHashJoinExec {
         // Initialize the random state for the join operation:
         let random_state = RandomState::with_seeds(0, 0, 0, 0);
 
-        let filter_state = if filter.is_some() {
-            let inner_state = IntervalCalculatorInnerState::default();
-            Some(Arc::new(Mutex::new(inner_state)))
-        } else {
-            None
-        };
-
         Ok(SymmetricHashJoinExec {
             left,
             right,
             on,
             filter,
             join_type: *join_type,
-            filter_state,
             schema: Arc::new(schema),
             random_state,
             metrics: ExecutionPlanMetricsSet::new(),
@@ -496,21 +486,27 @@ impl ExecutionPlan for SymmetricHashJoinExec {
             );
         }
         // If `filter_state` and `filter` are both present, then calculate 
sorted filter expressions
-        // for both sides, and build an expression graph if one is not already 
built.
-        let (left_sorted_filter_expr, right_sorted_filter_expr, graph) =
-            match (&self.filter_state, &self.filter) {
-                (Some(interval_state), Some(filter)) => 
build_filter_expression_graph(
-                    interval_state,
+        // for both sides, and build an expression graph.
+        let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match 
(
+            self.left.output_ordering(),
+            self.right.output_ordering(),
+            &self.filter,
+        ) {
+            (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
+                let (left, right, graph) = prepare_sorted_exprs(
+                    filter,
                     &self.left,
                     &self.right,
-                    filter,
-                )?,
-                // If `filter_state` or `filter` is not present, then return 
None for all three values:
-                (_, _) => (None, None, None),
-            };
+                    left_sort_exprs,
+                    right_sort_exprs,
+                )?;
+                (Some(left), Some(right), Some(graph))
+            }
+            // If `filter_state` or `filter` is not present, then return None 
for all three values:
+            _ => (None, None, None),
+        };
 
-        let on_left = self.on.iter().map(|on| 
on.0.clone()).collect::<Vec<_>>();
-        let on_right = self.on.iter().map(|on| 
on.1.clone()).collect::<Vec<_>>();
+        let (on_left, on_right) = self.on.iter().cloned().unzip();
 
         let left_side_joiner =
             OneSideHashJoiner::new(JoinSide::Left, on_left, 
self.left.schema());
diff --git a/datafusion/physical-plan/src/joins/utils.rs 
b/datafusion/physical-plan/src/joins/utils.rs
index 67f60e57d7..daaa16e055 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -50,6 +50,8 @@ use datafusion_physical_expr::{
     PhysicalSortExpr,
 };
 
+use crate::joins::hash_join_utils::{build_filter_input_order, 
SortedFilterExpr};
+use datafusion_physical_expr::intervals::ExprIntervalGraph;
 use datafusion_physical_expr::utils::merge_vectors;
 use futures::future::{BoxFuture, Shared};
 use futures::{ready, FutureExt};
@@ -1295,6 +1297,91 @@ impl BuildProbeJoinMetrics {
     }
 }
 
+/// Updates sorted filter expressions with corresponding node indices from the
+/// expression interval graph.
+///
+/// This function iterates through the provided sorted filter expressions,
+/// gathers the corresponding node indices from the expression interval graph,
+/// and then updates the sorted expressions with these indices. It ensures
+/// that these sorted expressions are aligned with the structure of the graph.
+fn update_sorted_exprs_with_node_indices(
+    graph: &mut ExprIntervalGraph,
+    sorted_exprs: &mut [SortedFilterExpr],
+) {
+    // Extract filter expressions from the sorted expressions:
+    let filter_exprs = sorted_exprs
+        .iter()
+        .map(|expr| expr.filter_expr().clone())
+        .collect::<Vec<_>>();
+
+    // Gather corresponding node indices for the extracted filter expressions 
from the graph:
+    let child_node_indices = graph.gather_node_indices(&filter_exprs);
+
+    // Iterate through the sorted expressions and the gathered node indices:
+    for (sorted_expr, (_, index)) in 
sorted_exprs.iter_mut().zip(child_node_indices) {
+        // Update each sorted expression with the corresponding node index:
+        sorted_expr.set_node_index(index);
+    }
+}
+
+/// Prepares and sorts expressions based on a given filter, left and right 
execution plans, and sort expressions.
+///
+/// # Arguments
+///
+/// * `filter` - The join filter to base the sorting on.
+/// * `left` - The left execution plan.
+/// * `right` - The right execution plan.
+/// * `left_sort_exprs` - The expressions to sort on the left side.
+/// * `right_sort_exprs` - The expressions to sort on the right side.
+///
+/// # Returns
+///
+/// * A tuple consisting of the sorted filter expression for the left and 
right sides, and an expression interval graph.
+pub fn prepare_sorted_exprs(
+    filter: &JoinFilter,
+    left: &Arc<dyn ExecutionPlan>,
+    right: &Arc<dyn ExecutionPlan>,
+    left_sort_exprs: &[PhysicalSortExpr],
+    right_sort_exprs: &[PhysicalSortExpr],
+) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> {
+    // Build the filter order for the left side
+    let err =
+        || DataFusionError::Plan("Filter does not include the child 
order".to_owned());
+
+    let left_temp_sorted_filter_expr = build_filter_input_order(
+        JoinSide::Left,
+        filter,
+        &left.schema(),
+        &left_sort_exprs[0],
+    )?
+    .ok_or_else(err)?;
+
+    // Build the filter order for the right side
+    let right_temp_sorted_filter_expr = build_filter_input_order(
+        JoinSide::Right,
+        filter,
+        &right.schema(),
+        &right_sort_exprs[0],
+    )?
+    .ok_or_else(err)?;
+
+    // Collect the sorted expressions
+    let mut sorted_exprs =
+        vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr];
+
+    // Build the expression interval graph
+    let mut graph = ExprIntervalGraph::try_new(filter.expression().clone())?;
+
+    // Update sorted expressions with node indices
+    update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs);
+
+    // Swap and remove to get the final sorted filter expressions
+    let right_sorted_filter_expr = sorted_exprs.swap_remove(1);
+    let left_sorted_filter_expr = sorted_exprs.swap_remove(0);
+
+    Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph))
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;

Reply via email to