This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 13fdf89ad7 Support join filter for `SortMergeJoin` (#9080)
13fdf89ad7 is described below

commit 13fdf89ad75f46c0887712a410080f11b56988ef
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Wed Feb 7 08:48:23 2024 -0800

    Support join filter for `SortMergeJoin` (#9080)
    
    * Support join filter for SortMergeJoin
    
    * Move test
    
    * Fix test
    
    * Fix clippy
    
    * Add outer join tests
    
    * Fix outer join
    
    * For review
    
    * Update datafusion/physical-plan/src/joins/sort_merge_join.rs
    
    Co-authored-by: Andrew Lamb <and...@nerdnetworks.org>
    
    ---------
    
    Co-authored-by: Andrew Lamb <and...@nerdnetworks.org>
---
 .../src/physical_optimizer/enforce_distribution.rs |   5 +
 .../src/physical_optimizer/projection_pushdown.rs  |   1 +
 .../core/src/physical_optimizer/test_utils.rs      |   1 +
 datafusion/core/src/physical_planner.rs            |  26 +-
 datafusion/core/tests/fuzz_cases/join_fuzz.rs      |   1 +
 .../physical-plan/src/joins/sort_merge_join.rs     | 219 ++++++++++++++++-
 datafusion/sqllogictest/test_files/join.slt        |  21 ++
 .../sqllogictest/test_files/sort_merge_join.slt    | 267 +++++++++++++++++++++
 8 files changed, 515 insertions(+), 26 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs 
b/datafusion/core/src/physical_optimizer/enforce_distribution.rs
index fab26c49c2..4f8806a685 100644
--- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs
+++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs
@@ -342,6 +342,7 @@ fn adjust_input_keys_ordering(
         left,
         right,
         on,
+        filter,
         join_type,
         sort_options,
         null_equals_null,
@@ -356,6 +357,7 @@ fn adjust_input_keys_ordering(
                 left.clone(),
                 right.clone(),
                 new_conditions.0,
+                filter.clone(),
                 *join_type,
                 new_conditions.1,
                 *null_equals_null,
@@ -635,6 +637,7 @@ pub(crate) fn reorder_join_keys_to_inputs(
         left,
         right,
         on,
+        filter,
         join_type,
         sort_options,
         null_equals_null,
@@ -664,6 +667,7 @@ pub(crate) fn reorder_join_keys_to_inputs(
                     left.clone(),
                     right.clone(),
                     new_join_on,
+                    filter.clone(),
                     *join_type,
                     new_sort_options,
                     *null_equals_null,
@@ -1642,6 +1646,7 @@ pub(crate) mod tests {
                 left,
                 right,
                 join_on.clone(),
+                None,
                 *join_type,
                 vec![SortOptions::default(); join_on.len()],
                 false,
diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs 
b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
index e638f4a9a8..437d63dad2 100644
--- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs
+++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
@@ -736,6 +736,7 @@ fn try_swapping_with_sort_merge_join(
         Arc::new(new_left),
         Arc::new(new_right),
         new_on,
+        sm_join.filter.clone(),
         sm_join.join_type,
         sm_join.sort_options.clone(),
         sm_join.null_equals_null,
diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs 
b/datafusion/core/src/physical_optimizer/test_utils.rs
index 5de6cff0b4..ca7fb78d21 100644
--- a/datafusion/core/src/physical_optimizer/test_utils.rs
+++ b/datafusion/core/src/physical_optimizer/test_utils.rs
@@ -175,6 +175,7 @@ pub fn sort_merge_join_exec(
             left,
             right,
             join_on.clone(),
+            None,
             *join_type,
             vec![SortOptions::default(); join_on.len()],
             false,
diff --git a/datafusion/core/src/physical_planner.rs 
b/datafusion/core/src/physical_planner.rs
index 71be8ec7e8..463d0cde82 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -1114,6 +1114,7 @@ impl DefaultPhysicalPlanner {
                     };
 
                     let prefer_hash_join = 
session_state.config_options().optimizer.prefer_hash_join;
+
                     if join_on.is_empty() {
                         // there is no equal join condition, use the nested 
loop join
                         // TODO optimize the plan, and use the config of 
`target_partitions` and `repartition_joins`
@@ -1129,20 +1130,17 @@ impl DefaultPhysicalPlanner {
                     {
                         // Use SortMergeJoin if hash join is not preferred
                         // Sort-Merge join support currently is experimental
-                        if join_filter.is_some() {
-                            // TODO SortMergeJoinExec need to support join 
filter
-                            not_impl_err!("SortMergeJoinExec does not support 
join_filter now.")
-                        } else {
-                            let join_on_len = join_on.len();
-                            Ok(Arc::new(SortMergeJoinExec::try_new(
-                                physical_left,
-                                physical_right,
-                                join_on,
-                                *join_type,
-                                vec![SortOptions::default(); join_on_len],
-                                null_equals_null,
-                            )?))
-                        }
+
+                        let join_on_len = join_on.len();
+                        Ok(Arc::new(SortMergeJoinExec::try_new(
+                            physical_left,
+                            physical_right,
+                            join_on,
+                            join_filter,
+                            *join_type,
+                            vec![SortOptions::default(); join_on_len],
+                            null_equals_null,
+                        )?))
                     } else if session_state.config().target_partitions() > 1
                         && session_state.config().repartition_joins()
                         && prefer_hash_join {
diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
index 1c819ac466..78f8ee7723 100644
--- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
@@ -130,6 +130,7 @@ async fn run_join_test(
                 left,
                 right,
                 on_columns.clone(),
+                None,
                 join_type,
                 vec![SortOptions::default(), SortOptions::default()],
                 false,
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs 
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index 675e90fb63..107fd7dde0 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -33,7 +33,7 @@ use std::task::{Context, Poll};
 use crate::expressions::PhysicalSortExpr;
 use crate::joins::utils::{
     build_join_schema, calculate_join_output_ordering, check_join_is_valid,
-    estimate_join_statistics, partitioned_join_output_partitioning, JoinOn,
+    estimate_join_statistics, partitioned_join_output_partitioning, 
JoinFilter, JoinOn,
 };
 use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
 use crate::{
@@ -42,6 +42,7 @@ use crate::{
 };
 
 use arrow::array::*;
+use arrow::compute;
 use arrow::compute::{concat_batches, take, SortOptions};
 use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
 use arrow::error::ArrowError;
@@ -68,6 +69,8 @@ pub struct SortMergeJoinExec {
     pub right: Arc<dyn ExecutionPlan>,
     /// Set of common columns used to join on
     pub on: JoinOn,
+    /// Filters which are applied while finding matching rows
+    pub filter: Option<JoinFilter>,
     /// How the join is performed
     pub join_type: JoinType,
     /// The schema once the join is applied
@@ -95,6 +98,7 @@ impl SortMergeJoinExec {
         left: Arc<dyn ExecutionPlan>,
         right: Arc<dyn ExecutionPlan>,
         on: JoinOn,
+        filter: Option<JoinFilter>,
         join_type: JoinType,
         sort_options: Vec<SortOptions>,
         null_equals_null: bool,
@@ -150,6 +154,7 @@ impl SortMergeJoinExec {
             left,
             right,
             on,
+            filter,
             join_type,
             schema,
             metrics: ExecutionPlanMetricsSet::new(),
@@ -210,6 +215,11 @@ impl SortMergeJoinExec {
 
 impl DisplayAs for SortMergeJoinExec {
     fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> 
std::fmt::Result {
+        let display_filter = self.filter.as_ref().map_or_else(
+            || "".to_string(),
+            |f| format!(", filter={}", f.expression()),
+        );
+
         match t {
             DisplayFormatType::Default | DisplayFormatType::Verbose => {
                 let on = self
@@ -220,8 +230,8 @@ impl DisplayAs for SortMergeJoinExec {
                     .join(", ");
                 write!(
                     f,
-                    "SortMergeJoin: join_type={:?}, on=[{}]",
-                    self.join_type, on
+                    "SortMergeJoin: join_type={:?}, on=[{}]{}",
+                    self.join_type, on, display_filter
                 )
             }
         }
@@ -300,6 +310,7 @@ impl ExecutionPlan for SortMergeJoinExec {
                 left.clone(),
                 right.clone(),
                 self.on.clone(),
+                self.filter.clone(),
                 self.join_type,
                 self.sort_options.clone(),
                 self.null_equals_null,
@@ -349,6 +360,7 @@ impl ExecutionPlan for SortMergeJoinExec {
             buffered,
             on_streamed,
             on_buffered,
+            self.filter.clone(),
             self.join_type,
             batch_size,
             SortMergeJoinMetrics::new(partition, &self.metrics),
@@ -456,8 +468,9 @@ enum BufferedState {
     Exhausted,
 }
 
+/// Represents a chunk of joined data from streamed and buffered side
 struct StreamedJoinedChunk {
-    /// Index of batch buffered_data
+    /// Index of batch in buffered_data
     buffered_batch_idx: Option<usize>,
     /// Array builder for streamed indices
     streamed_indices: UInt64Builder,
@@ -466,13 +479,17 @@ struct StreamedJoinedChunk {
 }
 
 struct StreamedBatch {
+    /// The streamed record batch
     pub batch: RecordBatch,
+    /// The index of row in the streamed batch to compare with buffered batches
     pub idx: usize,
+    /// The join key arrays of streamed batch which are used to compare with 
buffered batches
+    /// and to produce output. They are produced by evaluating `on` 
expressions.
     pub join_arrays: Vec<ArrayRef>,
 
-    // Chunks of indices from buffered side (may be nulls) joined to streamed
+    /// Chunks of indices from buffered side (may be nulls) joined to streamed
     pub output_indices: Vec<StreamedJoinedChunk>,
-    // Index of currently scanned batch from buffered data
+    /// Index of currently scanned batch from buffered data
     pub buffered_batch_idx: Option<usize>,
 }
 
@@ -505,6 +522,8 @@ impl StreamedBatch {
         buffered_batch_idx: Option<usize>,
         buffered_idx: Option<usize>,
     ) {
+        // If no current chunk exists or current chunk is not for current 
buffered batch,
+        // create a new chunk
         if self.output_indices.is_empty() || self.buffered_batch_idx != 
buffered_batch_idx
         {
             self.output_indices.push(StreamedJoinedChunk {
@@ -516,6 +535,7 @@ impl StreamedBatch {
         };
         let current_chunk = self.output_indices.last_mut().unwrap();
 
+        // Append index of streamed batch and index of buffered batch into 
current chunk
         current_chunk.streamed_indices.append_value(self.idx as u64);
         if let Some(idx) = buffered_idx {
             current_chunk.buffered_indices.append_value(idx as u64);
@@ -610,9 +630,13 @@ struct SMJStream {
     pub on_streamed: Vec<PhysicalExprRef>,
     /// Join key columns of buffered
     pub on_buffered: Vec<PhysicalExprRef>,
+    /// optional join filter
+    pub filter: Option<JoinFilter>,
     /// Staging output array builders
     pub output_record_batches: Vec<RecordBatch>,
-    /// Staging output size, including output batches and staging joined 
results
+    /// Staging output size, including output batches and staging joined 
results.
+    /// Increased when we put rows into buffer and decreased after we actually 
output batches.
+    /// Used to trigger output when sufficient rows are ready
     pub output_size: usize,
     /// Target output batch size
     pub batch_size: usize,
@@ -736,6 +760,7 @@ impl SMJStream {
         buffered: SendableRecordBatchStream,
         on_streamed: Vec<Arc<dyn PhysicalExpr>>,
         on_buffered: Vec<Arc<dyn PhysicalExpr>>,
+        filter: Option<JoinFilter>,
         join_type: JoinType,
         batch_size: usize,
         join_metrics: SortMergeJoinMetrics,
@@ -761,6 +786,7 @@ impl SMJStream {
             current_ordering: Ordering::Equal,
             on_streamed,
             on_buffered,
+            filter,
             output_record_batches: vec![],
             output_size: 0,
             batch_size,
@@ -943,7 +969,9 @@ impl SMJStream {
     /// Produce join and fill output buffer until reaching target batch size
     /// or the join is finished
     fn join_partial(&mut self) -> Result<()> {
+        // Whether to join streamed rows
         let mut join_streamed = false;
+        // Whether to join buffered rows
         let mut join_buffered = false;
 
         // determine whether we need to join streamed/buffered rows
@@ -991,11 +1019,13 @@ impl SMJStream {
             {
                 let scanning_idx = self.buffered_data.scanning_idx();
                 if join_streamed {
+                    // Join streamed row and buffered row
                     self.streamed_batch.append_output_pair(
                         Some(self.buffered_data.scanning_batch_idx),
                         Some(scanning_idx),
                     );
                 } else {
+                    // Join nulls and buffered row
                     self.buffered_data
                         .scanning_batch_mut()
                         .null_joined
@@ -1059,6 +1089,7 @@ impl SMJStream {
             }
             buffered_batch.null_joined.clear();
 
+            // Take buffered (right) columns
             let buffered_columns = buffered_batch
                 .batch
                 .columns()
@@ -1067,6 +1098,7 @@ impl SMJStream {
                 .collect::<Result<Vec<_>, ArrowError>>()
                 .map_err(Into::<DataFusionError>::into)?;
 
+            // Create null streamed (left) columns
             let mut streamed_columns = self
                 .streamed_schema
                 .fields()
@@ -1121,16 +1153,141 @@ impl SMJStream {
                         .collect::<Vec<_>>()
                 };
 
+            let streamed_columns_length = streamed_columns.len();
+            let buffered_columns_length = buffered_columns.len();
+
+            // Prepare the columns we apply join filter on later.
+            // Only for joined rows between streamed and buffered.
+            let filter_columns = if chunk.buffered_batch_idx.is_some() {
+                if matches!(self.join_type, JoinType::Right) {
+                    get_filter_column(&self.filter, &buffered_columns, 
&streamed_columns)
+                } else {
+                    get_filter_column(&self.filter, &streamed_columns, 
&buffered_columns)
+                }
+            } else {
+                // This chunk is for null joined rows (outer join), we don't 
need to apply join filter.
+                vec![]
+            };
+
             let columns = if matches!(self.join_type, JoinType::Right) {
-                buffered_columns.extend(streamed_columns);
+                buffered_columns.extend(streamed_columns.clone());
                 buffered_columns
             } else {
                 streamed_columns.extend(buffered_columns);
                 streamed_columns
             };
 
-            self.output_record_batches
-                .push(RecordBatch::try_new(self.schema.clone(), columns)?);
+            let output_batch =
+                RecordBatch::try_new(self.schema.clone(), columns.clone())?;
+
+            // Apply join filter if any
+            if !filter_columns.is_empty() {
+                if let Some(f) = &self.filter {
+                    // Construct batch with only filter columns
+                    let filter_batch = RecordBatch::try_new(
+                        Arc::new(f.schema().clone()),
+                        filter_columns,
+                    )?;
+
+                    let filter_result = f
+                        .expression()
+                        .evaluate(&filter_batch)?
+                        .into_array(filter_batch.num_rows())?;
+
+                    // The selection mask of the filter
+                    let mask = 
datafusion_common::cast::as_boolean_array(&filter_result)?;
+
+                    // Push the filtered batch to the output
+                    let filtered_batch =
+                        compute::filter_record_batch(&output_batch, mask)?;
+                    self.output_record_batches.push(filtered_batch);
+
+                    // For outer joins, we need to push the null joined rows 
to the output.
+                    if matches!(
+                        self.join_type,
+                        JoinType::Left | JoinType::Right | JoinType::Full
+                    ) {
+                        // The reverse of the selection mask. For the rows not 
pass join filter above,
+                        // we need to join them (left or right) with null rows 
for outer joins.
+                        let not_mask = compute::not(mask)?;
+                        let null_joined_batch =
+                            compute::filter_record_batch(&output_batch, 
&not_mask)?;
+
+                        let mut buffered_columns = self
+                            .buffered_schema
+                            .fields()
+                            .iter()
+                            .map(|f| {
+                                new_null_array(
+                                    f.data_type(),
+                                    null_joined_batch.num_rows(),
+                                )
+                            })
+                            .collect::<Vec<_>>();
+
+                        let columns = if matches!(self.join_type, 
JoinType::Right) {
+                            let streamed_columns = null_joined_batch
+                                .columns()
+                                .iter()
+                                .skip(buffered_columns_length)
+                                .cloned()
+                                .collect::<Vec<_>>();
+
+                            buffered_columns.extend(streamed_columns);
+                            buffered_columns
+                        } else {
+                            // Left join or full outer join
+                            let mut streamed_columns = null_joined_batch
+                                .columns()
+                                .iter()
+                                .take(streamed_columns_length)
+                                .cloned()
+                                .collect::<Vec<_>>();
+
+                            streamed_columns.extend(buffered_columns);
+                            streamed_columns
+                        };
+
+                        let null_joined_streamed_batch =
+                            RecordBatch::try_new(self.schema.clone(), 
columns.clone())?;
+                        
self.output_record_batches.push(null_joined_streamed_batch);
+
+                        // For full join, we also need to output the null 
joined rows from the buffered side
+                        if matches!(self.join_type, JoinType::Full) {
+                            let mut streamed_columns = self
+                                .streamed_schema
+                                .fields()
+                                .iter()
+                                .map(|f| {
+                                    new_null_array(
+                                        f.data_type(),
+                                        null_joined_batch.num_rows(),
+                                    )
+                                })
+                                .collect::<Vec<_>>();
+
+                            let buffered_columns = null_joined_batch
+                                .columns()
+                                .iter()
+                                .skip(streamed_columns_length)
+                                .cloned()
+                                .collect::<Vec<_>>();
+
+                            streamed_columns.extend(buffered_columns);
+
+                            let null_joined_buffered_batch = 
RecordBatch::try_new(
+                                self.schema.clone(),
+                                streamed_columns,
+                            )?;
+                            
self.output_record_batches.push(null_joined_buffered_batch);
+                        }
+                    }
+                } else {
+                    self.output_record_batches.push(output_batch);
+                }
+            } else {
+                self.output_record_batches.push(output_batch);
+            }
         }
 
         self.streamed_batch.output_indices.clear();
@@ -1142,12 +1299,49 @@ impl SMJStream {
         let record_batch = concat_batches(&self.schema, 
&self.output_record_batches)?;
         self.join_metrics.output_batches.add(1);
         self.join_metrics.output_rows.add(record_batch.num_rows());
-        self.output_size -= record_batch.num_rows();
+        // If join filter exists, `self.output_size` is not accurate as we 
don't know the exact
+        // number of rows in the output record batch. If streamed row joined 
with buffered rows,
+        // once join filter is applied, the number of output rows may be more 
than 1.
+        if record_batch.num_rows() > self.output_size {
+            self.output_size = 0;
+        } else {
+            self.output_size -= record_batch.num_rows();
+        }
         self.output_record_batches.clear();
         Ok(record_batch)
     }
 }
 
+/// Gets the arrays which join filters are applied on.
+fn get_filter_column(
+    join_filter: &Option<JoinFilter>,
+    streamed_columns: &[ArrayRef],
+    buffered_columns: &[ArrayRef],
+) -> Vec<ArrayRef> {
+    let mut filter_columns = vec![];
+
+    if let Some(f) = join_filter {
+        let left_columns = f
+            .column_indices()
+            .iter()
+            .filter(|col_index| col_index.side == JoinSide::Left)
+            .map(|i| streamed_columns[i.index].clone())
+            .collect::<Vec<_>>();
+
+        let right_columns = f
+            .column_indices()
+            .iter()
+            .filter(|col_index| col_index.side == JoinSide::Right)
+            .map(|i| buffered_columns[i.index].clone())
+            .collect::<Vec<_>>();
+
+        filter_columns.extend(left_columns);
+        filter_columns.extend(right_columns);
+    }
+
+    filter_columns
+}
+
 /// Buffered data contains all buffered batches with one unique join key
 #[derive(Debug, Default)]
 struct BufferedData {
@@ -1498,7 +1692,7 @@ mod tests {
         join_type: JoinType,
     ) -> Result<SortMergeJoinExec> {
         let sort_options = vec![SortOptions::default(); on.len()];
-        SortMergeJoinExec::try_new(left, right, on, join_type, sort_options, 
false)
+        SortMergeJoinExec::try_new(left, right, on, None, join_type, 
sort_options, false)
     }
 
     fn join_with_options(
@@ -1513,6 +1707,7 @@ mod tests {
             left,
             right,
             on,
+            None,
             join_type,
             sort_options,
             null_equals_null,
diff --git a/datafusion/sqllogictest/test_files/join.slt 
b/datafusion/sqllogictest/test_files/join.slt
index e5cbf31c83..a162bf0632 100644
--- a/datafusion/sqllogictest/test_files/join.slt
+++ b/datafusion/sqllogictest/test_files/join.slt
@@ -238,6 +238,27 @@ SELECT t1_int, t2_int, t2_id FROM t1 RIGHT JOIN t2 ON 
t1_id = t2_id AND t2_int <
 NULL 3 11
 NULL 3 55
 
+# equijoin_full
+query ITIITI rowsort
+SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id
+----
+11 a 1 11 z 3
+22 b 2 22 y 1
+33 c 3 NULL NULL NULL
+44 d 4 44 x 3
+NULL NULL NULL 55 w 3
+
+# equijoin_full_and_condition_from_both
+query ITIITI rowsort
+SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id AND t2_int <= t1_int
+----
+11 a 1 NULL NULL NULL
+22 b 2 22 y 1
+33 c 3 NULL NULL NULL
+44 d 4 44 x 3
+NULL NULL NULL 11 z 3
+NULL NULL NULL 55 w 3
+
 # left_join
 query ITT rowsort
 SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id
diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt 
b/datafusion/sqllogictest/test_files/sort_merge_join.slt
new file mode 100644
index 0000000000..426b9a3a52
--- /dev/null
+++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+
+#   http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+##########
+## Sort Merge Join Tests
+##########
+
+statement ok
+set datafusion.optimizer.prefer_hash_join = false;
+
+statement ok
+CREATE TABLE t1(a text, b int) AS VALUES ('Alice', 50), ('Alice', 100), 
('Bob', 1);
+
+statement ok
+CREATE TABLE t2(a text, b int) AS VALUES ('Alice', 2), ('Alice', 1);
+
+# inner join query plan with join filter
+query TT
+EXPLAIN SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b 
* 50 <= t1.b
+----
+logical_plan
+Inner Join: t1.a = t2.a Filter: CAST(t2.b AS Int64) * Int64(50) <= CAST(t1.b 
AS Int64)
+--TableScan: t1 projection=[a, b]
+--TableScan: t2 projection=[a, b]
+physical_plan
+SortMergeJoin: join_type=Inner, on=[(a@0, a@0)], filter=CAST(b@1 AS Int64) * 
50 <= CAST(b@0 AS Int64)
+--SortExec: expr=[a@0 ASC]
+----CoalesceBatchesExec: target_batch_size=8192
+------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1
+--------MemoryExec: partitions=1, partition_sizes=[1]
+--SortExec: expr=[a@0 ASC]
+----CoalesceBatchesExec: target_batch_size=8192
+------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1
+--------MemoryExec: partitions=1, partition_sizes=[1]
+
+# inner join with join filter
+query TITI rowsort
+SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= 
t1.b
+----
+Alice 100 Alice 1
+Alice 100 Alice 2
+Alice 50 Alice 1
+
+query TITI rowsort
+SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b < t1.b
+----
+Alice 100 Alice 1
+Alice 100 Alice 2
+Alice 50 Alice 1
+Alice 50 Alice 2
+
+query TITI rowsort
+SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b > t1.b
+----
+
+# left join without join filter
+query TITI rowsort
+SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a
+----
+Alice 100 Alice 1
+Alice 100 Alice 2
+Alice 50 Alice 1
+Alice 50 Alice 2
+Bob 1 NULL NULL
+
+# left join with join filter
+query TITI rowsort
+SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b
+----
+Alice 100 Alice 1
+Alice 100 Alice 2
+Alice 50 Alice 1
+Alice 50 NULL NULL
+Bob 1 NULL NULL
+
+query TITI rowsort
+SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b < t1.b
+----
+Alice 100 Alice 1
+Alice 100 Alice 2
+Alice 50 Alice 1
+Alice 50 Alice 2
+Bob 1 NULL NULL
+
+# right join without join filter
+query TITI rowsort
+SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a
+----
+Alice 100 Alice 1
+Alice 100 Alice 2
+Alice 50 Alice 1
+Alice 50 Alice 2
+
+# right join with join filter
+query TITI rowsort
+SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b
+----
+Alice 100 Alice 1
+Alice 100 Alice 2
+Alice 50 Alice 1
+NULL NULL Alice 2
+
+query TITI rowsort
+SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t1.b > t2.b
+----
+Alice 100 Alice 1
+Alice 100 Alice 2
+Alice 50 Alice 1
+Alice 50 Alice 2
+
+# full join without join filter
+query TITI rowsort
+SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a
+----
+Alice 100 Alice 1
+Alice 100 Alice 2
+Alice 50 Alice 1
+Alice 50 Alice 2
+Bob 1 NULL NULL
+
+# full join with join filter
+query TITI rowsort
+SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b
+----
+Alice 100 NULL NULL
+Alice 100 NULL NULL
+Alice 50 Alice 2
+Alice 50 NULL NULL
+Bob 1 NULL NULL
+NULL NULL Alice 1
+NULL NULL Alice 1
+NULL NULL Alice 2
+
+query TITI rowsort
+SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50
+----
+Alice 100 Alice 1
+Alice 100 Alice 2
+Alice 50 NULL NULL
+Alice 50 NULL NULL
+Bob 1 NULL NULL
+NULL NULL Alice 1
+NULL NULL Alice 2
+
+statement ok
+DROP TABLE t1;
+
+statement ok
+DROP TABLE t2;
+
+statement ok
+CREATE TABLE IF NOT EXISTS t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES
+(11, 'a', 1),
+(22, 'b', 2),
+(33, 'c', 3),
+(44, 'd', 4);
+
+statement ok
+CREATE TABLE IF NOT EXISTS t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES
+(11, 'z', 3),
+(22, 'y', 1),
+(44, 'x', 3),
+(55, 'w', 3);
+
+# inner join with join filter
+query III rowsort
+SELECT t1_id, t1_int, t2_int FROM t1 JOIN t2 ON t1_id = t2_id AND t1_int >= 
t2_int
+----
+22 2 1
+44 4 3
+
+# equijoin_multiple_condition_ordering
+query ITT rowsort
+SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t1_name <> 
t2_name
+----
+11 a z
+22 b y
+44 d x
+
+# equijoin_right_and_condition_from_left
+query ITT rowsort
+SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND 
t1_id >= 22
+----
+22 b y
+44 d x
+NULL NULL w
+NULL NULL z
+
+# equijoin_left_and_condition_from_left
+query ITT rowsort
+SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_id 
>= 44
+----
+11 a NULL
+22 b NULL
+33 c NULL
+44 d x
+
+# equijoin_left_and_condition_from_both
+query III rowsort
+SELECT t1_id, t1_int, t2_int FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_int 
>= t2_int
+----
+11 1 NULL
+22 2 1
+33 3 NULL
+44 4 3
+
+# equijoin_right_and_condition_from_right
+query ITT rowsort
+SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND 
t2_id >= 22
+----
+22 b y
+44 d x
+NULL NULL w
+NULL NULL z
+
+# equijoin_right_and_condition_from_both
+query III rowsort
+SELECT t1_int, t2_int, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_int 
<= t1_int
+----
+2 1 22
+4 3 44
+NULL 3 11
+NULL 3 55
+
+# equijoin_full
+query ITIITI rowsort
+SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id
+----
+11 a 1 11 z 3
+22 b 2 22 y 1
+33 c 3 NULL NULL NULL
+44 d 4 44 x 3
+NULL NULL NULL 55 w 3
+
+# equijoin_full_and_condition_from_both
+query ITIITI rowsort
+SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id AND t2_int <= t1_int
+----
+11 a 1 NULL NULL NULL
+22 b 2 22 y 1
+33 c 3 NULL NULL NULL
+44 d 4 44 x 3
+NULL NULL NULL 11 z 3
+NULL NULL NULL 55 w 3
+
+statement ok
+DROP TABLE t1;
+
+statement ok
+DROP TABLE t2;
+
+statement ok
+set datafusion.optimizer.prefer_hash_join = true;

Reply via email to