This is an automated email from the ASF dual-hosted git repository.
akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ee9078d8cf [MINOR] Reduce complexity on SHJ (#7607)
ee9078d8cf is described below
commit ee9078d8cfd99afa2ee2467af7837687770aac43
Author: Metehan Yıldırım <[email protected]>
AuthorDate: Fri Sep 22 15:23:24 2023 +0300
[MINOR] Reduce complexity on SHJ (#7607)
* Before fix
* Clippy
* Minor changes
* Simplifications
---------
Co-authored-by: Mustafa Akur <[email protected]>
---
.../physical-plan/src/joins/hash_join_utils.rs | 105 +--------------------
.../physical-plan/src/joins/symmetric_hash_join.rs | 46 +++++----
datafusion/physical-plan/src/joins/utils.rs | 87 +++++++++++++++++
3 files changed, 110 insertions(+), 128 deletions(-)
diff --git a/datafusion/physical-plan/src/joins/hash_join_utils.rs
b/datafusion/physical-plan/src/joins/hash_join_utils.rs
index bb79763458..525c1a7145 100644
--- a/datafusion/physical-plan/src/joins/hash_join_utils.rs
+++ b/datafusion/physical-plan/src/joins/hash_join_utils.rs
@@ -19,13 +19,12 @@
//! related functionality, used both in join calculations and optimization
rules.
use std::collections::{HashMap, VecDeque};
-use std::fmt::{Debug, Formatter};
+use std::fmt::Debug;
use std::ops::IndexMut;
use std::sync::Arc;
use std::{fmt, usize};
use crate::joins::utils::{JoinFilter, JoinSide};
-use crate::ExecutionPlan;
use arrow::compute::concat_batches;
use arrow::datatypes::{ArrowNativeType, SchemaRef};
@@ -34,13 +33,12 @@ use arrow_array::{ArrowPrimitiveType, NativeAdapter,
PrimitiveArray, RecordBatch
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_physical_expr::expressions::Column;
-use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval,
IntervalBound};
+use datafusion_physical_expr::intervals::{Interval, IntervalBound};
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use hashbrown::raw::RawTable;
use hashbrown::HashSet;
-use parking_lot::Mutex;
// Maps a `u64` hash value based on the build side ["on" values] to a list of
indices with this key's value.
// By allocating a `HashMap` with capacity for *at least* the number of rows
for entries at the build side,
@@ -446,105 +444,6 @@ fn convert_filter_columns(
})
}
-#[derive(Default)]
-pub struct IntervalCalculatorInnerState {
- /// Expression graph for interval calculations
- graph: Option<ExprIntervalGraph>,
- sorted_exprs: Vec<Option<SortedFilterExpr>>,
- calculated: bool,
-}
-
-impl Debug for IntervalCalculatorInnerState {
- fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
- write!(f, "Exprs({:?})", self.sorted_exprs)
- }
-}
-
-pub fn build_filter_expression_graph(
- interval_state: &Arc<Mutex<IntervalCalculatorInnerState>>,
- left: &Arc<dyn ExecutionPlan>,
- right: &Arc<dyn ExecutionPlan>,
- filter: &JoinFilter,
-) -> Result<(
- Option<SortedFilterExpr>,
- Option<SortedFilterExpr>,
- Option<ExprIntervalGraph>,
-)> {
- // Lock the mutex of the interval state:
- let mut filter_state = interval_state.lock();
- // If this is the first partition to be invoked, then we need to
initialize our state
- // (the expression graph for pruning, sorted filter expressions etc.)
- if !filter_state.calculated {
- // Interval calculations require each column to exhibit monotonicity
- // independently. However, a `PhysicalSortExpr` object defines a
- // lexicographical ordering, so we can only use their first elements.
- // when deducing column monotonicities.
- // TODO: Extend the `PhysicalSortExpr` mechanism to express independent
- // (i.e. simultaneous) ordering properties of columns.
-
- // Build sorted filter expressions for the left and right join side:
- let join_sides = [JoinSide::Left, JoinSide::Right];
- let children = [left, right];
- for (join_side, child) in join_sides.iter().zip(children.iter()) {
- let sorted_expr = child
- .output_ordering()
- .and_then(|orders| {
- build_filter_input_order(
- *join_side,
- filter,
- &child.schema(),
- &orders[0],
- )
- .transpose()
- })
- .transpose()?;
-
- filter_state.sorted_exprs.push(sorted_expr);
- }
-
- // Collect available sorted filter expressions:
- let sorted_exprs_size = filter_state.sorted_exprs.len();
- let mut sorted_exprs = filter_state
- .sorted_exprs
- .iter_mut()
- .flatten()
- .collect::<Vec<_>>();
-
- // Create the expression graph if we can create sorted filter
expressions for both children:
- filter_state.graph = if sorted_exprs.len() == sorted_exprs_size {
- let mut graph =
ExprIntervalGraph::try_new(filter.expression().clone())?;
-
- // Gather filter expressions:
- let filter_exprs = sorted_exprs
- .iter()
- .map(|sorted_expr| sorted_expr.filter_expr().clone())
- .collect::<Vec<_>>();
-
- // Gather node indices of converted filter expressions in
`SortedFilterExpr`s
- // using the filter columns vector:
- let child_node_indices = graph.gather_node_indices(&filter_exprs);
-
- // Update SortedFilterExpr instances with the corresponding node
indices:
- for (sorted_expr, (_, index)) in
- sorted_exprs.iter_mut().zip(child_node_indices.iter())
- {
- sorted_expr.set_node_index(*index);
- }
-
- Some(graph)
- } else {
- None
- };
- filter_state.calculated = true;
- }
- // Return the sorted filter expressions for both sides along with the
expression graph:
- Ok((
- filter_state.sorted_exprs[0].clone(),
- filter_state.sorted_exprs[1].clone(),
- filter_state.graph.as_ref().cloned(),
- ))
-}
-
/// The [SortedFilterExpr] object represents a sorted filter expression. It
/// contains the following information: The origin expression, the filter
/// expression, an interval encapsulating expression bounds, and a stable
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index a2fd127112..e6eb5dd695 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -35,10 +35,9 @@ use std::{any::Any, usize};
use crate::common::SharedMemoryReservation;
use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash};
use crate::joins::hash_join_utils::{
- build_filter_expression_graph, calculate_filter_expr_intervals,
combine_two_batches,
+ calculate_filter_expr_intervals, combine_two_batches,
convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
- get_pruning_semi_indices, record_visited_indices,
IntervalCalculatorInnerState,
- PruningJoinHashMap,
+ get_pruning_semi_indices, record_visited_indices, PruningJoinHashMap,
};
use crate::joins::StreamJoinPartitionMode;
use crate::DisplayAs;
@@ -69,6 +68,7 @@ use datafusion_execution::memory_pool::MemoryConsumer;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::intervals::ExprIntervalGraph;
+use crate::joins::utils::prepare_sorted_exprs;
use ahash::RandomState;
use futures::stream::{select, BoxStream};
use futures::{Stream, StreamExt};
@@ -174,8 +174,6 @@ pub struct SymmetricHashJoinExec {
pub(crate) filter: Option<JoinFilter>,
/// How the join is performed
pub(crate) join_type: JoinType,
- /// Expression graph and `SortedFilterExpr`s for interval calculations
- filter_state: Option<Arc<Mutex<IntervalCalculatorInnerState>>>,
/// The schema once the join is applied
schema: SchemaRef,
/// Shares the `RandomState` for the hashing algorithm
@@ -285,20 +283,12 @@ impl SymmetricHashJoinExec {
// Initialize the random state for the join operation:
let random_state = RandomState::with_seeds(0, 0, 0, 0);
- let filter_state = if filter.is_some() {
- let inner_state = IntervalCalculatorInnerState::default();
- Some(Arc::new(Mutex::new(inner_state)))
- } else {
- None
- };
-
Ok(SymmetricHashJoinExec {
left,
right,
on,
filter,
join_type: *join_type,
- filter_state,
schema: Arc::new(schema),
random_state,
metrics: ExecutionPlanMetricsSet::new(),
@@ -496,21 +486,27 @@ impl ExecutionPlan for SymmetricHashJoinExec {
);
}
// If `filter_state` and `filter` are both present, then calculate
sorted filter expressions
- // for both sides, and build an expression graph if one is not already
built.
- let (left_sorted_filter_expr, right_sorted_filter_expr, graph) =
- match (&self.filter_state, &self.filter) {
- (Some(interval_state), Some(filter)) =>
build_filter_expression_graph(
- interval_state,
+ // for both sides, and build an expression graph.
+ let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match
(
+ self.left.output_ordering(),
+ self.right.output_ordering(),
+ &self.filter,
+ ) {
+ (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
+ let (left, right, graph) = prepare_sorted_exprs(
+ filter,
&self.left,
&self.right,
- filter,
- )?,
- // If `filter_state` or `filter` is not present, then return
None for all three values:
- (_, _) => (None, None, None),
- };
+ left_sort_exprs,
+ right_sort_exprs,
+ )?;
+ (Some(left), Some(right), Some(graph))
+ }
+ // If `filter_state` or `filter` is not present, then return None
for all three values:
+ _ => (None, None, None),
+ };
- let on_left = self.on.iter().map(|on|
on.0.clone()).collect::<Vec<_>>();
- let on_right = self.on.iter().map(|on|
on.1.clone()).collect::<Vec<_>>();
+ let (on_left, on_right) = self.on.iter().cloned().unzip();
let left_side_joiner =
OneSideHashJoiner::new(JoinSide::Left, on_left,
self.left.schema());
diff --git a/datafusion/physical-plan/src/joins/utils.rs
b/datafusion/physical-plan/src/joins/utils.rs
index 67f60e57d7..daaa16e055 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -50,6 +50,8 @@ use datafusion_physical_expr::{
PhysicalSortExpr,
};
+use crate::joins::hash_join_utils::{build_filter_input_order,
SortedFilterExpr};
+use datafusion_physical_expr::intervals::ExprIntervalGraph;
use datafusion_physical_expr::utils::merge_vectors;
use futures::future::{BoxFuture, Shared};
use futures::{ready, FutureExt};
@@ -1295,6 +1297,91 @@ impl BuildProbeJoinMetrics {
}
}
+/// Updates sorted filter expressions with corresponding node indices from the
+/// expression interval graph.
+///
+/// This function iterates through the provided sorted filter expressions,
+/// gathers the corresponding node indices from the expression interval graph,
+/// and then updates the sorted expressions with these indices. It ensures
+/// that these sorted expressions are aligned with the structure of the graph.
+fn update_sorted_exprs_with_node_indices(
+ graph: &mut ExprIntervalGraph,
+ sorted_exprs: &mut [SortedFilterExpr],
+) {
+ // Extract filter expressions from the sorted expressions:
+ let filter_exprs = sorted_exprs
+ .iter()
+ .map(|expr| expr.filter_expr().clone())
+ .collect::<Vec<_>>();
+
+ // Gather corresponding node indices for the extracted filter expressions
from the graph:
+ let child_node_indices = graph.gather_node_indices(&filter_exprs);
+
+ // Iterate through the sorted expressions and the gathered node indices:
+ for (sorted_expr, (_, index)) in
sorted_exprs.iter_mut().zip(child_node_indices) {
+ // Update each sorted expression with the corresponding node index:
+ sorted_expr.set_node_index(index);
+ }
+}
+
+/// Prepares and sorts expressions based on a given filter, left and right
execution plans, and sort expressions.
+///
+/// # Arguments
+///
+/// * `filter` - The join filter to base the sorting on.
+/// * `left` - The left execution plan.
+/// * `right` - The right execution plan.
+/// * `left_sort_exprs` - The expressions to sort on the left side.
+/// * `right_sort_exprs` - The expressions to sort on the right side.
+///
+/// # Returns
+///
+/// * A tuple consisting of the sorted filter expression for the left and
right sides, and an expression interval graph.
+pub fn prepare_sorted_exprs(
+ filter: &JoinFilter,
+ left: &Arc<dyn ExecutionPlan>,
+ right: &Arc<dyn ExecutionPlan>,
+ left_sort_exprs: &[PhysicalSortExpr],
+ right_sort_exprs: &[PhysicalSortExpr],
+) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> {
+ // Build the filter order for the left side
+ let err =
+ || DataFusionError::Plan("Filter does not include the child
order".to_owned());
+
+ let left_temp_sorted_filter_expr = build_filter_input_order(
+ JoinSide::Left,
+ filter,
+ &left.schema(),
+ &left_sort_exprs[0],
+ )?
+ .ok_or_else(err)?;
+
+ // Build the filter order for the right side
+ let right_temp_sorted_filter_expr = build_filter_input_order(
+ JoinSide::Right,
+ filter,
+ &right.schema(),
+ &right_sort_exprs[0],
+ )?
+ .ok_or_else(err)?;
+
+ // Collect the sorted expressions
+ let mut sorted_exprs =
+ vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr];
+
+ // Build the expression interval graph
+ let mut graph = ExprIntervalGraph::try_new(filter.expression().clone())?;
+
+ // Update sorted expressions with node indices
+ update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs);
+
+ // Swap and remove to get the final sorted filter expressions
+ let right_sorted_filter_expr = sorted_exprs.swap_remove(1);
+ let left_sorted_filter_expr = sorted_exprs.swap_remove(0);
+
+ Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph))
+}
+
#[cfg(test)]
mod tests {
use super::*;