This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 58ad10c07a [MINOR] Code refactor on hash join utils (#6999)
58ad10c07a is described below
commit 58ad10c07a3ae7abc19b3b0899f19f1f069030c8
Author: Metehan Yıldırım <[email protected]>
AuthorDate: Tue Jul 18 23:56:16 2023 +0300
[MINOR] Code refactor on hash join utils (#6999)
* Code refactor
* Remove mode
* Update test_utils.rs
---
.../src/physical_plan/joins/hash_join_utils.rs | 332 ++++-
datafusion/core/src/physical_plan/joins/mod.rs | 3 +
.../src/physical_plan/joins/symmetric_hash_join.rs | 1498 +++++---------------
.../core/src/physical_plan/joins/test_utils.rs | 513 +++++++
4 files changed, 1199 insertions(+), 1147 deletions(-)
diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs
b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs
index 1b9cbd543d..37790e6bb8 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs
@@ -22,17 +22,25 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::{fmt, usize};
-use arrow::datatypes::SchemaRef;
+use arrow::datatypes::{ArrowNativeType, SchemaRef};
+use arrow::compute::concat_batches;
+use arrow_array::builder::BooleanBufferBuilder;
+use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray,
RecordBatch};
use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_physical_expr::expressions::Column;
-use datafusion_physical_expr::intervals::Interval;
+use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval,
IntervalBound};
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use hashbrown::raw::RawTable;
+use hashbrown::HashSet;
+use parking_lot::Mutex;
use smallvec::SmallVec;
+use std::fmt::{Debug, Formatter};
use crate::physical_plan::joins::utils::{JoinFilter, JoinSide};
+use crate::physical_plan::ExecutionPlan;
use datafusion_common::Result;
// Maps a `u64` hash value based on the build side ["on" values] to a list of
indices with this key's value.
@@ -280,6 +288,105 @@ fn convert_filter_columns(
})
}
+#[derive(Default)]
+pub struct IntervalCalculatorInnerState {
+ /// Expression graph for interval calculations
+ graph: Option<ExprIntervalGraph>,
+ sorted_exprs: Vec<Option<SortedFilterExpr>>,
+ calculated: bool,
+}
+
+impl Debug for IntervalCalculatorInnerState {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ write!(f, "Exprs({:?})", self.sorted_exprs)
+ }
+}
+
+pub fn build_filter_expression_graph(
+ interval_state: &Arc<Mutex<IntervalCalculatorInnerState>>,
+ left: &Arc<dyn ExecutionPlan>,
+ right: &Arc<dyn ExecutionPlan>,
+ filter: &JoinFilter,
+) -> Result<(
+ Option<SortedFilterExpr>,
+ Option<SortedFilterExpr>,
+ Option<ExprIntervalGraph>,
+)> {
+ // Lock the mutex of the interval state:
+ let mut filter_state = interval_state.lock();
+ // If this is the first partition to be invoked, then we need to
initialize our state
+ // (the expression graph for pruning, sorted filter expressions etc.)
+ if !filter_state.calculated {
+ // Interval calculations require each column to exhibit monotonicity
+ // independently. However, a `PhysicalSortExpr` object defines a
+ // lexicographical ordering, so we can only use their first elements.
+ // when deducing column monotonicities.
+ // TODO: Extend the `PhysicalSortExpr` mechanism to express independent
+ // (i.e. simultaneous) ordering properties of columns.
+
+ // Build sorted filter expressions for the left and right join side:
+ let join_sides = [JoinSide::Left, JoinSide::Right];
+ let children = [left, right];
+ for (join_side, child) in join_sides.iter().zip(children.iter()) {
+ let sorted_expr = child
+ .output_ordering()
+ .and_then(|orders| {
+ build_filter_input_order(
+ *join_side,
+ filter,
+ &child.schema(),
+ &orders[0],
+ )
+ .transpose()
+ })
+ .transpose()?;
+
+ filter_state.sorted_exprs.push(sorted_expr);
+ }
+
+ // Collect available sorted filter expressions:
+ let sorted_exprs_size = filter_state.sorted_exprs.len();
+ let mut sorted_exprs = filter_state
+ .sorted_exprs
+ .iter_mut()
+ .flatten()
+ .collect::<Vec<_>>();
+
+ // Create the expression graph if we can create sorted filter
expressions for both children:
+ filter_state.graph = if sorted_exprs.len() == sorted_exprs_size {
+ let mut graph =
ExprIntervalGraph::try_new(filter.expression().clone())?;
+
+ // Gather filter expressions:
+ let filter_exprs = sorted_exprs
+ .iter()
+ .map(|sorted_expr| sorted_expr.filter_expr().clone())
+ .collect::<Vec<_>>();
+
+ // Gather node indices of converted filter expressions in
`SortedFilterExpr`s
+ // using the filter columns vector:
+ let child_node_indices = graph.gather_node_indices(&filter_exprs);
+
+ // Update SortedFilterExpr instances with the corresponding node
indices:
+ for (sorted_expr, (_, index)) in
+ sorted_exprs.iter_mut().zip(child_node_indices.iter())
+ {
+ sorted_expr.set_node_index(*index);
+ }
+
+ Some(graph)
+ } else {
+ None
+ };
+ filter_state.calculated = true;
+ }
+ // Return the sorted filter expressions for both sides along with the
expression graph:
+ Ok((
+ filter_state.sorted_exprs[0].clone(),
+ filter_state.sorted_exprs[1].clone(),
+ filter_state.graph.as_ref().cloned(),
+ ))
+}
+
/// The [SortedFilterExpr] object represents a sorted filter expression. It
/// contains the following information: The origin expression, the filter
/// expression, an interval encapsulating expression bounds, and a stable
@@ -341,6 +448,227 @@ impl SortedFilterExpr {
}
}
+/// Calculate the filter expression intervals.
+///
+/// This function updates the `interval` field of each `SortedFilterExpr` based
+/// on the first or the last value of the expression in `build_input_buffer`
+/// and `probe_batch`.
+///
+/// # Arguments
+///
+/// * `build_input_buffer` - The [RecordBatch] on the build side of the join.
+/// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update.
+/// * `probe_batch` - The `RecordBatch` on the probe side of the join.
+/// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update.
+///
+/// ### Note
+/// ```text
+///
+/// Interval arithmetic is used to calculate viable join ranges for build-side
+/// pruning. This is done by first creating an interval for join filter values
in
+/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on
the
+/// ordering (descending/ascending) of the filter expression. Here, FV denotes
the
+/// first value on the build side. This range is then compared with the probe
side
+/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering
+/// (ascending/descending) of the probe side. Here, LV denotes the last value
on
+/// the probe side.
+///
+/// As a concrete example, consider the following query:
+///
+/// SELECT * FROM left_table, right_table
+/// WHERE
+/// left_key = right_key AND
+/// a > b - 3 AND
+/// a < b + 10
+///
+/// where columns "a" and "b" come from tables "left_table" and "right_table",
+/// respectively. When a new `RecordBatch` arrives at the right side, the
+/// condition a > b - 3 will possibly indicate a prunable range for the left
+/// side. Conversely, when a new `RecordBatch` arrives at the left side, the
+/// condition a < b + 10 will possibly indicate prunability for the right side.
+/// Let’s inspect what happens when a new RecordBatch` arrives at the right
+/// side (i.e. when the left side is the build side):
+///
+/// Build Probe
+/// +-------+ +-------+
+/// | a | z | | b | y |
+/// |+--|--+| |+--|--+|
+/// | 1 | 2 | | 4 | 3 |
+/// |+--|--+| |+--|--+|
+/// | 3 | 1 | | 4 | 3 |
+/// |+--|--+| |+--|--+|
+/// | 5 | 7 | | 6 | 1 |
+/// |+--|--+| |+--|--+|
+/// | 7 | 1 | | 6 | 3 |
+/// +-------+ +-------+
+///
+/// In this case, the interval representing viable (i.e. joinable) values for
+/// column "a" is [1, ∞], and the interval representing possible future values
+/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate
+/// intervals for the whole filter expression and propagate join constraint by
+/// traversing the expression graph.
+/// ```
+pub fn calculate_filter_expr_intervals(
+ build_input_buffer: &RecordBatch,
+ build_sorted_filter_expr: &mut SortedFilterExpr,
+ probe_batch: &RecordBatch,
+ probe_sorted_filter_expr: &mut SortedFilterExpr,
+) -> Result<()> {
+ // If either build or probe side has no data, return early:
+ if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
+ return Ok(());
+ }
+ // Calculate the interval for the build side filter expression (if
present):
+ update_filter_expr_interval(
+ &build_input_buffer.slice(0, 1),
+ build_sorted_filter_expr,
+ )?;
+ // Calculate the interval for the probe side filter expression (if
present):
+ update_filter_expr_interval(
+ &probe_batch.slice(probe_batch.num_rows() - 1, 1),
+ probe_sorted_filter_expr,
+ )
+}
+
+/// This is a subroutine of the function [`calculate_filter_expr_intervals`].
+/// It constructs the current interval using the given `batch` and updates
+/// the filter expression (i.e. `sorted_expr`) with this interval.
+pub fn update_filter_expr_interval(
+ batch: &RecordBatch,
+ sorted_expr: &mut SortedFilterExpr,
+) -> Result<()> {
+ // Evaluate the filter expression and convert the result to an array:
+ let array = sorted_expr
+ .origin_sorted_expr()
+ .expr
+ .evaluate(batch)?
+ .into_array(1);
+ // Convert the array to a ScalarValue:
+ let value = ScalarValue::try_from_array(&array, 0)?;
+ // Create a ScalarValue representing positive or negative infinity for the
same data type:
+ let unbounded = IntervalBound::make_unbounded(value.get_datatype())?;
+ // Update the interval with lower and upper bounds based on the sort
option:
+ let interval = if sorted_expr.origin_sorted_expr().options.descending {
+ Interval::new(unbounded, IntervalBound::new(value, false))
+ } else {
+ Interval::new(IntervalBound::new(value, false), unbounded)
+ };
+ // Set the calculated interval for the sorted filter expression:
+ sorted_expr.set_interval(interval);
+ Ok(())
+}
+
+/// Get the anti join indices from the visited hash set.
+///
+/// This method returns the indices from the original input that were not
present in the visited hash set.
+///
+/// # Arguments
+///
+/// * `prune_length` - The length of the pruned record batch.
+/// * `deleted_offset` - The offset to the indices.
+/// * `visited_rows` - The hash set of visited indices.
+///
+/// # Returns
+///
+/// A `PrimitiveArray` of the anti join indices.
+pub fn get_pruning_anti_indices<T: ArrowPrimitiveType>(
+ prune_length: usize,
+ deleted_offset: usize,
+ visited_rows: &HashSet<usize>,
+) -> PrimitiveArray<T>
+where
+ NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
+{
+ let mut bitmap = BooleanBufferBuilder::new(prune_length);
+ bitmap.append_n(prune_length, false);
+ // mark the indices as true if they are present in the visited hash set
+ for v in 0..prune_length {
+ let row = v + deleted_offset;
+ bitmap.set_bit(v, visited_rows.contains(&row));
+ }
+ // get the anti index
+ (0..prune_length)
+ .filter_map(|idx|
(!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
+ .collect()
+}
+
+/// This method creates a boolean buffer from the visited rows hash set
+/// and the indices of the pruned record batch slice.
+///
+/// It gets the indices from the original input that were present in the
visited hash set.
+///
+/// # Arguments
+///
+/// * `prune_length` - The length of the pruned record batch.
+/// * `deleted_offset` - The offset to the indices.
+/// * `visited_rows` - The hash set of visited indices.
+///
+/// # Returns
+///
+/// A [PrimitiveArray] of the specified type T, containing the semi indices.
+pub fn get_pruning_semi_indices<T: ArrowPrimitiveType>(
+ prune_length: usize,
+ deleted_offset: usize,
+ visited_rows: &HashSet<usize>,
+) -> PrimitiveArray<T>
+where
+ NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
+{
+ let mut bitmap = BooleanBufferBuilder::new(prune_length);
+ bitmap.append_n(prune_length, false);
+ // mark the indices as true if they are present in the visited hash set
+ (0..prune_length).for_each(|v| {
+ let row = &(v + deleted_offset);
+ bitmap.set_bit(v, visited_rows.contains(row));
+ });
+ // get the semi index
+ (0..prune_length)
+ .filter_map(|idx|
(bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
+ .collect::<PrimitiveArray<T>>()
+}
+
+pub fn combine_two_batches(
+ output_schema: &SchemaRef,
+ left_batch: Option<RecordBatch>,
+ right_batch: Option<RecordBatch>,
+) -> Result<Option<RecordBatch>> {
+ match (left_batch, right_batch) {
+ (Some(batch), None) | (None, Some(batch)) => {
+ // If only one of the batches are present, return it:
+ Ok(Some(batch))
+ }
+ (Some(left_batch), Some(right_batch)) => {
+ // If both batches are present, concatenate them:
+ concat_batches(output_schema, &[left_batch, right_batch])
+ .map_err(DataFusionError::ArrowError)
+ .map(Some)
+ }
+ (None, None) => {
+ // If neither is present, return an empty batch:
+ Ok(None)
+ }
+ }
+}
+
+/// Records the visited indices from the input `PrimitiveArray` of type `T`
into the given hash set `visited`.
+/// This function will insert the indices (offset by `offset`) into the
`visited` hash set.
+///
+/// # Arguments
+///
+/// * `visited` - A hash set to store the visited indices.
+/// * `offset` - An offset to the indices in the `PrimitiveArray`.
+/// * `indices` - The input `PrimitiveArray` of type `T` which stores the
indices to be recorded.
+///
+pub fn record_visited_indices<T: ArrowPrimitiveType>(
+ visited: &mut HashSet<usize>,
+ offset: usize,
+ indices: &PrimitiveArray<T>,
+) {
+ for i in indices.values() {
+ visited.insert(i.as_usize() + offset);
+ }
+}
+
#[cfg(test)]
pub mod tests {
use super::*;
diff --git a/datafusion/core/src/physical_plan/joins/mod.rs
b/datafusion/core/src/physical_plan/joins/mod.rs
index fd805fa201..19f10d06e1 100644
--- a/datafusion/core/src/physical_plan/joins/mod.rs
+++ b/datafusion/core/src/physical_plan/joins/mod.rs
@@ -31,6 +31,9 @@ mod sort_merge_join;
mod symmetric_hash_join;
pub mod utils;
+#[cfg(test)]
+pub mod test_utils;
+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
/// Partitioning mode to use for hash join
pub enum PartitionMode {
diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
index 10c9ae2c08..1818e4b91c 100644
--- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
@@ -27,19 +27,16 @@
use std::collections::{HashMap, VecDeque};
use std::fmt;
-use std::fmt::{Debug, Formatter};
+use std::fmt::Debug;
use std::sync::Arc;
use std::task::Poll;
use std::vec;
use std::{any::Any, usize};
use ahash::RandomState;
-use arrow::array::{
- ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray,
- PrimitiveBuilder,
-};
+use arrow::array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray,
PrimitiveBuilder};
use arrow::compute::concat_batches;
-use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
+use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_array::builder::{UInt32BufferBuilder, UInt64BufferBuilder};
use arrow_array::{UInt32Array, UInt64Array};
@@ -51,19 +48,22 @@ use hashbrown::HashSet;
use parking_lot::Mutex;
use smallvec::smallvec;
-use datafusion_common::{utils::bisect, ScalarValue};
use datafusion_execution::memory_pool::MemoryConsumer;
-use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval,
IntervalBound};
+use datafusion_physical_expr::intervals::ExprIntervalGraph;
use crate::physical_plan::common::SharedMemoryReservation;
-use
crate::physical_plan::joins::hash_join_utils::convert_sort_expr_with_filter_schema;
+use crate::physical_plan::joins::hash_join_utils::{
+ build_filter_expression_graph, calculate_filter_expr_intervals,
combine_two_batches,
+ convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
+ get_pruning_semi_indices, record_visited_indices,
IntervalCalculatorInnerState,
+};
use crate::physical_plan::joins::StreamJoinPartitionMode;
use crate::physical_plan::DisplayAs;
use crate::physical_plan::{
expressions::Column,
expressions::PhysicalSortExpr,
joins::{
- hash_join_utils::{build_filter_input_order, SortedFilterExpr},
+ hash_join_utils::SortedFilterExpr,
utils::{
build_batch_from_indices, build_join_schema, check_join_is_valid,
combine_join_equivalence_properties,
partitioned_join_output_partitioning,
@@ -74,6 +74,7 @@ use crate::physical_plan::{
DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan,
Partitioning,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};
+use datafusion_common::utils::bisect;
use datafusion_common::JoinType;
use datafusion_common::{DataFusionError, Result};
use datafusion_execution::TaskContext;
@@ -197,19 +198,6 @@ pub struct SymmetricHashJoinExec {
mode: StreamJoinPartitionMode,
}
-struct IntervalCalculatorInnerState {
- /// Expression graph for interval calculations
- graph: Option<ExprIntervalGraph>,
- sorted_exprs: Vec<Option<SortedFilterExpr>>,
- calculated: bool,
-}
-
-impl Debug for IntervalCalculatorInnerState {
- fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
- write!(f, "Exprs({:?})", self.sorted_exprs)
- }
-}
-
#[derive(Debug)]
struct SymmetricHashJoinSideMetrics {
/// Number of batches consumed by this operator
@@ -306,11 +294,7 @@ impl SymmetricHashJoinExec {
let random_state = RandomState::with_seeds(0, 0, 0, 0);
let filter_state = if filter.is_some() {
- let inner_state = IntervalCalculatorInnerState {
- graph: None,
- sorted_exprs: vec![],
- calculated: false,
- };
+ let inner_state = IntervalCalculatorInnerState::default();
Some(Arc::new(Mutex::new(inner_state)))
} else {
None
@@ -523,83 +507,12 @@ impl ExecutionPlan for SymmetricHashJoinExec {
// for both sides, and build an expression graph if one is not already
built.
let (left_sorted_filter_expr, right_sorted_filter_expr, graph) =
match (&self.filter_state, &self.filter) {
- (Some(interval_state), Some(filter)) => {
- // Lock the mutex of the interval state:
- let mut filter_state = interval_state.lock();
- // If this is the first partition to be invoked, then we
need to initialize our state
- // (the expression graph for pruning, sorted filter
expressions etc.)
- if !filter_state.calculated {
- // Interval calculations require each column to
exhibit monotonicity
- // independently. However, a `PhysicalSortExpr` object
defines a
- // lexicographical ordering, so we can only use their
first elements.
- // when deducing column monotonicities.
- // TODO: Extend the `PhysicalSortExpr` mechanism to
express independent
- // (i.e. simultaneous) ordering properties of
columns.
-
- // Build sorted filter expressions for the left and
right join side:
- let join_sides = [JoinSide::Left, JoinSide::Right];
- let children = [&self.left, &self.right];
- for (join_side, child) in
join_sides.iter().zip(children.iter()) {
- let sorted_expr = child
- .output_ordering()
- .and_then(|orders| {
- build_filter_input_order(
- *join_side,
- filter,
- &child.schema(),
- &orders[0],
- )
- .transpose()
- })
- .transpose()?;
-
- filter_state.sorted_exprs.push(sorted_expr);
- }
-
- // Collect available sorted filter expressions:
- let sorted_exprs_size =
filter_state.sorted_exprs.len();
- let mut sorted_exprs = filter_state
- .sorted_exprs
- .iter_mut()
- .flatten()
- .collect::<Vec<_>>();
-
- // Create the expression graph if we can create sorted
filter expressions for both children:
- filter_state.graph = if sorted_exprs.len() ==
sorted_exprs_size {
- let mut graph =
-
ExprIntervalGraph::try_new(filter.expression().clone())?;
-
- // Gather filter expressions:
- let filter_exprs = sorted_exprs
- .iter()
- .map(|sorted_expr|
sorted_expr.filter_expr().clone())
- .collect::<Vec<_>>();
-
- // Gather node indices of converted filter
expressions in `SortedFilterExpr`s
- // using the filter columns vector:
- let child_node_indices =
- graph.gather_node_indices(&filter_exprs);
-
- // Update SortedFilterExpr instances with the
corresponding node indices:
- for (sorted_expr, (_, index)) in
-
sorted_exprs.iter_mut().zip(child_node_indices.iter())
- {
- sorted_expr.set_node_index(*index);
- }
-
- Some(graph)
- } else {
- None
- };
- filter_state.calculated = true;
- }
- // Return the sorted filter expressions for both sides
along with the expression graph:
- (
- filter_state.sorted_exprs[0].clone(),
- filter_state.sorted_exprs[1].clone(),
- filter_state.graph.as_ref().cloned(),
- )
- }
+ (Some(interval_state), Some(filter)) =>
build_filter_expression_graph(
+ interval_state,
+ &self.left,
+ &self.right,
+ filter,
+ )?,
// If `filter_state` or `filter` is not present, then return
None for all three values:
(_, _) => (None, None, None),
};
@@ -742,116 +655,6 @@ fn prune_hash_values(
Ok(())
}
-/// Calculate the filter expression intervals.
-///
-/// This function updates the `interval` field of each `SortedFilterExpr` based
-/// on the first or the last value of the expression in `build_input_buffer`
-/// and `probe_batch`.
-///
-/// # Arguments
-///
-/// * `build_input_buffer` - The [RecordBatch] on the build side of the join.
-/// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update.
-/// * `probe_batch` - The `RecordBatch` on the probe side of the join.
-/// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update.
-///
-/// ### Note
-/// ```text
-///
-/// Interval arithmetic is used to calculate viable join ranges for build-side
-/// pruning. This is done by first creating an interval for join filter values
in
-/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on
the
-/// ordering (descending/ascending) of the filter expression. Here, FV denotes
the
-/// first value on the build side. This range is then compared with the probe
side
-/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering
-/// (ascending/descending) of the probe side. Here, LV denotes the last value
on
-/// the probe side.
-///
-/// As a concrete example, consider the following query:
-///
-/// SELECT * FROM left_table, right_table
-/// WHERE
-/// left_key = right_key AND
-/// a > b - 3 AND
-/// a < b + 10
-///
-/// where columns "a" and "b" come from tables "left_table" and "right_table",
-/// respectively. When a new `RecordBatch` arrives at the right side, the
-/// condition a > b - 3 will possibly indicate a prunable range for the left
-/// side. Conversely, when a new `RecordBatch` arrives at the left side, the
-/// condition a < b + 10 will possibly indicate prunability for the right side.
-/// Let’s inspect what happens when a new RecordBatch` arrives at the right
-/// side (i.e. when the left side is the build side):
-///
-/// Build Probe
-/// +-------+ +-------+
-/// | a | z | | b | y |
-/// |+--|--+| |+--|--+|
-/// | 1 | 2 | | 4 | 3 |
-/// |+--|--+| |+--|--+|
-/// | 3 | 1 | | 4 | 3 |
-/// |+--|--+| |+--|--+|
-/// | 5 | 7 | | 6 | 1 |
-/// |+--|--+| |+--|--+|
-/// | 7 | 1 | | 6 | 3 |
-/// +-------+ +-------+
-///
-/// In this case, the interval representing viable (i.e. joinable) values for
-/// column "a" is [1, ∞], and the interval representing possible future values
-/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate
-/// intervals for the whole filter expression and propagate join constraint by
-/// traversing the expression graph.
-/// ```
-fn calculate_filter_expr_intervals(
- build_input_buffer: &RecordBatch,
- build_sorted_filter_expr: &mut SortedFilterExpr,
- probe_batch: &RecordBatch,
- probe_sorted_filter_expr: &mut SortedFilterExpr,
-) -> Result<()> {
- // If either build or probe side has no data, return early:
- if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
- return Ok(());
- }
- // Calculate the interval for the build side filter expression (if
present):
- update_filter_expr_interval(
- &build_input_buffer.slice(0, 1),
- build_sorted_filter_expr,
- )?;
- // Calculate the interval for the probe side filter expression (if
present):
- update_filter_expr_interval(
- &probe_batch.slice(probe_batch.num_rows() - 1, 1),
- probe_sorted_filter_expr,
- )
-}
-
-/// This is a subroutine of the function [`calculate_filter_expr_intervals`].
-/// It constructs the current interval using the given `batch` and updates
-/// the filter expression (i.e. `sorted_expr`) with this interval.
-fn update_filter_expr_interval(
- batch: &RecordBatch,
- sorted_expr: &mut SortedFilterExpr,
-) -> Result<()> {
- // Evaluate the filter expression and convert the result to an array:
- let array = sorted_expr
- .origin_sorted_expr()
- .expr
- .evaluate(batch)?
- .into_array(1);
- // Convert the array to a ScalarValue:
- let value = ScalarValue::try_from_array(&array, 0)?;
- // Create a ScalarValue representing positive or negative infinity for the
same data type:
- let unbounded = IntervalBound::make_unbounded(value.get_datatype())?;
- // Update the interval with lower and upper bounds based on the sort
option:
- let interval = if sorted_expr.origin_sorted_expr().options.descending {
- Interval::new(unbounded, IntervalBound::new(value, false))
- } else {
- Interval::new(IntervalBound::new(value, false), unbounded)
- };
- // Set the calculated interval for the sorted filter expression:
- sorted_expr.set_interval(interval);
- Ok(())
-}
-
/// Determine the pruning length for `buffer`.
///
/// This function evaluates the build side filter expression, converts the
@@ -919,93 +722,6 @@ fn need_to_produce_result_in_final(build_side: JoinSide,
join_type: JoinType) ->
}
}
-/// Get the anti join indices from the visited hash set.
-///
-/// This method returns the indices from the original input that were not
present in the visited hash set.
-///
-/// # Arguments
-///
-/// * `prune_length` - The length of the pruned record batch.
-/// * `deleted_offset` - The offset to the indices.
-/// * `visited_rows` - The hash set of visited indices.
-///
-/// # Returns
-///
-/// A `PrimitiveArray` of the anti join indices.
-fn get_anti_indices<T: ArrowPrimitiveType>(
- prune_length: usize,
- deleted_offset: usize,
- visited_rows: &HashSet<usize>,
-) -> PrimitiveArray<T>
-where
- NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
-{
- let mut bitmap = BooleanBufferBuilder::new(prune_length);
- bitmap.append_n(prune_length, false);
- // mark the indices as true if they are present in the visited hash set
- for v in 0..prune_length {
- let row = v + deleted_offset;
- bitmap.set_bit(v, visited_rows.contains(&row));
- }
- // get the anti index
- (0..prune_length)
- .filter_map(|idx|
(!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
- .collect()
-}
-
-/// This method creates a boolean buffer from the visited rows hash set
-/// and the indices of the pruned record batch slice.
-///
-/// It gets the indices from the original input that were present in the
visited hash set.
-///
-/// # Arguments
-///
-/// * `prune_length` - The length of the pruned record batch.
-/// * `deleted_offset` - The offset to the indices.
-/// * `visited_rows` - The hash set of visited indices.
-///
-/// # Returns
-///
-/// A [PrimitiveArray] of the specified type T, containing the semi indices.
-fn get_semi_indices<T: ArrowPrimitiveType>(
- prune_length: usize,
- deleted_offset: usize,
- visited_rows: &HashSet<usize>,
-) -> PrimitiveArray<T>
-where
- NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
-{
- let mut bitmap = BooleanBufferBuilder::new(prune_length);
- bitmap.append_n(prune_length, false);
- // mark the indices as true if they are present in the visited hash set
- (0..prune_length).for_each(|v| {
- let row = &(v + deleted_offset);
- bitmap.set_bit(v, visited_rows.contains(row));
- });
- // get the semi index
- (0..prune_length)
- .filter_map(|idx|
(bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
- .collect::<PrimitiveArray<T>>()
-}
-/// Records the visited indices from the input `PrimitiveArray` of type `T`
into the given hash set `visited`.
-/// This function will insert the indices (offset by `offset`) into the
`visited` hash set.
-///
-/// # Arguments
-///
-/// * `visited` - A hash set to store the visited indices.
-/// * `offset` - An offset to the indices in the `PrimitiveArray`.
-/// * `indices` - The input `PrimitiveArray` of type `T` which stores the
indices to be recorded.
-///
-fn record_visited_indices<T: ArrowPrimitiveType>(
- visited: &mut HashSet<usize>,
- offset: usize,
- indices: &PrimitiveArray<T>,
-) {
- for i in indices.values() {
- visited.insert(i.as_usize() + offset);
- }
-}
-
/// Calculate indices by join type.
///
/// This method returns a tuple of two arrays: build and probe indices.
@@ -1040,7 +756,7 @@ where
| (JoinSide::Right, JoinType::Right | JoinType::RightAnti)
| (_, JoinType::Full) => {
let build_unmatched_indices =
- get_anti_indices(prune_length, deleted_offset, visited_rows);
+ get_pruning_anti_indices(prune_length, deleted_offset,
visited_rows);
let mut builder =
PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
builder.append_nulls(build_unmatched_indices.len());
@@ -1050,7 +766,7 @@ where
// In the case of `LeftSemi` or `RightSemi` join, get the semi indices
(JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right,
JoinType::RightSemi) => {
let build_unmatched_indices =
- get_semi_indices(prune_length, deleted_offset, visited_rows);
+ get_pruning_semi_indices(prune_length, deleted_offset,
visited_rows);
let mut builder =
PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
builder.append_nulls(build_unmatched_indices.len());
@@ -1063,25 +779,301 @@ where
Ok(result)
}
-struct OneSideHashJoiner {
+/// This function produces unmatched record results based on the build side,
+/// join type and other parameters.
+///
+/// The method uses first `prune_length` rows from the build side input buffer
+/// to produce results.
+///
+/// # Arguments
+///
+/// * `output_schema` - The schema of the final output record batch.
+/// * `prune_length` - The length of the determined prune length.
+/// * `probe_schema` - The schema of the probe [RecordBatch].
+/// * `join_type` - The type of join to be performed.
+/// * `column_indices` - Indices of columns that are being joined.
+///
+/// # Returns
+///
+/// * `Option<RecordBatch>` - The final output record batch if required,
otherwise [None].
+pub(crate) fn build_side_determined_results(
+ build_hash_joiner: &OneSideHashJoiner,
+ output_schema: &SchemaRef,
+ prune_length: usize,
+ probe_schema: SchemaRef,
+ join_type: JoinType,
+ column_indices: &[ColumnIndex],
+) -> Result<Option<RecordBatch>> {
+ // Check if we need to produce a result in the final output:
+ if need_to_produce_result_in_final(build_hash_joiner.build_side,
join_type) {
+ // Calculate the indices for build and probe sides based on join type
and build side:
+ let (build_indices, probe_indices) = calculate_indices_by_join_type(
+ build_hash_joiner.build_side,
+ prune_length,
+ &build_hash_joiner.visited_rows,
+ build_hash_joiner.deleted_offset,
+ join_type,
+ )?;
+
+ // Create an empty probe record batch:
+ let empty_probe_batch = RecordBatch::new_empty(probe_schema);
+ // Build the final result from the indices of build and probe sides:
+ build_batch_from_indices(
+ output_schema.as_ref(),
+ &build_hash_joiner.input_buffer,
+ &empty_probe_batch,
+ &build_indices,
+ &probe_indices,
+ column_indices,
+ build_hash_joiner.build_side,
+ )
+ .map(|batch| (batch.num_rows() > 0).then_some(batch))
+ } else {
+ // If we don't need to produce a result, return None
+ Ok(None)
+ }
+}
+
+/// Gets build and probe indices which satisfy the on condition (including
+/// the equality condition and the join filter) in the join.
+#[allow(clippy::too_many_arguments)]
+pub fn build_join_indices(
+ probe_batch: &RecordBatch,
+ build_hashmap: &SymmetricJoinHashMap,
+ build_input_buffer: &RecordBatch,
+ on_build: &[Column],
+ on_probe: &[Column],
+ filter: Option<&JoinFilter>,
+ random_state: &RandomState,
+ null_equals_null: bool,
+ hashes_buffer: &mut Vec<u64>,
+ offset: Option<usize>,
+ build_side: JoinSide,
+) -> Result<(UInt64Array, UInt32Array)> {
+ // Get the indices that satisfy the equality condition, like `left.a1 =
right.a2`
+ let (build_indices, probe_indices) = build_equal_condition_join_indices(
+ build_hashmap,
+ build_input_buffer,
+ probe_batch,
+ on_build,
+ on_probe,
+ random_state,
+ null_equals_null,
+ hashes_buffer,
+ offset,
+ )?;
+ if let Some(filter) = filter {
+ // Filter the indices which satisfy the non-equal join condition, like
`left.b1 = 10`
+ apply_join_filter_to_indices(
+ build_input_buffer,
+ probe_batch,
+ build_indices,
+ probe_indices,
+ filter,
+ build_side,
+ )
+ } else {
+ Ok((build_indices, probe_indices))
+ }
+}
+
+// Returns build/probe indices satisfying the equality condition.
+// On LEFT.b1 = RIGHT.b2
+// LEFT Table:
+// a1 b1 c1
+// 1 1 10
+// 3 3 30
+// 5 5 50
+// 7 7 70
+// 9 8 90
+// 11 8 110
+// 13 10 130
+// RIGHT Table:
+// a2 b2 c2
+// 2 2 20
+// 4 4 40
+// 6 6 60
+// 8 8 80
+// 10 10 100
+// 12 10 120
+// The result is
+// "+----+----+-----+----+----+-----+",
+// "| a1 | b1 | c1 | a2 | b2 | c2 |",
+// "+----+----+-----+----+----+-----+",
+// "| 11 | 8 | 110 | 8 | 8 | 80 |",
+// "| 13 | 10 | 130 | 10 | 10 | 100 |",
+// "| 13 | 10 | 130 | 12 | 10 | 120 |",
+// "| 9 | 8 | 90 | 8 | 8 | 80 |",
+// "+----+----+-----+----+----+-----+"
+// And the result of build and probe indices are:
+// Build indices: 5, 6, 6, 4
+// Probe indices: 3, 4, 5, 3
+#[allow(clippy::too_many_arguments)]
+pub fn build_equal_condition_join_indices(
+ build_hashmap: &SymmetricJoinHashMap,
+ build_input_buffer: &RecordBatch,
+ probe_batch: &RecordBatch,
+ build_on: &[Column],
+ probe_on: &[Column],
+ random_state: &RandomState,
+ null_equals_null: bool,
+ hashes_buffer: &mut Vec<u64>,
+ offset: Option<usize>,
+) -> Result<(UInt64Array, UInt32Array)> {
+ let keys_values = probe_on
+ .iter()
+ .map(|c|
Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows())))
+ .collect::<Result<Vec<_>>>()?;
+ let build_join_values = build_on
+ .iter()
+ .map(|c| {
+ Ok(c.evaluate(build_input_buffer)?
+ .into_array(build_input_buffer.num_rows()))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ hashes_buffer.clear();
+ hashes_buffer.resize(probe_batch.num_rows(), 0);
+ let hash_values = create_hashes(&keys_values, random_state,
hashes_buffer)?;
+ // Using a buffer builder to avoid slower normal builder
+ let mut build_indices = UInt64BufferBuilder::new(0);
+ let mut probe_indices = UInt32BufferBuilder::new(0);
+ let offset_value = offset.unwrap_or(0);
+ // Visit all of the probe rows
+ for (row, hash_value) in hash_values.iter().enumerate() {
+ // Get the hash and find it in the build index
+ // For every item on the build and probe we check if it matches
+ // This possibly contains rows with hash collisions,
+ // So we have to check here whether rows are equal or not
+ if let Some((_, indices)) = build_hashmap
+ .0
+ .get(*hash_value, |(hash, _)| *hash_value == *hash)
+ {
+ for &i in indices {
+ // Check hash collisions
+ let offset_build_index = i as usize - offset_value;
+ // Check hash collisions
+ if equal_rows(
+ offset_build_index,
+ row,
+ &build_join_values,
+ &keys_values,
+ null_equals_null,
+ )? {
+ build_indices.append(offset_build_index as u64);
+ probe_indices.append(row as u32);
+ }
+ }
+ }
+ }
+
+ Ok((
+ PrimitiveArray::new(build_indices.finish().into(), None),
+ PrimitiveArray::new(probe_indices.finish().into(), None),
+ ))
+}
+
+/// This method performs a join between the build side input buffer and the
probe side batch.
+///
+/// # Arguments
+///
+/// * `build_hash_joiner` - Build side hash joiner
+/// * `probe_hash_joiner` - Probe side hash joiner
+/// * `schema` - A reference to the schema of the output record batch.
+/// * `join_type` - The type of join to be performed.
+/// * `on_probe` - An array of columns on which the join will be performed.
The columns are from the probe side of the join.
+/// * `filter` - An optional filter on the join condition.
+/// * `probe_batch` - The second record batch to be joined.
+/// * `column_indices` - An array of columns to be selected for the result of
the join.
+/// * `random_state` - The random state for the join.
+/// * `null_equals_null` - A boolean indicating whether NULL values should be
treated as equal when joining.
+///
+/// # Returns
+///
+/// A [Result] containing an optional record batch if the join type is not one
of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`.
+/// If the join type is one of the above four, the function will return [None].
+#[allow(clippy::too_many_arguments)]
+pub(crate) fn join_with_probe_batch(
+ build_hash_joiner: &mut OneSideHashJoiner,
+ probe_hash_joiner: &mut OneSideHashJoiner,
+ schema: &SchemaRef,
+ join_type: JoinType,
+ filter: Option<&JoinFilter>,
+ probe_batch: &RecordBatch,
+ column_indices: &[ColumnIndex],
+ random_state: &RandomState,
+ null_equals_null: bool,
+) -> Result<Option<RecordBatch>> {
+ if build_hash_joiner.input_buffer.num_rows() == 0 ||
probe_batch.num_rows() == 0 {
+ return Ok(None);
+ }
+ let (build_indices, probe_indices) = build_join_indices(
+ probe_batch,
+ &build_hash_joiner.hashmap,
+ &build_hash_joiner.input_buffer,
+ &build_hash_joiner.on,
+ &probe_hash_joiner.on,
+ filter,
+ random_state,
+ null_equals_null,
+ &mut build_hash_joiner.hashes_buffer,
+ Some(build_hash_joiner.deleted_offset),
+ build_hash_joiner.build_side,
+ )?;
+ if need_to_produce_result_in_final(build_hash_joiner.build_side,
join_type) {
+ record_visited_indices(
+ &mut build_hash_joiner.visited_rows,
+ build_hash_joiner.deleted_offset,
+ &build_indices,
+ );
+ }
+ if need_to_produce_result_in_final(build_hash_joiner.build_side.negate(),
join_type) {
+ record_visited_indices(
+ &mut probe_hash_joiner.visited_rows,
+ probe_hash_joiner.offset,
+ &probe_indices,
+ );
+ }
+ if matches!(
+ join_type,
+ JoinType::LeftAnti
+ | JoinType::RightAnti
+ | JoinType::LeftSemi
+ | JoinType::RightSemi
+ ) {
+ Ok(None)
+ } else {
+ build_batch_from_indices(
+ schema,
+ &build_hash_joiner.input_buffer,
+ probe_batch,
+ &build_indices,
+ &probe_indices,
+ column_indices,
+ build_hash_joiner.build_side,
+ )
+ .map(|batch| (batch.num_rows() > 0).then_some(batch))
+ }
+}
+
+pub struct OneSideHashJoiner {
/// Build side
build_side: JoinSide,
/// Input record batch buffer
- input_buffer: RecordBatch,
+ pub input_buffer: RecordBatch,
/// Columns from the side
- on: Vec<Column>,
+ pub(crate) on: Vec<Column>,
/// Hashmap
- hashmap: SymmetricJoinHashMap,
+ pub(crate) hashmap: SymmetricJoinHashMap,
/// To optimize hash deleting in case of pruning, we hold them in memory
row_hash_values: VecDeque<u64>,
/// Reuse the hashes buffer
- hashes_buffer: Vec<u64>,
+ pub(crate) hashes_buffer: Vec<u64>,
/// Matched rows
- visited_rows: HashSet<usize>,
+ pub(crate) visited_rows: HashSet<usize>,
/// Offset
- offset: usize,
+ pub(crate) offset: usize,
/// Deleted offset
- deleted_offset: usize,
+ pub(crate) deleted_offset: usize,
}
impl OneSideHashJoiner {
@@ -1156,7 +1148,7 @@ impl OneSideHashJoiner {
/// # Returns
///
/// Returns a [Result] encapsulating any intermediate errors.
- fn update_internal_state(
+ pub(crate) fn update_internal_state(
&mut self,
batch: &RecordBatch,
random_state: &RandomState,
@@ -1180,280 +1172,6 @@ impl OneSideHashJoiner {
Ok(())
}
- /// Gets build and probe indices which satisfy the on condition (including
- /// the equality condition and the join filter) in the join.
- #[allow(clippy::too_many_arguments)]
- pub fn build_join_indices(
- probe_batch: &RecordBatch,
- build_hashmap: &SymmetricJoinHashMap,
- build_input_buffer: &RecordBatch,
- on_build: &[Column],
- on_probe: &[Column],
- filter: Option<&JoinFilter>,
- random_state: &RandomState,
- null_equals_null: bool,
- hashes_buffer: &mut Vec<u64>,
- offset: Option<usize>,
- build_side: JoinSide,
- ) -> Result<(UInt64Array, UInt32Array)> {
- // Get the indices that satisfy the equality condition, like `left.a1
= right.a2`
- let (build_indices, probe_indices) =
Self::build_equal_condition_join_indices(
- build_hashmap,
- build_input_buffer,
- probe_batch,
- on_build,
- on_probe,
- random_state,
- null_equals_null,
- hashes_buffer,
- offset,
- )?;
- if let Some(filter) = filter {
- // Filter the indices which satisfy the non-equal join condition,
like `left.b1 = 10`
- apply_join_filter_to_indices(
- build_input_buffer,
- probe_batch,
- build_indices,
- probe_indices,
- filter,
- build_side,
- )
- } else {
- Ok((build_indices, probe_indices))
- }
- }
-
- // Returns build/probe indices satisfying the equality condition.
- // On LEFT.b1 = RIGHT.b2
- // LEFT Table:
- // a1 b1 c1
- // 1 1 10
- // 3 3 30
- // 5 5 50
- // 7 7 70
- // 9 8 90
- // 11 8 110
- // 13 10 130
- // RIGHT Table:
- // a2 b2 c2
- // 2 2 20
- // 4 4 40
- // 6 6 60
- // 8 8 80
- // 10 10 100
- // 12 10 120
- // The result is
- // "+----+----+-----+----+----+-----+",
- // "| a1 | b1 | c1 | a2 | b2 | c2 |",
- // "+----+----+-----+----+----+-----+",
- // "| 11 | 8 | 110 | 8 | 8 | 80 |",
- // "| 13 | 10 | 130 | 10 | 10 | 100 |",
- // "| 13 | 10 | 130 | 12 | 10 | 120 |",
- // "| 9 | 8 | 90 | 8 | 8 | 80 |",
- // "+----+----+-----+----+----+-----+"
- // And the result of build and probe indices are:
- // Build indices: 5, 6, 6, 4
- // Probe indices: 3, 4, 5, 3
- #[allow(clippy::too_many_arguments)]
- pub fn build_equal_condition_join_indices(
- build_hashmap: &SymmetricJoinHashMap,
- build_input_buffer: &RecordBatch,
- probe_batch: &RecordBatch,
- build_on: &[Column],
- probe_on: &[Column],
- random_state: &RandomState,
- null_equals_null: bool,
- hashes_buffer: &mut Vec<u64>,
- offset: Option<usize>,
- ) -> Result<(UInt64Array, UInt32Array)> {
- let keys_values = probe_on
- .iter()
- .map(|c|
Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows())))
- .collect::<Result<Vec<_>>>()?;
- let build_join_values = build_on
- .iter()
- .map(|c| {
- Ok(c.evaluate(build_input_buffer)?
- .into_array(build_input_buffer.num_rows()))
- })
- .collect::<Result<Vec<_>>>()?;
- hashes_buffer.clear();
- hashes_buffer.resize(probe_batch.num_rows(), 0);
- let hash_values = create_hashes(&keys_values, random_state,
hashes_buffer)?;
- // Using a buffer builder to avoid slower normal builder
- let mut build_indices = UInt64BufferBuilder::new(0);
- let mut probe_indices = UInt32BufferBuilder::new(0);
- let offset_value = offset.unwrap_or(0);
- // Visit all of the probe rows
- for (row, hash_value) in hash_values.iter().enumerate() {
- // Get the hash and find it in the build index
- // For every item on the build and probe we check if it matches
- // This possibly contains rows with hash collisions,
- // So we have to check here whether rows are equal or not
- if let Some((_, indices)) = build_hashmap
- .0
- .get(*hash_value, |(hash, _)| *hash_value == *hash)
- {
- for &i in indices {
- // Check hash collisions
- let offset_build_index = i as usize - offset_value;
- // Check hash collisions
- if equal_rows(
- offset_build_index,
- row,
- &build_join_values,
- &keys_values,
- null_equals_null,
- )? {
- build_indices.append(offset_build_index as u64);
- probe_indices.append(row as u32);
- }
- }
- }
- }
-
- Ok((
- PrimitiveArray::new(build_indices.finish().into(), None),
- PrimitiveArray::new(probe_indices.finish().into(), None),
- ))
- }
-
- /// This method performs a join between the build side input buffer and
the probe side batch.
- ///
- /// # Arguments
- ///
- /// * `schema` - A reference to the schema of the output record batch.
- /// * `join_type` - The type of join to be performed.
- /// * `on_probe` - An array of columns on which the join will be
performed. The columns are from the probe side of the join.
- /// * `filter` - An optional filter on the join condition.
- /// * `probe_batch` - The second record batch to be joined.
- /// * `probe_visited` - A hash set to store the visited indices from the
probe batch.
- /// * `probe_offset` - The offset of the probe side for visited indices
calculations.
- /// * `column_indices` - An array of columns to be selected for the result
of the join.
- /// * `random_state` - The random state for the join.
- /// * `null_equals_null` - A boolean indicating whether NULL values should
be treated as equal when joining.
- ///
- /// # Returns
- ///
- /// A [Result] containing an optional record batch if the join type is not
one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`.
- /// If the join type is one of the above four, the function will return
[None].
- #[allow(clippy::too_many_arguments)]
- fn join_with_probe_batch(
- &mut self,
- schema: &SchemaRef,
- join_type: JoinType,
- on_probe: &[Column],
- filter: Option<&JoinFilter>,
- probe_batch: &RecordBatch,
- probe_visited: &mut HashSet<usize>,
- probe_offset: usize,
- column_indices: &[ColumnIndex],
- random_state: &RandomState,
- null_equals_null: bool,
- ) -> Result<Option<RecordBatch>> {
- if self.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
- return Ok(None);
- }
- let (build_indices, probe_indices) = Self::build_join_indices(
- probe_batch,
- &self.hashmap,
- &self.input_buffer,
- &self.on,
- on_probe,
- filter,
- random_state,
- null_equals_null,
- &mut self.hashes_buffer,
- Some(self.deleted_offset),
- self.build_side,
- )?;
- if need_to_produce_result_in_final(self.build_side, join_type) {
- record_visited_indices(
- &mut self.visited_rows,
- self.deleted_offset,
- &build_indices,
- );
- }
- if need_to_produce_result_in_final(self.build_side.negate(),
join_type) {
- record_visited_indices(probe_visited, probe_offset,
&probe_indices);
- }
- if matches!(
- join_type,
- JoinType::LeftAnti
- | JoinType::RightAnti
- | JoinType::LeftSemi
- | JoinType::RightSemi
- ) {
- Ok(None)
- } else {
- build_batch_from_indices(
- schema,
- &self.input_buffer,
- probe_batch,
- &build_indices,
- &probe_indices,
- column_indices,
- self.build_side,
- )
- .map(|batch| (batch.num_rows() > 0).then_some(batch))
- }
- }
-
- /// This function produces unmatched record results based on the build
side,
- /// join type and other parameters.
- ///
- /// The method uses first `prune_length` rows from the build side input
buffer
- /// to produce results.
- ///
- /// # Arguments
- ///
- /// * `output_schema` - The schema of the final output record batch.
- /// * `prune_length` - The length of the determined prune length.
- /// * `probe_schema` - The schema of the probe [RecordBatch].
- /// * `join_type` - The type of join to be performed.
- /// * `column_indices` - Indices of columns that are being joined.
- ///
- /// # Returns
- ///
- /// * `Option<RecordBatch>` - The final output record batch if required,
otherwise [None].
- fn build_side_determined_results(
- &self,
- output_schema: &SchemaRef,
- prune_length: usize,
- probe_schema: SchemaRef,
- join_type: JoinType,
- column_indices: &[ColumnIndex],
- ) -> Result<Option<RecordBatch>> {
- // Check if we need to produce a result in the final output:
- if need_to_produce_result_in_final(self.build_side, join_type) {
- // Calculate the indices for build and probe sides based on join
type and build side:
- let (build_indices, probe_indices) =
calculate_indices_by_join_type(
- self.build_side,
- prune_length,
- &self.visited_rows,
- self.deleted_offset,
- join_type,
- )?;
-
- // Create an empty probe record batch:
- let empty_probe_batch = RecordBatch::new_empty(probe_schema);
- // Build the final result from the indices of build and probe
sides:
- build_batch_from_indices(
- output_schema.as_ref(),
- &self.input_buffer,
- &empty_probe_batch,
- &build_indices,
- &probe_indices,
- column_indices,
- self.build_side,
- )
- .map(|batch| (batch.num_rows() > 0).then_some(batch))
- } else {
- // If we don't need to produce a result, return None
- Ok(None)
- }
- }
-
/// Prunes the internal buffer.
///
/// Argument `probe_batch` is used to update the intervals of the sorted
@@ -1475,7 +1193,7 @@ impl OneSideHashJoiner {
///
/// If there are rows to prune, returns the pruned build side record batch
wrapped in an `Ok` variant.
/// Otherwise, returns `Ok(None)`.
- fn calculate_prune_length_with_probe_batch(
+ pub(crate) fn calculate_prune_length_with_probe_batch(
&mut self,
build_side_sorted_filter_expr: &mut SortedFilterExpr,
probe_side_sorted_filter_expr: &mut SortedFilterExpr,
@@ -1508,22 +1226,7 @@ impl OneSideHashJoiner {
determine_prune_length(&self.input_buffer,
build_side_sorted_filter_expr)
}
- fn prune_internal_state_and_build_anti_result(
- &mut self,
- prune_length: usize,
- schema: &SchemaRef,
- probe_batch: &RecordBatch,
- join_type: JoinType,
- column_indices: &[ColumnIndex],
- ) -> Result<Option<RecordBatch>> {
- // Compute the result and perform pruning if there are rows to prune:
- let result = self.build_side_determined_results(
- schema,
- prune_length,
- probe_batch.schema(),
- join_type,
- column_indices,
- );
+ pub(crate) fn prune_internal_state(&mut self, prune_length: usize) ->
Result<()> {
// Prune the hash values:
prune_hash_values(
prune_length,
@@ -1541,30 +1244,7 @@ impl OneSideHashJoiner {
.slice(prune_length, self.input_buffer.num_rows() - prune_length);
// Increment the deleted offset:
self.deleted_offset += prune_length;
- result
- }
-}
-
-fn combine_two_batches(
- output_schema: &SchemaRef,
- left_batch: Option<RecordBatch>,
- right_batch: Option<RecordBatch>,
-) -> Result<Option<RecordBatch>> {
- match (left_batch, right_batch) {
- (Some(batch), None) | (None, Some(batch)) => {
- // If only one of the batches are present, return it:
- Ok(Some(batch))
- }
- (Some(left_batch), Some(right_batch)) => {
- // If both batches are present, concatenate them:
- concat_batches(output_schema, &[left_batch, right_batch])
- .map_err(DataFusionError::ArrowError)
- .map(Some)
- }
- (None, None) => {
- // If neither is present, return an empty batch:
- Ok(None)
- }
+ Ok(())
}
}
@@ -1634,14 +1314,13 @@ impl SymmetricHashJoinStream {
probe_hash_joiner
.update_internal_state(&probe_batch,
&self.random_state)?;
// Join the two sides:
- let equal_result = build_hash_joiner.join_with_probe_batch(
+ let equal_result = join_with_probe_batch(
+ build_hash_joiner,
+ probe_hash_joiner,
&self.schema,
self.join_type,
- &probe_hash_joiner.on,
self.filter.as_ref(),
&probe_batch,
- &mut probe_hash_joiner.visited_rows,
- probe_hash_joiner.offset,
&self.column_indices,
&self.random_state,
self.null_equals_null,
@@ -1673,13 +1352,16 @@ impl SymmetricHashJoinStream {
)?;
if prune_length > 0 {
-
build_hash_joiner.prune_internal_state_and_build_anti_result(
- prune_length,
+ let res = build_side_determined_results(
+ build_hash_joiner,
&self.schema,
- &probe_batch,
+ prune_length,
+ probe_batch.schema(),
self.join_type,
&self.column_indices,
- )?
+ )?;
+
build_hash_joiner.prune_internal_state(prune_length)?;
+ res
} else {
None
}
@@ -1708,7 +1390,8 @@ impl SymmetricHashJoinStream {
}
self.final_result = true;
// Get the left side results:
- let left_result = self.left.build_side_determined_results(
+ let left_result = build_side_determined_results(
+ &self.left,
&self.schema,
self.left.input_buffer.num_rows(),
self.right.input_buffer.schema(),
@@ -1716,7 +1399,8 @@ impl SymmetricHashJoinStream {
&self.column_indices,
)?;
// Get the right side results:
- let right_result =
self.right.build_side_determined_results(
+ let right_result = build_side_determined_results(
+ &self.right,
&self.schema,
self.right.input_buffer.num_rows(),
self.left.input_buffer.schema(),
@@ -1746,509 +1430,34 @@ impl SymmetricHashJoinStream {
mod tests {
use std::fs::File;
- use arrow::array::{ArrayRef, Float64Array, IntervalDayTimeArray};
- use arrow::array::{Int32Array, TimestampMillisecondArray};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
- use arrow::util::pretty::pretty_format_batches;
use rstest::*;
use tempfile::TempDir;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{binary, col, Column};
- use datafusion_physical_expr::intervals::test_utils::{
- gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr,
- };
- use datafusion_physical_expr::PhysicalExpr;
+ use
datafusion_physical_expr::intervals::test_utils::gen_conjunctive_numerical_expr;
- use crate::physical_plan::joins::{
- hash_join_utils::tests::complicated_filter, HashJoinExec,
PartitionMode,
- };
- use crate::physical_plan::{
- common, displayable, memory::MemoryExec, repartition::RepartitionExec,
- };
+ use crate::physical_plan::displayable;
+ use
crate::physical_plan::joins::hash_join_utils::tests::complicated_filter;
use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext};
use crate::test_util::register_unbounded_file_with_ordering;
- use super::*;
-
- const TABLE_SIZE: i32 = 100;
-
- fn compare_batches(collected_1: &[RecordBatch], collected_2:
&[RecordBatch]) {
- // compare
- let first_formatted =
pretty_format_batches(collected_1).unwrap().to_string();
- let second_formatted =
pretty_format_batches(collected_2).unwrap().to_string();
-
- let mut first_formatted_sorted: Vec<&str> =
- first_formatted.trim().lines().collect();
- first_formatted_sorted.sort_unstable();
-
- let mut second_formatted_sorted: Vec<&str> =
- second_formatted.trim().lines().collect();
- second_formatted_sorted.sort_unstable();
-
- for (i, (first_line, second_line)) in first_formatted_sorted
- .iter()
- .zip(&second_formatted_sorted)
- .enumerate()
- {
- assert_eq!((i, first_line), (i, second_line));
- }
- }
-
- async fn partitioned_sym_join_with_filter(
- left: Arc<dyn ExecutionPlan>,
- right: Arc<dyn ExecutionPlan>,
- on: JoinOn,
- filter: Option<JoinFilter>,
- join_type: &JoinType,
- null_equals_null: bool,
- context: Arc<TaskContext>,
- ) -> Result<Vec<RecordBatch>> {
- let partition_count = 4;
-
- let left_expr = on
- .iter()
- .map(|(l, _)| Arc::new(l.clone()) as _)
- .collect::<Vec<_>>();
-
- let right_expr = on
- .iter()
- .map(|(_, r)| Arc::new(r.clone()) as _)
- .collect::<Vec<_>>();
-
- let join = SymmetricHashJoinExec::try_new(
- Arc::new(RepartitionExec::try_new(
- left,
- Partitioning::Hash(left_expr, partition_count),
- )?),
- Arc::new(RepartitionExec::try_new(
- right,
- Partitioning::Hash(right_expr, partition_count),
- )?),
- on,
- filter,
- join_type,
- null_equals_null,
- StreamJoinPartitionMode::Partitioned,
- )?;
-
- let mut batches = vec![];
- for i in 0..partition_count {
- let stream = join.execute(i, context.clone())?;
- let more_batches = common::collect(stream).await?;
- batches.extend(
- more_batches
- .into_iter()
- .filter(|b| b.num_rows() > 0)
- .collect::<Vec<_>>(),
- );
- }
-
- Ok(batches)
- }
-
- async fn partitioned_hash_join_with_filter(
- left: Arc<dyn ExecutionPlan>,
- right: Arc<dyn ExecutionPlan>,
- on: JoinOn,
- filter: Option<JoinFilter>,
- join_type: &JoinType,
- null_equals_null: bool,
- context: Arc<TaskContext>,
- ) -> Result<Vec<RecordBatch>> {
- let partition_count = 4;
-
- let (left_expr, right_expr) = on
- .iter()
- .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _))
- .unzip();
-
- let join = HashJoinExec::try_new(
- Arc::new(RepartitionExec::try_new(
- left,
- Partitioning::Hash(left_expr, partition_count),
- )?),
- Arc::new(RepartitionExec::try_new(
- right,
- Partitioning::Hash(right_expr, partition_count),
- )?),
- on,
- filter,
- join_type,
- PartitionMode::Partitioned,
- null_equals_null,
- )?;
-
- let mut batches = vec![];
- for i in 0..partition_count {
- let stream = join.execute(i, context.clone())?;
- let more_batches = common::collect(stream).await?;
- batches.extend(
- more_batches
- .into_iter()
- .filter(|b| b.num_rows() > 0)
- .collect::<Vec<_>>(),
- );
- }
-
- Ok(batches)
- }
-
- pub fn split_record_batches(
- batch: &RecordBatch,
- batch_size: usize,
- ) -> Result<Vec<RecordBatch>> {
- let row_num = batch.num_rows();
- let number_of_batch = row_num / batch_size;
- let mut sizes = vec![batch_size; number_of_batch];
- sizes.push(row_num - (batch_size * number_of_batch));
- let mut result = vec![];
- for (i, size) in sizes.iter().enumerate() {
- result.push(batch.slice(i * batch_size, *size));
- }
- Ok(result)
- }
-
- // It creates join filters for different type of fields for testing.
- macro_rules! join_expr_tests {
- ($func_name:ident, $type:ty, $SCALAR:ident) => {
- fn $func_name(
- expr_id: usize,
- left_col: Arc<dyn PhysicalExpr>,
- right_col: Arc<dyn PhysicalExpr>,
- ) -> Arc<dyn PhysicalExpr> {
- match expr_id {
- // left_col + 1 > right_col + 5 AND left_col + 3 <
right_col + 10
- 0 => gen_conjunctive_numerical_expr(
- left_col,
- right_col,
- (
- Operator::Plus,
- Operator::Plus,
- Operator::Plus,
- Operator::Plus,
- ),
- ScalarValue::$SCALAR(Some(1 as $type)),
- ScalarValue::$SCALAR(Some(5 as $type)),
- ScalarValue::$SCALAR(Some(3 as $type)),
- ScalarValue::$SCALAR(Some(10 as $type)),
- (Operator::Gt, Operator::Lt),
- ),
- // left_col - 1 > right_col + 5 AND left_col + 3 <
right_col + 10
- 1 => gen_conjunctive_numerical_expr(
- left_col,
- right_col,
- (
- Operator::Minus,
- Operator::Plus,
- Operator::Plus,
- Operator::Plus,
- ),
- ScalarValue::$SCALAR(Some(1 as $type)),
- ScalarValue::$SCALAR(Some(5 as $type)),
- ScalarValue::$SCALAR(Some(3 as $type)),
- ScalarValue::$SCALAR(Some(10 as $type)),
- (Operator::Gt, Operator::Lt),
- ),
- // left_col - 1 > right_col + 5 AND left_col - 3 <
right_col + 10
- 2 => gen_conjunctive_numerical_expr(
- left_col,
- right_col,
- (
- Operator::Minus,
- Operator::Plus,
- Operator::Minus,
- Operator::Plus,
- ),
- ScalarValue::$SCALAR(Some(1 as $type)),
- ScalarValue::$SCALAR(Some(5 as $type)),
- ScalarValue::$SCALAR(Some(3 as $type)),
- ScalarValue::$SCALAR(Some(10 as $type)),
- (Operator::Gt, Operator::Lt),
- ),
- // left_col - 10 > right_col - 5 AND left_col - 3 <
right_col + 10
- 3 => gen_conjunctive_numerical_expr(
- left_col,
- right_col,
- (
- Operator::Minus,
- Operator::Minus,
- Operator::Minus,
- Operator::Plus,
- ),
- ScalarValue::$SCALAR(Some(10 as $type)),
- ScalarValue::$SCALAR(Some(5 as $type)),
- ScalarValue::$SCALAR(Some(3 as $type)),
- ScalarValue::$SCALAR(Some(10 as $type)),
- (Operator::Gt, Operator::Lt),
- ),
- // left_col - 10 > right_col - 5 AND left_col - 30 <
right_col - 3
- 4 => gen_conjunctive_numerical_expr(
- left_col,
- right_col,
- (
- Operator::Minus,
- Operator::Minus,
- Operator::Minus,
- Operator::Minus,
- ),
- ScalarValue::$SCALAR(Some(10 as $type)),
- ScalarValue::$SCALAR(Some(5 as $type)),
- ScalarValue::$SCALAR(Some(30 as $type)),
- ScalarValue::$SCALAR(Some(3 as $type)),
- (Operator::Gt, Operator::Lt),
- ),
- // left_col - 2 >= right_col - 5 AND left_col - 7 <=
right_col - 3
- 5 => gen_conjunctive_numerical_expr(
- left_col,
- right_col,
- (
- Operator::Minus,
- Operator::Plus,
- Operator::Plus,
- Operator::Minus,
- ),
- ScalarValue::$SCALAR(Some(2 as $type)),
- ScalarValue::$SCALAR(Some(5 as $type)),
- ScalarValue::$SCALAR(Some(7 as $type)),
- ScalarValue::$SCALAR(Some(3 as $type)),
- (Operator::GtEq, Operator::LtEq),
- ),
- // left_col - 28 >= right_col - 11 AND left_col - 21 <=
right_col - 39
- 6 => gen_conjunctive_numerical_expr(
- left_col,
- right_col,
- (
- Operator::Plus,
- Operator::Minus,
- Operator::Plus,
- Operator::Plus,
- ),
- ScalarValue::$SCALAR(Some(28 as $type)),
- ScalarValue::$SCALAR(Some(11 as $type)),
- ScalarValue::$SCALAR(Some(21 as $type)),
- ScalarValue::$SCALAR(Some(39 as $type)),
- (Operator::Gt, Operator::LtEq),
- ),
- // left_col - 28 >= right_col - 11 AND left_col - 21 <=
right_col - 39
- 7 => gen_conjunctive_numerical_expr(
- left_col,
- right_col,
- (
- Operator::Plus,
- Operator::Minus,
- Operator::Minus,
- Operator::Plus,
- ),
- ScalarValue::$SCALAR(Some(28 as $type)),
- ScalarValue::$SCALAR(Some(11 as $type)),
- ScalarValue::$SCALAR(Some(21 as $type)),
- ScalarValue::$SCALAR(Some(39 as $type)),
- (Operator::GtEq, Operator::Lt),
- ),
- _ => panic!("No case"),
- }
- }
- };
- }
-
- join_expr_tests!(join_expr_tests_fixture_i32, i32, Int32);
- join_expr_tests!(join_expr_tests_fixture_f64, f64, Float64);
-
- use rand::rngs::StdRng;
- use rand::{Rng, SeedableRng};
+ use crate::physical_plan::joins::test_utils::{
+ build_sides_record_batches, compare_batches, create_memory_table,
+ join_expr_tests_fixture_f64, join_expr_tests_fixture_i32,
+ join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter,
+ partitioned_sym_join_with_filter,
+ };
+ use datafusion_common::ScalarValue;
use std::iter::Iterator;
- struct AscendingRandomFloatIterator {
- prev: f64,
- max: f64,
- rng: StdRng,
- }
-
- impl AscendingRandomFloatIterator {
- fn new(min: f64, max: f64) -> Self {
- let mut rng = StdRng::seed_from_u64(42);
- let initial = rng.gen_range(min..max);
- AscendingRandomFloatIterator {
- prev: initial,
- max,
- rng,
- }
- }
- }
-
- impl Iterator for AscendingRandomFloatIterator {
- type Item = f64;
-
- fn next(&mut self) -> Option<Self::Item> {
- let value = self.rng.gen_range(self.prev..self.max);
- self.prev = value;
- Some(value)
- }
- }
-
- fn join_expr_tests_fixture_temporal(
- expr_id: usize,
- left_col: Arc<dyn PhysicalExpr>,
- right_col: Arc<dyn PhysicalExpr>,
- schema: &Schema,
- ) -> Result<Arc<dyn PhysicalExpr>> {
- match expr_id {
- // constructs ((left_col - INTERVAL '100ms') > (right_col -
INTERVAL '200ms')) AND ((left_col - INTERVAL '450ms') < (right_col - INTERVAL
'300ms'))
- 0 => gen_conjunctive_temporal_expr(
- left_col,
- right_col,
- Operator::Minus,
- Operator::Minus,
- Operator::Minus,
- Operator::Minus,
- ScalarValue::new_interval_dt(0, 100), // 100 ms
- ScalarValue::new_interval_dt(0, 200), // 200 ms
- ScalarValue::new_interval_dt(0, 450), // 450 ms
- ScalarValue::new_interval_dt(0, 300), // 300 ms
- schema,
- ),
- // constructs ((left_col - TIMESTAMP '2023-01-01:12.00.03') >
(right_col - TIMESTAMP '2023-01-01:12.00.01')) AND ((left_col - TIMESTAMP
'2023-01-01:12.00.00') < (right_col - TIMESTAMP '2023-01-01:12.00.02'))
- 1 => gen_conjunctive_temporal_expr(
- left_col,
- right_col,
- Operator::Minus,
- Operator::Minus,
- Operator::Minus,
- Operator::Minus,
- ScalarValue::TimestampMillisecond(Some(1672574403000), None),
// 2023-01-01:12.00.03
- ScalarValue::TimestampMillisecond(Some(1672574401000), None),
// 2023-01-01:12.00.01
- ScalarValue::TimestampMillisecond(Some(1672574400000), None),
// 2023-01-01:12.00.00
- ScalarValue::TimestampMillisecond(Some(1672574402000), None),
// 2023-01-01:12.00.02
- schema,
- ),
- _ => unreachable!(),
- }
- }
- fn build_sides_record_batches(
- table_size: i32,
- key_cardinality: (i32, i32),
- ) -> Result<(RecordBatch, RecordBatch)> {
- let null_ratio: f64 = 0.4;
- let initial_range = 0..table_size;
- let index = (table_size as f64 * null_ratio).round() as i32;
- let rest_of = index..table_size;
- let ordered: ArrayRef = Arc::new(Int32Array::from_iter(
- initial_range.clone().collect::<Vec<i32>>(),
- ));
- let ordered_des = Arc::new(Int32Array::from_iter(
- initial_range.clone().rev().collect::<Vec<i32>>(),
- ));
- let cardinality = Arc::new(Int32Array::from_iter(
- initial_range.clone().map(|x| x % 4).collect::<Vec<i32>>(),
- ));
- let cardinality_key_left = Arc::new(Int32Array::from_iter(
- initial_range
- .clone()
- .map(|x| x % key_cardinality.0)
- .collect::<Vec<i32>>(),
- ));
- let cardinality_key_right = Arc::new(Int32Array::from_iter(
- initial_range
- .clone()
- .map(|x| x % key_cardinality.1)
- .collect::<Vec<i32>>(),
- ));
- let ordered_asc_null_first = Arc::new(Int32Array::from_iter({
- std::iter::repeat(None)
- .take(index as usize)
- .chain(rest_of.clone().map(Some))
- .collect::<Vec<Option<i32>>>()
- }));
- let ordered_asc_null_last = Arc::new(Int32Array::from_iter({
- rest_of
- .clone()
- .map(Some)
- .chain(std::iter::repeat(None).take(index as usize))
- .collect::<Vec<Option<i32>>>()
- }));
-
- let ordered_desc_null_first = Arc::new(Int32Array::from_iter({
- std::iter::repeat(None)
- .take(index as usize)
- .chain(rest_of.rev().map(Some))
- .collect::<Vec<Option<i32>>>()
- }));
-
- let time = Arc::new(TimestampMillisecondArray::from(
- initial_range
- .clone()
- .map(|x| x as i64 + 1672531200000) // x + 2023-01-01:00.00.00
- .collect::<Vec<i64>>(),
- ));
- let interval_time: ArrayRef = Arc::new(IntervalDayTimeArray::from(
- initial_range
- .map(|x| x as i64 * 100) // x * 100ms
- .collect::<Vec<i64>>(),
- ));
-
- let float_asc = Arc::new(Float64Array::from_iter_values(
- AscendingRandomFloatIterator::new(0., table_size as f64)
- .take(table_size as usize),
- ));
-
- let left = RecordBatch::try_from_iter(vec![
- ("la1", ordered.clone()),
- ("lb1", cardinality.clone()),
- ("lc1", cardinality_key_left),
- ("lt1", time.clone()),
- ("la2", ordered.clone()),
- ("la1_des", ordered_des.clone()),
- ("l_asc_null_first", ordered_asc_null_first.clone()),
- ("l_asc_null_last", ordered_asc_null_last.clone()),
- ("l_desc_null_first", ordered_desc_null_first.clone()),
- ("li1", interval_time.clone()),
- ("l_float", float_asc.clone()),
- ])?;
- let right = RecordBatch::try_from_iter(vec![
- ("ra1", ordered.clone()),
- ("rb1", cardinality),
- ("rc1", cardinality_key_right),
- ("rt1", time),
- ("ra2", ordered),
- ("ra1_des", ordered_des),
- ("r_asc_null_first", ordered_asc_null_first),
- ("r_asc_null_last", ordered_asc_null_last),
- ("r_desc_null_first", ordered_desc_null_first),
- ("ri1", interval_time),
- ("r_float", float_asc),
- ])?;
- Ok((left, right))
- }
+ use super::*;
- fn create_memory_table(
- left_batch: RecordBatch,
- right_batch: RecordBatch,
- left_sorted: Option<Vec<PhysicalSortExpr>>,
- right_sorted: Option<Vec<PhysicalSortExpr>>,
- batch_size: usize,
- ) -> Result<(Arc<dyn ExecutionPlan>, Arc<dyn ExecutionPlan>)> {
- let mut left = MemoryExec::try_new(
- &[split_record_batches(&left_batch, batch_size)?],
- left_batch.schema(),
- None,
- )?;
- if let Some(sorted) = left_sorted {
- left = left.with_sort_information(sorted);
- }
- let mut right = MemoryExec::try_new(
- &[split_record_batches(&right_batch, batch_size)?],
- right_batch.schema(),
- None,
- )?;
- if let Some(sorted) = right_sorted {
- right = right.with_sort_information(sorted);
- }
- Ok((Arc::new(left), Arc::new(right)))
- }
+ const TABLE_SIZE: i32 = 100;
- async fn experiment(
+ pub async fn experiment(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
filter: Option<JoinFilter>,
@@ -2289,10 +1498,10 @@ mod tests {
)]
join_type: JoinType,
#[values(
- (4, 5),
- (11, 21),
- (31, 71),
- (99, 12),
+ (4, 5),
+ (11, 21),
+ (31, 71),
+ (99, 12),
)]
cardinality: (i32, i32),
) -> Result<()> {
@@ -2370,10 +1579,10 @@ mod tests {
)]
join_type: JoinType,
#[values(
- (4, 5),
- (11, 21),
- (31, 71),
- (99, 12),
+ (4, 5),
+ (11, 21),
+ (31, 71),
+ (99, 12),
)]
cardinality: (i32, i32),
#[values(0, 1, 2, 3, 4, 5, 6, 7)] case_expr: usize,
@@ -2536,10 +1745,10 @@ mod tests {
)]
join_type: JoinType,
#[values(
- (4, 5),
- (11, 21),
- (31, 71),
- (99, 12),
+ (4, 5),
+ (11, 21),
+ (31, 71),
+ (99, 12),
)]
cardinality: (i32, i32),
#[values(0, 1, 2, 3, 4, 5, 6)] case_expr: usize,
@@ -3125,14 +2334,13 @@ mod tests {
initial_right_batch.num_rows()
);
- left_side_joiner.join_with_probe_batch(
+ join_with_probe_batch(
+ &mut left_side_joiner,
+ &mut right_side_joiner,
&join_schema,
join_type,
- &right_side_joiner.on,
Some(&filter),
&initial_right_batch,
- &mut right_side_joiner.visited_rows,
- right_side_joiner.offset,
&join_column_indices,
&random_state,
false,
@@ -3155,9 +2363,9 @@ mod tests {
)]
join_type: JoinType,
#[values(
- (4, 5),
- (99, 12),
- )]
+ (4, 5),
+ (99, 12),
+ )]
cardinality: (i32, i32),
#[values(0, 1)] case_expr: usize,
) -> Result<()> {
diff --git a/datafusion/core/src/physical_plan/joins/test_utils.rs
b/datafusion/core/src/physical_plan/joins/test_utils.rs
new file mode 100644
index 0000000000..e786fb5eb5
--- /dev/null
+++ b/datafusion/core/src/physical_plan/joins/test_utils.rs
@@ -0,0 +1,513 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! This file has test utils for hash joins
+
+use crate::physical_plan::joins::utils::{JoinFilter, JoinOn};
+use crate::physical_plan::joins::{
+ HashJoinExec, PartitionMode, StreamJoinPartitionMode,
SymmetricHashJoinExec,
+};
+use crate::physical_plan::memory::MemoryExec;
+use crate::physical_plan::repartition::RepartitionExec;
+use crate::physical_plan::{common, ExecutionPlan, Partitioning};
+use arrow::util::pretty::pretty_format_batches;
+use arrow_array::{
+ ArrayRef, Float64Array, Int32Array, IntervalDayTimeArray, RecordBatch,
+ TimestampMillisecondArray,
+};
+use arrow_schema::Schema;
+use datafusion_common::Result;
+use datafusion_common::ScalarValue;
+use datafusion_execution::TaskContext;
+use datafusion_expr::{JoinType, Operator};
+use datafusion_physical_expr::intervals::test_utils::{
+ gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr,
+};
+use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
+use rand::prelude::StdRng;
+use rand::{Rng, SeedableRng};
+use std::sync::Arc;
+use std::usize;
+
+pub fn compare_batches(collected_1: &[RecordBatch], collected_2:
&[RecordBatch]) {
+ // compare
+ let first_formatted =
pretty_format_batches(collected_1).unwrap().to_string();
+ let second_formatted =
pretty_format_batches(collected_2).unwrap().to_string();
+
+ let mut first_formatted_sorted: Vec<&str> =
first_formatted.trim().lines().collect();
+ first_formatted_sorted.sort_unstable();
+
+ let mut second_formatted_sorted: Vec<&str> =
+ second_formatted.trim().lines().collect();
+ second_formatted_sorted.sort_unstable();
+
+ for (i, (first_line, second_line)) in first_formatted_sorted
+ .iter()
+ .zip(&second_formatted_sorted)
+ .enumerate()
+ {
+ assert_eq!((i, first_line), (i, second_line));
+ }
+}
+
+pub async fn partitioned_sym_join_with_filter(
+ left: Arc<dyn ExecutionPlan>,
+ right: Arc<dyn ExecutionPlan>,
+ on: JoinOn,
+ filter: Option<JoinFilter>,
+ join_type: &JoinType,
+ null_equals_null: bool,
+ context: Arc<TaskContext>,
+) -> Result<Vec<RecordBatch>> {
+ let partition_count = 4;
+
+ let left_expr = on
+ .iter()
+ .map(|(l, _)| Arc::new(l.clone()) as _)
+ .collect::<Vec<_>>();
+
+ let right_expr = on
+ .iter()
+ .map(|(_, r)| Arc::new(r.clone()) as _)
+ .collect::<Vec<_>>();
+
+ let join = SymmetricHashJoinExec::try_new(
+ Arc::new(RepartitionExec::try_new(
+ left,
+ Partitioning::Hash(left_expr, partition_count),
+ )?),
+ Arc::new(RepartitionExec::try_new(
+ right,
+ Partitioning::Hash(right_expr, partition_count),
+ )?),
+ on,
+ filter,
+ join_type,
+ null_equals_null,
+ StreamJoinPartitionMode::Partitioned,
+ )?;
+
+ let mut batches = vec![];
+ for i in 0..partition_count {
+ let stream = join.execute(i, context.clone())?;
+ let more_batches = common::collect(stream).await?;
+ batches.extend(
+ more_batches
+ .into_iter()
+ .filter(|b| b.num_rows() > 0)
+ .collect::<Vec<_>>(),
+ );
+ }
+
+ Ok(batches)
+}
+
+pub async fn partitioned_hash_join_with_filter(
+ left: Arc<dyn ExecutionPlan>,
+ right: Arc<dyn ExecutionPlan>,
+ on: JoinOn,
+ filter: Option<JoinFilter>,
+ join_type: &JoinType,
+ null_equals_null: bool,
+ context: Arc<TaskContext>,
+) -> Result<Vec<RecordBatch>> {
+ let partition_count = 4;
+ let (left_expr, right_expr) = on
+ .iter()
+ .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _))
+ .unzip();
+
+ let join = Arc::new(HashJoinExec::try_new(
+ Arc::new(RepartitionExec::try_new(
+ left,
+ Partitioning::Hash(left_expr, partition_count),
+ )?),
+ Arc::new(RepartitionExec::try_new(
+ right,
+ Partitioning::Hash(right_expr, partition_count),
+ )?),
+ on,
+ filter,
+ join_type,
+ PartitionMode::Partitioned,
+ null_equals_null,
+ )?);
+
+ let mut batches = vec![];
+ for i in 0..partition_count {
+ let stream = join.execute(i, context.clone())?;
+ let more_batches = common::collect(stream).await?;
+ batches.extend(
+ more_batches
+ .into_iter()
+ .filter(|b| b.num_rows() > 0)
+ .collect::<Vec<_>>(),
+ );
+ }
+
+ Ok(batches)
+}
+
+pub fn split_record_batches(
+ batch: &RecordBatch,
+ batch_size: usize,
+) -> Result<Vec<RecordBatch>> {
+ let row_num = batch.num_rows();
+ let number_of_batch = row_num / batch_size;
+ let mut sizes = vec![batch_size; number_of_batch];
+ sizes.push(row_num - (batch_size * number_of_batch));
+ let mut result = vec![];
+ for (i, size) in sizes.iter().enumerate() {
+ result.push(batch.slice(i * batch_size, *size));
+ }
+ Ok(result)
+}
+
+struct AscendingRandomFloatIterator {
+ prev: f64,
+ max: f64,
+ rng: StdRng,
+}
+
+impl AscendingRandomFloatIterator {
+ fn new(min: f64, max: f64) -> Self {
+ let mut rng = StdRng::seed_from_u64(42);
+ let initial = rng.gen_range(min..max);
+ AscendingRandomFloatIterator {
+ prev: initial,
+ max,
+ rng,
+ }
+ }
+}
+
+impl Iterator for AscendingRandomFloatIterator {
+ type Item = f64;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let value = self.rng.gen_range(self.prev..self.max);
+ self.prev = value;
+ Some(value)
+ }
+}
+
+pub fn join_expr_tests_fixture_temporal(
+ expr_id: usize,
+ left_col: Arc<dyn PhysicalExpr>,
+ right_col: Arc<dyn PhysicalExpr>,
+ schema: &Schema,
+) -> Result<Arc<dyn PhysicalExpr>> {
+ match expr_id {
+ // constructs ((left_col - INTERVAL '100ms') > (right_col - INTERVAL
'200ms')) AND ((left_col - INTERVAL '450ms') < (right_col - INTERVAL '300ms'))
+ 0 => gen_conjunctive_temporal_expr(
+ left_col,
+ right_col,
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Minus,
+ ScalarValue::new_interval_dt(0, 100), // 100 ms
+ ScalarValue::new_interval_dt(0, 200), // 200 ms
+ ScalarValue::new_interval_dt(0, 450), // 450 ms
+ ScalarValue::new_interval_dt(0, 300), // 300 ms
+ schema,
+ ),
+ // constructs ((left_col - TIMESTAMP '2023-01-01:12.00.03') >
(right_col - TIMESTAMP '2023-01-01:12.00.01')) AND ((left_col - TIMESTAMP
'2023-01-01:12.00.00') < (right_col - TIMESTAMP '2023-01-01:12.00.02'))
+ 1 => gen_conjunctive_temporal_expr(
+ left_col,
+ right_col,
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Minus,
+ ScalarValue::TimestampMillisecond(Some(1672574403000), None), //
2023-01-01:12.00.03
+ ScalarValue::TimestampMillisecond(Some(1672574401000), None), //
2023-01-01:12.00.01
+ ScalarValue::TimestampMillisecond(Some(1672574400000), None), //
2023-01-01:12.00.00
+ ScalarValue::TimestampMillisecond(Some(1672574402000), None), //
2023-01-01:12.00.02
+ schema,
+ ),
+ _ => unreachable!(),
+ }
+}
+
+// It creates join filters for different type of fields for testing.
+macro_rules! join_expr_tests {
+ ($func_name:ident, $type:ty, $SCALAR:ident) => {
+ pub fn $func_name(
+ expr_id: usize,
+ left_col: Arc<dyn PhysicalExpr>,
+ right_col: Arc<dyn PhysicalExpr>,
+ ) -> Arc<dyn PhysicalExpr> {
+ match expr_id {
+ // left_col + 1 > right_col + 5 AND left_col + 3 < right_col +
10
+ 0 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Plus,
+ Operator::Plus,
+ Operator::Plus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(1 as $type)),
+ ScalarValue::$SCALAR(Some(5 as $type)),
+ ScalarValue::$SCALAR(Some(3 as $type)),
+ ScalarValue::$SCALAR(Some(10 as $type)),
+ (Operator::Gt, Operator::Lt),
+ ),
+ // left_col - 1 > right_col + 5 AND left_col + 3 < right_col +
10
+ 1 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Minus,
+ Operator::Plus,
+ Operator::Plus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(1 as $type)),
+ ScalarValue::$SCALAR(Some(5 as $type)),
+ ScalarValue::$SCALAR(Some(3 as $type)),
+ ScalarValue::$SCALAR(Some(10 as $type)),
+ (Operator::Gt, Operator::Lt),
+ ),
+ // left_col - 1 > right_col + 5 AND left_col - 3 < right_col +
10
+ 2 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Minus,
+ Operator::Plus,
+ Operator::Minus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(1 as $type)),
+ ScalarValue::$SCALAR(Some(5 as $type)),
+ ScalarValue::$SCALAR(Some(3 as $type)),
+ ScalarValue::$SCALAR(Some(10 as $type)),
+ (Operator::Gt, Operator::Lt),
+ ),
+ // left_col - 10 > right_col - 5 AND left_col - 3 < right_col
+ 10
+ 3 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(10 as $type)),
+ ScalarValue::$SCALAR(Some(5 as $type)),
+ ScalarValue::$SCALAR(Some(3 as $type)),
+ ScalarValue::$SCALAR(Some(10 as $type)),
+ (Operator::Gt, Operator::Lt),
+ ),
+ // left_col - 10 > right_col - 5 AND left_col - 30 < right_col
- 3
+ 4 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Minus,
+ ),
+ ScalarValue::$SCALAR(Some(10 as $type)),
+ ScalarValue::$SCALAR(Some(5 as $type)),
+ ScalarValue::$SCALAR(Some(30 as $type)),
+ ScalarValue::$SCALAR(Some(3 as $type)),
+ (Operator::Gt, Operator::Lt),
+ ),
+ // left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col
- 3
+ 5 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Minus,
+ Operator::Plus,
+ Operator::Plus,
+ Operator::Minus,
+ ),
+ ScalarValue::$SCALAR(Some(2 as $type)),
+ ScalarValue::$SCALAR(Some(5 as $type)),
+ ScalarValue::$SCALAR(Some(7 as $type)),
+ ScalarValue::$SCALAR(Some(3 as $type)),
+ (Operator::GtEq, Operator::LtEq),
+ ),
+ // left_col - 28 >= right_col - 11 AND left_col - 21 <=
right_col - 39
+ 6 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Plus,
+ Operator::Minus,
+ Operator::Plus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(28 as $type)),
+ ScalarValue::$SCALAR(Some(11 as $type)),
+ ScalarValue::$SCALAR(Some(21 as $type)),
+ ScalarValue::$SCALAR(Some(39 as $type)),
+ (Operator::Gt, Operator::LtEq),
+ ),
+ // left_col - 28 >= right_col - 11 AND left_col - 21 <=
right_col + 39
+ 7 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Plus,
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(28 as $type)),
+ ScalarValue::$SCALAR(Some(11 as $type)),
+ ScalarValue::$SCALAR(Some(21 as $type)),
+ ScalarValue::$SCALAR(Some(39 as $type)),
+ (Operator::GtEq, Operator::Lt),
+ ),
+ _ => panic!("No case"),
+ }
+ }
+ };
+}
+
+join_expr_tests!(join_expr_tests_fixture_i32, i32, Int32);
+join_expr_tests!(join_expr_tests_fixture_f64, f64, Float64);
+
+pub fn build_sides_record_batches(
+ table_size: i32,
+ key_cardinality: (i32, i32),
+) -> Result<(RecordBatch, RecordBatch)> {
+ let null_ratio: f64 = 0.4;
+ let initial_range = 0..table_size;
+ let index = (table_size as f64 * null_ratio).round() as i32;
+ let rest_of = index..table_size;
+ let ordered: ArrayRef = Arc::new(Int32Array::from_iter(
+ initial_range.clone().collect::<Vec<i32>>(),
+ ));
+ let ordered_des = Arc::new(Int32Array::from_iter(
+ initial_range.clone().rev().collect::<Vec<i32>>(),
+ ));
+ let cardinality = Arc::new(Int32Array::from_iter(
+ initial_range.clone().map(|x| x % 4).collect::<Vec<i32>>(),
+ ));
+ let cardinality_key_left = Arc::new(Int32Array::from_iter(
+ initial_range
+ .clone()
+ .map(|x| x % key_cardinality.0)
+ .collect::<Vec<i32>>(),
+ ));
+ let cardinality_key_right = Arc::new(Int32Array::from_iter(
+ initial_range
+ .clone()
+ .map(|x| x % key_cardinality.1)
+ .collect::<Vec<i32>>(),
+ ));
+ let ordered_asc_null_first = Arc::new(Int32Array::from_iter({
+ std::iter::repeat(None)
+ .take(index as usize)
+ .chain(rest_of.clone().map(Some))
+ .collect::<Vec<Option<i32>>>()
+ }));
+ let ordered_asc_null_last = Arc::new(Int32Array::from_iter({
+ rest_of
+ .clone()
+ .map(Some)
+ .chain(std::iter::repeat(None).take(index as usize))
+ .collect::<Vec<Option<i32>>>()
+ }));
+
+ let ordered_desc_null_first = Arc::new(Int32Array::from_iter({
+ std::iter::repeat(None)
+ .take(index as usize)
+ .chain(rest_of.rev().map(Some))
+ .collect::<Vec<Option<i32>>>()
+ }));
+
+ let time = Arc::new(TimestampMillisecondArray::from(
+ initial_range
+ .clone()
+ .map(|x| x as i64 + 1672531200000) // x + 2023-01-01:00.00.00
+ .collect::<Vec<i64>>(),
+ ));
+ let interval_time: ArrayRef = Arc::new(IntervalDayTimeArray::from(
+ initial_range
+ .map(|x| x as i64 * 100) // x * 100ms
+ .collect::<Vec<i64>>(),
+ ));
+
+ let float_asc = Arc::new(Float64Array::from_iter_values(
+ AscendingRandomFloatIterator::new(0., table_size as f64)
+ .take(table_size as usize),
+ ));
+
+ let left = RecordBatch::try_from_iter(vec![
+ ("la1", ordered.clone()),
+ ("lb1", cardinality.clone()),
+ ("lc1", cardinality_key_left),
+ ("lt1", time.clone()),
+ ("la2", ordered.clone()),
+ ("la1_des", ordered_des.clone()),
+ ("l_asc_null_first", ordered_asc_null_first.clone()),
+ ("l_asc_null_last", ordered_asc_null_last.clone()),
+ ("l_desc_null_first", ordered_desc_null_first.clone()),
+ ("li1", interval_time.clone()),
+ ("l_float", float_asc.clone()),
+ ])?;
+ let right = RecordBatch::try_from_iter(vec![
+ ("ra1", ordered.clone()),
+ ("rb1", cardinality),
+ ("rc1", cardinality_key_right),
+ ("rt1", time),
+ ("ra2", ordered),
+ ("ra1_des", ordered_des),
+ ("r_asc_null_first", ordered_asc_null_first),
+ ("r_asc_null_last", ordered_asc_null_last),
+ ("r_desc_null_first", ordered_desc_null_first),
+ ("ri1", interval_time),
+ ("r_float", float_asc),
+ ])?;
+ Ok((left, right))
+}
+
+pub fn create_memory_table(
+ left_batch: RecordBatch,
+ right_batch: RecordBatch,
+ left_sorted: Option<Vec<PhysicalSortExpr>>,
+ right_sorted: Option<Vec<PhysicalSortExpr>>,
+ batch_size: usize,
+) -> Result<(Arc<dyn ExecutionPlan>, Arc<dyn ExecutionPlan>)> {
+ let mut left = MemoryExec::try_new(
+ &[split_record_batches(&left_batch, batch_size)?],
+ left_batch.schema(),
+ None,
+ )?;
+ if let Some(sorted) = left_sorted {
+ left = left.with_sort_information(sorted);
+ }
+ let mut right = MemoryExec::try_new(
+ &[split_record_batches(&right_batch, batch_size)?],
+ right_batch.schema(),
+ None,
+ )?;
+ if let Some(sorted) = right_sorted {
+ right = right.with_sort_information(sorted);
+ }
+ Ok((Arc::new(left), Arc::new(right)))
+}