This is an automated email from the ASF dual-hosted git repository.

ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 18c75669e1 [MINOR] Remove duplicate test utility and move one utility 
function for better organization (#8652)
18c75669e1 is described below

commit 18c75669e18929ca095c47af4ebf285b14d2c814
Author: Metehan Yıldırım <[email protected]>
AuthorDate: Mon Dec 25 23:12:51 2023 +0300

    [MINOR] Remove duplicate test utility and move one utility function for 
better organization (#8652)
    
    * Code rearrange
    
    * Update stream_join_utils.rs
---
 .../physical-plan/src/joins/stream_join_utils.rs   | 156 +++++++++++++--------
 .../physical-plan/src/joins/symmetric_hash_join.rs |  11 +-
 datafusion/physical-plan/src/joins/utils.rs        |  90 +-----------
 3 files changed, 104 insertions(+), 153 deletions(-)

diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs 
b/datafusion/physical-plan/src/joins/stream_join_utils.rs
index 50b1618a35..9a4c989276 100644
--- a/datafusion/physical-plan/src/joins/stream_join_utils.rs
+++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs
@@ -25,23 +25,25 @@ use std::usize;
 
 use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult};
 use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder};
-use crate::{handle_async_state, handle_state, metrics};
+use crate::{handle_async_state, handle_state, metrics, ExecutionPlan};
 
 use arrow::compute::concat_batches;
 use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, 
RecordBatch};
 use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder};
 use arrow_schema::{Schema, SchemaRef};
-use async_trait::async_trait;
 use datafusion_common::tree_node::{Transformed, TreeNode};
 use datafusion_common::{
-    arrow_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue,
+    arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, 
Result,
+    ScalarValue,
 };
 use datafusion_execution::SendableRecordBatchStream;
 use datafusion_expr::interval_arithmetic::Interval;
 use datafusion_physical_expr::expressions::Column;
+use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
 use datafusion_physical_expr::utils::collect_columns;
 use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
 
+use async_trait::async_trait;
 use futures::{ready, FutureExt, StreamExt};
 use hashbrown::raw::RawTable;
 use hashbrown::HashSet;
@@ -175,7 +177,7 @@ impl PruningJoinHashMap {
         prune_length: usize,
         deleting_offset: u64,
         shrink_factor: usize,
-    ) -> Result<()> {
+    ) {
         // Remove elements from the list based on the pruning length.
         self.next.drain(0..prune_length);
 
@@ -198,11 +200,10 @@ impl PruningJoinHashMap {
 
         // Shrink the map if necessary.
         self.shrink_if_necessary(shrink_factor);
-        Ok(())
     }
 }
 
-pub fn check_filter_expr_contains_sort_information(
+fn check_filter_expr_contains_sort_information(
     expr: &Arc<dyn PhysicalExpr>,
     reference: &Arc<dyn PhysicalExpr>,
 ) -> bool {
@@ -227,7 +228,7 @@ pub fn map_origin_col_to_filter_col(
     side: &JoinSide,
 ) -> Result<HashMap<Column, Column>> {
     let filter_schema = filter.schema();
-    let mut col_to_col_map: HashMap<Column, Column> = HashMap::new();
+    let mut col_to_col_map = HashMap::<Column, Column>::new();
     for (filter_schema_index, index) in 
filter.column_indices().iter().enumerate() {
         if index.side.eq(side) {
             // Get the main field from column index:
@@ -581,7 +582,7 @@ where
     // get the semi index
     (0..prune_length)
         .filter_map(|idx| 
(bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
-        .collect::<PrimitiveArray<T>>()
+        .collect()
 }
 
 pub fn combine_two_batches(
@@ -763,7 +764,6 @@ pub trait EagerJoinStream {
                 if batch.num_rows() == 0 {
                     return Ok(StatefulStreamResult::Continue);
                 }
-
                 self.set_state(EagerJoinStreamState::PullLeft);
                 self.process_batch_from_right(batch)
             }
@@ -1032,6 +1032,91 @@ impl StreamJoinMetrics {
     }
 }
 
+/// Updates sorted filter expressions with corresponding node indices from the
+/// expression interval graph.
+///
+/// This function iterates through the provided sorted filter expressions,
+/// gathers the corresponding node indices from the expression interval graph,
+/// and then updates the sorted expressions with these indices. It ensures
+/// that these sorted expressions are aligned with the structure of the graph.
+fn update_sorted_exprs_with_node_indices(
+    graph: &mut ExprIntervalGraph,
+    sorted_exprs: &mut [SortedFilterExpr],
+) {
+    // Extract filter expressions from the sorted expressions:
+    let filter_exprs = sorted_exprs
+        .iter()
+        .map(|expr| expr.filter_expr().clone())
+        .collect::<Vec<_>>();
+
+    // Gather corresponding node indices for the extracted filter expressions 
from the graph:
+    let child_node_indices = graph.gather_node_indices(&filter_exprs);
+
+    // Iterate through the sorted expressions and the gathered node indices:
+    for (sorted_expr, (_, index)) in 
sorted_exprs.iter_mut().zip(child_node_indices) {
+        // Update each sorted expression with the corresponding node index:
+        sorted_expr.set_node_index(index);
+    }
+}
+
+/// Prepares and sorts expressions based on a given filter, left and right 
execution plans, and sort expressions.
+///
+/// # Arguments
+///
+/// * `filter` - The join filter to base the sorting on.
+/// * `left` - The left execution plan.
+/// * `right` - The right execution plan.
+/// * `left_sort_exprs` - The expressions to sort on the left side.
+/// * `right_sort_exprs` - The expressions to sort on the right side.
+///
+/// # Returns
+///
+/// * A tuple consisting of the sorted filter expression for the left and 
right sides, and an expression interval graph.
+pub fn prepare_sorted_exprs(
+    filter: &JoinFilter,
+    left: &Arc<dyn ExecutionPlan>,
+    right: &Arc<dyn ExecutionPlan>,
+    left_sort_exprs: &[PhysicalSortExpr],
+    right_sort_exprs: &[PhysicalSortExpr],
+) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> {
+    // Build the filter order for the left side
+    let err = || plan_datafusion_err!("Filter does not include the child 
order");
+
+    let left_temp_sorted_filter_expr = build_filter_input_order(
+        JoinSide::Left,
+        filter,
+        &left.schema(),
+        &left_sort_exprs[0],
+    )?
+    .ok_or_else(err)?;
+
+    // Build the filter order for the right side
+    let right_temp_sorted_filter_expr = build_filter_input_order(
+        JoinSide::Right,
+        filter,
+        &right.schema(),
+        &right_sort_exprs[0],
+    )?
+    .ok_or_else(err)?;
+
+    // Collect the sorted expressions
+    let mut sorted_exprs =
+        vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr];
+
+    // Build the expression interval graph
+    let mut graph =
+        ExprIntervalGraph::try_new(filter.expression().clone(), 
filter.schema())?;
+
+    // Update sorted expressions with node indices
+    update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs);
+
+    // Swap and remove to get the final sorted filter expressions
+    let right_sorted_filter_expr = sorted_exprs.swap_remove(1);
+    let left_sorted_filter_expr = sorted_exprs.swap_remove(0);
+
+    Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph))
+}
+
 #[cfg(test)]
 pub mod tests {
     use std::sync::Arc;
@@ -1043,62 +1128,15 @@ pub mod tests {
     };
     use crate::{
         expressions::{Column, PhysicalSortExpr},
+        joins::test_utils::complicated_filter,
         joins::utils::{ColumnIndex, JoinFilter},
     };
 
     use arrow::compute::SortOptions;
     use arrow::datatypes::{DataType, Field, Schema};
-    use datafusion_common::{JoinSide, ScalarValue};
+    use datafusion_common::JoinSide;
     use datafusion_expr::Operator;
-    use datafusion_physical_expr::expressions::{binary, cast, col, lit};
-
-    /// Filter expr for a + b > c + 10 AND a + b < c + 100
-    pub(crate) fn complicated_filter(
-        filter_schema: &Schema,
-    ) -> Result<Arc<dyn PhysicalExpr>> {
-        let left_expr = binary(
-            cast(
-                binary(
-                    col("0", filter_schema)?,
-                    Operator::Plus,
-                    col("1", filter_schema)?,
-                    filter_schema,
-                )?,
-                filter_schema,
-                DataType::Int64,
-            )?,
-            Operator::Gt,
-            binary(
-                cast(col("2", filter_schema)?, filter_schema, 
DataType::Int64)?,
-                Operator::Plus,
-                lit(ScalarValue::Int64(Some(10))),
-                filter_schema,
-            )?,
-            filter_schema,
-        )?;
-
-        let right_expr = binary(
-            cast(
-                binary(
-                    col("0", filter_schema)?,
-                    Operator::Plus,
-                    col("1", filter_schema)?,
-                    filter_schema,
-                )?,
-                filter_schema,
-                DataType::Int64,
-            )?,
-            Operator::Lt,
-            binary(
-                cast(col("2", filter_schema)?, filter_schema, 
DataType::Int64)?,
-                Operator::Plus,
-                lit(ScalarValue::Int64(Some(100))),
-                filter_schema,
-            )?,
-            filter_schema,
-        )?;
-        binary(left_expr, Operator::And, right_expr, filter_schema)
-    }
+    use datafusion_physical_expr::expressions::{binary, cast, col};
 
     #[test]
     fn test_column_exchange() -> Result<()> {
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs 
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index b9101b57c3..f071a7f601 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -36,13 +36,14 @@ use 
crate::joins::hash_join::{build_equal_condition_join_indices, update_hash};
 use crate::joins::stream_join_utils::{
     calculate_filter_expr_intervals, combine_two_batches,
     convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
-    get_pruning_semi_indices, record_visited_indices, EagerJoinStream,
-    EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, 
StreamJoinMetrics,
+    get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices,
+    EagerJoinStream, EagerJoinStreamState, PruningJoinHashMap, 
SortedFilterExpr,
+    StreamJoinMetrics,
 };
 use crate::joins::utils::{
     build_batch_from_indices, build_join_schema, check_join_is_valid,
-    partitioned_join_output_partitioning, prepare_sorted_exprs, ColumnIndex, 
JoinFilter,
-    JoinOn, StatefulStreamResult,
+    partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn,
+    StatefulStreamResult,
 };
 use crate::{
     expressions::{Column, PhysicalSortExpr},
@@ -936,7 +937,7 @@ impl OneSideHashJoiner {
             prune_length,
             self.deleted_offset as u64,
             HASHMAP_SHRINK_SCALE_FACTOR,
-        )?;
+        );
         // Remove pruned rows from the visited rows set:
         for row in self.deleted_offset..(self.deleted_offset + prune_length) {
             self.visited_rows.remove(&row);
diff --git a/datafusion/physical-plan/src/joins/utils.rs 
b/datafusion/physical-plan/src/joins/utils.rs
index c902ba85f2..ac805b50e6 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -25,7 +25,6 @@ use std::sync::Arc;
 use std::task::{Context, Poll};
 use std::usize;
 
-use crate::joins::stream_join_utils::{build_filter_input_order, 
SortedFilterExpr};
 use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder};
 use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics};
 
@@ -39,13 +38,11 @@ use arrow::record_batch::{RecordBatch, RecordBatchOptions};
 use datafusion_common::cast::as_boolean_array;
 use datafusion_common::stats::Precision;
 use datafusion_common::{
-    plan_datafusion_err, plan_err, DataFusionError, JoinSide, JoinType, Result,
-    SharedResult,
+    plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult,
 };
 use datafusion_expr::interval_arithmetic::Interval;
 use datafusion_physical_expr::equivalence::add_offset_to_expr;
 use datafusion_physical_expr::expressions::Column;
-use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
 use datafusion_physical_expr::utils::merge_vectors;
 use datafusion_physical_expr::{
     LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalSortExpr,
@@ -1208,91 +1205,6 @@ impl BuildProbeJoinMetrics {
     }
 }
 
-/// Updates sorted filter expressions with corresponding node indices from the
-/// expression interval graph.
-///
-/// This function iterates through the provided sorted filter expressions,
-/// gathers the corresponding node indices from the expression interval graph,
-/// and then updates the sorted expressions with these indices. It ensures
-/// that these sorted expressions are aligned with the structure of the graph.
-fn update_sorted_exprs_with_node_indices(
-    graph: &mut ExprIntervalGraph,
-    sorted_exprs: &mut [SortedFilterExpr],
-) {
-    // Extract filter expressions from the sorted expressions:
-    let filter_exprs = sorted_exprs
-        .iter()
-        .map(|expr| expr.filter_expr().clone())
-        .collect::<Vec<_>>();
-
-    // Gather corresponding node indices for the extracted filter expressions 
from the graph:
-    let child_node_indices = graph.gather_node_indices(&filter_exprs);
-
-    // Iterate through the sorted expressions and the gathered node indices:
-    for (sorted_expr, (_, index)) in 
sorted_exprs.iter_mut().zip(child_node_indices) {
-        // Update each sorted expression with the corresponding node index:
-        sorted_expr.set_node_index(index);
-    }
-}
-
-/// Prepares and sorts expressions based on a given filter, left and right 
execution plans, and sort expressions.
-///
-/// # Arguments
-///
-/// * `filter` - The join filter to base the sorting on.
-/// * `left` - The left execution plan.
-/// * `right` - The right execution plan.
-/// * `left_sort_exprs` - The expressions to sort on the left side.
-/// * `right_sort_exprs` - The expressions to sort on the right side.
-///
-/// # Returns
-///
-/// * A tuple consisting of the sorted filter expression for the left and 
right sides, and an expression interval graph.
-pub fn prepare_sorted_exprs(
-    filter: &JoinFilter,
-    left: &Arc<dyn ExecutionPlan>,
-    right: &Arc<dyn ExecutionPlan>,
-    left_sort_exprs: &[PhysicalSortExpr],
-    right_sort_exprs: &[PhysicalSortExpr],
-) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> {
-    // Build the filter order for the left side
-    let err = || plan_datafusion_err!("Filter does not include the child 
order");
-
-    let left_temp_sorted_filter_expr = build_filter_input_order(
-        JoinSide::Left,
-        filter,
-        &left.schema(),
-        &left_sort_exprs[0],
-    )?
-    .ok_or_else(err)?;
-
-    // Build the filter order for the right side
-    let right_temp_sorted_filter_expr = build_filter_input_order(
-        JoinSide::Right,
-        filter,
-        &right.schema(),
-        &right_sort_exprs[0],
-    )?
-    .ok_or_else(err)?;
-
-    // Collect the sorted expressions
-    let mut sorted_exprs =
-        vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr];
-
-    // Build the expression interval graph
-    let mut graph =
-        ExprIntervalGraph::try_new(filter.expression().clone(), 
filter.schema())?;
-
-    // Update sorted expressions with node indices
-    update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs);
-
-    // Swap and remove to get the final sorted filter expressions
-    let right_sorted_filter_expr = sorted_exprs.swap_remove(1);
-    let left_sorted_filter_expr = sorted_exprs.swap_remove(0);
-
-    Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph))
-}
-
 /// The `handle_state` macro is designed to process the result of a 
state-changing
 /// operation, encountered e.g. in implementations of `EagerJoinStream`. It
 /// operates on a `StatefulStreamResult` by matching its variants and executing

Reply via email to