This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new ef9df2960 Sort preserving `SortMergeJoin` (#2699)
ef9df2960 is described below

commit ef9df296013b1aef3a0116d174d6b89491173cdc
Author: Eduard Karacharov <[email protected]>
AuthorDate: Tue Jun 14 21:18:06 2022 +0300

    Sort preserving `SortMergeJoin` (#2699)
    
    * sort preserving merge join
    
    * Apply suggestions from code review
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../core/src/physical_plan/sort_merge_join.rs      | 502 +++++++++++++++------
 1 file changed, 362 insertions(+), 140 deletions(-)

diff --git a/datafusion/core/src/physical_plan/sort_merge_join.rs 
b/datafusion/core/src/physical_plan/sort_merge_join.rs
index e2248a99b..ffbd27df9 100644
--- a/datafusion/core/src/physical_plan/sort_merge_join.rs
+++ b/datafusion/core/src/physical_plan/sort_merge_join.rs
@@ -126,7 +126,13 @@ impl ExecutionPlan for SortMergeJoinExec {
     }
 
     fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
-        self.right.output_ordering()
+        match self.join_type {
+            JoinType::Inner | JoinType::Left | JoinType::Semi | JoinType::Anti 
=> {
+                self.left.output_ordering()
+            }
+            JoinType::Right => self.right.output_ordering(),
+            JoinType::Full => None,
+        }
     }
 
     fn relies_on_input_order(&self) -> bool {
@@ -300,11 +306,24 @@ enum BufferedState {
     Exhausted,
 }
 
+struct StreamedJoinedChunk {
+    /// Index of batch buffered_data
+    buffered_batch_idx: Option<usize>,
+    /// Array builder for streamed indices
+    streamed_indices: UInt64Builder,
+    /// Array builder for buffered indices
+    buffered_indices: UInt64Builder,
+}
+
 struct StreamedBatch {
     pub batch: RecordBatch,
     pub idx: usize,
     pub join_arrays: Vec<ArrayRef>,
-    pub null_joined: Vec<usize>,
+
+    // Chunks of indices from buffered side (may be nulls) joined to streamed
+    pub output_indices: Vec<StreamedJoinedChunk>,
+    // Index of currently scanned batch from buffered data
+    pub buffered_batch_idx: Option<usize>,
 }
 impl StreamedBatch {
     fn new(batch: RecordBatch, on_column: &[Column]) -> Self {
@@ -313,7 +332,8 @@ impl StreamedBatch {
             batch,
             idx: 0,
             join_arrays,
-            null_joined: vec![],
+            output_indices: vec![],
+            buffered_batch_idx: None,
         }
     }
 
@@ -322,8 +342,39 @@ impl StreamedBatch {
             batch: RecordBatch::new_empty(schema),
             idx: 0,
             join_arrays: vec![],
-            null_joined: vec![],
+            output_indices: vec![],
+            buffered_batch_idx: None,
+        }
+    }
+
+    /// Appends new pair consisting of current streamed index and 
`buffered_idx`
+    /// index of buffered batch with `buffered_batch_idx` index.
+    fn append_output_pair(
+        &mut self,
+        buffered_batch_idx: Option<usize>,
+        buffered_idx: Option<usize>,
+    ) -> ArrowResult<()> {
+        if self.output_indices.is_empty() || self.buffered_batch_idx != 
buffered_batch_idx
+        {
+            self.output_indices.push(StreamedJoinedChunk {
+                buffered_batch_idx,
+                streamed_indices: UInt64Builder::new(1),
+                buffered_indices: UInt64Builder::new(1),
+            });
+            self.buffered_batch_idx = buffered_batch_idx;
+        };
+        let current_chunk = self.output_indices.last_mut().unwrap();
+
+        current_chunk
+            .streamed_indices
+            .append_value(self.idx as u64)?;
+        if let Some(idx) = buffered_idx {
+            current_chunk.buffered_indices.append_value(idx as u64)?;
+        } else {
+            current_chunk.buffered_indices.append_null()?;
         }
+
+        Ok(())
     }
 }
 
@@ -338,8 +389,6 @@ struct BufferedBatch {
     pub join_arrays: Vec<ArrayRef>,
     /// Buffered joined index (null joining buffered)
     pub null_joined: Vec<usize>,
-    /// Buffered joined index (streamed joining buffered)
-    pub pair_joined: (Vec<usize>, Vec<usize>),
 }
 impl BufferedBatch {
     fn new(batch: RecordBatch, range: Range<usize>, on_column: &[Column]) -> 
Self {
@@ -349,7 +398,6 @@ impl BufferedBatch {
             range,
             join_arrays,
             null_joined: vec![],
-            pair_joined: (vec![], vec![]),
         }
     }
 }
@@ -571,7 +619,7 @@ impl SMJStream {
                     }
                     Poll::Ready(Some(batch)) => {
                         if batch.num_rows() > 0 {
-                            self.freeze_dequeuing_streamed()?;
+                            self.freeze_streamed()?;
                             self.join_metrics.input_batches.add(1);
                             self.join_metrics.input_rows.add(batch.num_rows());
                             self.streamed_batch =
@@ -757,16 +805,10 @@ impl SMJStream {
             {
                 let scanning_idx = self.buffered_data.scanning_idx();
                 if join_streamed {
-                    self.buffered_data
-                        .scanning_batch_mut()
-                        .pair_joined
-                        .0
-                        .push(self.streamed_batch.idx);
-                    self.buffered_data
-                        .scanning_batch_mut()
-                        .pair_joined
-                        .1
-                        .push(scanning_idx);
+                    self.streamed_batch.append_output_pair(
+                        Some(self.buffered_data.scanning_batch_idx),
+                        Some(scanning_idx),
+                    )?;
                 } else {
                     self.buffered_data
                         .scanning_batch_mut()
@@ -783,9 +825,14 @@ impl SMJStream {
             }
         } else {
             // joining streamed and nulls
+            let scanning_batch_idx = if self.buffered_data.scanning_finished() 
{
+                None
+            } else {
+                Some(self.buffered_data.scanning_batch_idx)
+            };
+
             self.streamed_batch
-                .null_joined
-                .push(self.streamed_batch.idx);
+                .append_output_pair(scanning_batch_idx, None)?;
             self.output_size += 1;
             self.buffered_data.scanning_finish();
             self.streamed_joined = true;
@@ -794,82 +841,26 @@ impl SMJStream {
     }
 
     fn freeze_all(&mut self) -> ArrowResult<()> {
-        self.freeze_streamed_join_null()?;
-        self.freeze_buffered_join_null(self.buffered_data.batches.len())?;
-        self.freeze_buffered_join_streamed(self.buffered_data.batches.len())?;
-        Ok(())
-    }
-
-    // freeze when a dequeueing streamed batch
-    fn freeze_dequeuing_streamed(&mut self) -> ArrowResult<()> {
-        self.freeze_streamed_join_null()?;
-        self.freeze_buffered_join_streamed(self.buffered_data.batches.len())?;
+        self.freeze_streamed()?;
+        self.freeze_buffered(self.buffered_data.batches.len())?;
         Ok(())
     }
 
-    // freeze when a dequeueing streamed batch
+    // Produces and stages record batches to ensure dequeued buffered batch
+    // no longer needed:
+    //   1. freezes all indices joined to streamed side
+    //   2. freezes NULLs joined to dequeued buffered batch to "release" it
     fn freeze_dequeuing_buffered(&mut self) -> ArrowResult<()> {
-        self.freeze_buffered_join_streamed(1)?;
-        self.freeze_buffered_join_null(1)?;
-        Ok(())
-    }
-
-    // join_type must be one of: `Left`/`Right`/`Full`/`Semi`/`Anti`
-    fn freeze_streamed_join_null(&mut self) -> ArrowResult<()> {
-        if !matches!(
-            self.join_type,
-            JoinType::Left
-                | JoinType::Right
-                | JoinType::Full
-                | JoinType::Semi
-                | JoinType::Anti
-        ) {
-            return Ok(());
-        }
-        let streamed_indices = UInt64Array::from_iter_values(
-            self.streamed_batch
-                .null_joined
-                .iter()
-                .map(|&index| index as u64),
-        );
-        if streamed_indices.is_empty() {
-            return Ok(());
-        }
-        self.streamed_batch.null_joined.clear();
-
-        let mut streamed_columns = self
-            .streamed_batch
-            .batch
-            .columns()
-            .iter()
-            .map(|column| take(column, &streamed_indices, None))
-            .collect::<ArrowResult<Vec<_>>>()?;
-
-        let columns = if matches!(self.join_type, JoinType::Semi | 
JoinType::Anti) {
-            streamed_columns
-        } else {
-            let mut buffered_columns = self
-                .buffered_schema
-                .fields()
-                .iter()
-                .map(|f| new_null_array(f.data_type(), streamed_indices.len()))
-                .collect::<Vec<_>>();
-
-            if matches!(self.join_type, JoinType::Right) {
-                buffered_columns.extend(streamed_columns);
-                buffered_columns
-            } else {
-                streamed_columns.extend(buffered_columns);
-                streamed_columns
-            }
-        };
-        self.output_record_batches
-            .push(RecordBatch::try_new(self.schema.clone(), columns)?);
+        self.freeze_streamed()?;
+        self.freeze_buffered(1)?;
         Ok(())
     }
 
-    // join_type must be `Full`
-    fn freeze_buffered_join_null(&mut self, batch_count: usize) -> 
ArrowResult<()> {
+    // Produces and stages record batch from buffered indices with 
corresponding
+    // NULLs on streamed side.
+    //
+    // Applicable only in case of Full join.
+    fn freeze_buffered(&mut self, batch_count: usize) -> ArrowResult<()> {
         if !matches!(self.join_type, JoinType::Full) {
             return Ok(());
         }
@@ -905,41 +896,15 @@ impl SMJStream {
         Ok(())
     }
 
-    // join_type must be `Inner`/`Left`/`Right`/`Full`
-    fn freeze_buffered_join_streamed(&mut self, batch_count: usize) -> 
ArrowResult<()> {
-        if !matches!(
-            self.join_type,
-            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
-        ) {
-            return Ok(());
-        }
-        for buffered_batch in 
self.buffered_data.batches.range_mut(..batch_count) {
-            let buffered_indices = UInt64Array::from_iter_values(
-                buffered_batch
-                    .pair_joined
-                    .1
-                    .iter()
-                    .map(|&index| index as u64),
-            );
-            let streamed_indices = UInt64Array::from_iter_values(
-                buffered_batch
-                    .pair_joined
-                    .0
-                    .iter()
-                    .map(|&index| index as u64),
-            );
-            if buffered_indices.is_empty() {
+    // Produces and stages record batch for all output indices found
+    // for current streamed batch and clears staged output indices.
+    fn freeze_streamed(&mut self) -> ArrowResult<()> {
+        for chunk in self.streamed_batch.output_indices.iter_mut() {
+            let streamed_indices = chunk.streamed_indices.finish();
+
+            if streamed_indices.is_empty() {
                 continue;
             }
-            buffered_batch.pair_joined.0.clear();
-            buffered_batch.pair_joined.1.clear();
-
-            let mut buffered_columns = buffered_batch
-                .batch
-                .columns()
-                .iter()
-                .map(|column| take(column, &buffered_indices, None))
-                .collect::<ArrowResult<Vec<_>>>()?;
 
             let mut streamed_columns = self
                 .streamed_batch
@@ -949,6 +914,26 @@ impl SMJStream {
                 .map(|column| take(column, &streamed_indices, None))
                 .collect::<ArrowResult<Vec<_>>>()?;
 
+            let buffered_indices: UInt64Array = 
chunk.buffered_indices.finish();
+
+            let mut buffered_columns =
+                if matches!(self.join_type, JoinType::Semi | JoinType::Anti) {
+                    vec![]
+                } else if let Some(buffered_idx) = chunk.buffered_batch_idx {
+                    self.buffered_data.batches[buffered_idx]
+                        .batch
+                        .columns()
+                        .iter()
+                        .map(|column| take(column, &buffered_indices, None))
+                        .collect::<ArrowResult<Vec<_>>>()?
+                } else {
+                    self.buffered_schema
+                        .fields()
+                        .iter()
+                        .map(|f| new_null_array(f.data_type(), 
buffered_indices.len()))
+                        .collect::<Vec<_>>()
+                };
+
             let columns = if matches!(self.join_type, JoinType::Right) {
                 buffered_columns.extend(streamed_columns);
                 buffered_columns
@@ -960,6 +945,9 @@ impl SMJStream {
             self.output_record_batches
                 .push(RecordBatch::try_new(self.schema.clone(), columns)?);
         }
+
+        self.streamed_batch.output_indices.clear();
+
         Ok(())
     }
 
@@ -1211,7 +1199,6 @@ mod tests {
     use arrow::datatypes::{DataType, Field, Schema};
     use arrow::record_batch::RecordBatch;
 
-    use crate::assert_batches_sorted_eq;
     use crate::error::Result;
     use crate::logical_plan::JoinType;
     use crate::physical_plan::expressions::Column;
@@ -1221,6 +1208,7 @@ mod tests {
     use crate::physical_plan::{common, ExecutionPlan};
     use crate::prelude::{SessionConfig, SessionContext};
     use crate::test::{build_table_i32, columns};
+    use crate::{assert_batches_eq, assert_batches_sorted_eq};
 
     fn build_table(
         a: (&str, &Vec<i32>),
@@ -1232,6 +1220,11 @@ mod tests {
         Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
     }
 
+    fn build_table_from_batches(batches: Vec<RecordBatch>) -> Arc<dyn 
ExecutionPlan> {
+        let schema = batches.first().unwrap().schema();
+        Arc::new(MemoryExec::try_new(&[batches], schema, None).unwrap())
+    }
+
     fn build_date_table(
         a: (&str, &Vec<i32>),
         b: (&str, &Vec<i32>),
@@ -1414,7 +1407,8 @@ mod tests {
             "| 3  | 5  | 9  | 20 | 5  | 80 |",
             "+----+----+----+----+----+----+",
         ];
-        assert_batches_sorted_eq!(expected, &batches);
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
         Ok(())
     }
 
@@ -1451,7 +1445,8 @@ mod tests {
             "| 2  | 2  | 9  | 2  | 2  | 80 |",
             "+----+----+----+----+----+----+",
         ];
-        assert_batches_sorted_eq!(expected, &batches);
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
         Ok(())
     }
 
@@ -1489,7 +1484,8 @@ mod tests {
             "| 1  | 1  | 8  | 1  | 1  | 80 |",
             "+----+----+----+----+----+----+",
         ];
-        assert_batches_sorted_eq!(expected, &batches);
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
         Ok(())
     }
 
@@ -1526,7 +1522,8 @@ mod tests {
             "| 2  | 2  | 9  | 2  | 2  | 80 |",
             "+----+----+----+----+----+----+",
         ];
-        assert_batches_sorted_eq!(expected, &batches);
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
         Ok(())
     }
 
@@ -1571,14 +1568,14 @@ mod tests {
             "+----+----+----+----+----+----+",
             "| a1 | b2 | c1 | a1 | b2 | c2 |",
             "+----+----+----+----+----+----+",
-            "| 1  |    | 1  | 1  |    | 10 |",
-            "| 1  | 1  |    | 1  | 1  | 70 |",
-            "| 2  | 2  | 8  | 2  | 2  | 80 |",
             "| 2  | 2  | 9  | 2  | 2  | 80 |",
+            "| 2  | 2  | 8  | 2  | 2  | 80 |",
+            "| 1  | 1  |    | 1  | 1  | 70 |",
+            "| 1  |    | 1  | 1  |    | 10 |",
             "+----+----+----+----+----+----+",
         ];
-        //assert_eq!(batches.len(), 1);
-        assert_batches_sorted_eq!(expected, &batches);
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
         Ok(())
     }
 
@@ -1619,7 +1616,8 @@ mod tests {
         assert_eq!(batches.len(), 2);
         assert_eq!(batches[0].num_rows(), 2);
         assert_eq!(batches[1].num_rows(), 1);
-        assert_batches_sorted_eq!(expected, &batches);
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
         Ok(())
     }
 
@@ -1650,7 +1648,8 @@ mod tests {
             "| 3  | 7  | 9  |    |    |    |",
             "+----+----+----+----+----+----+",
         ];
-        assert_batches_sorted_eq!(expected, &batches);
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
         Ok(())
     }
 
@@ -1676,12 +1675,13 @@ mod tests {
             "+----+----+----+----+----+----+",
             "| a1 | b1 | c1 | a2 | b1 | c2 |",
             "+----+----+----+----+----+----+",
-            "|    |    |    | 30 | 6  | 90 |",
             "| 1  | 4  | 7  | 10 | 4  | 70 |",
             "| 2  | 5  | 8  | 20 | 5  | 80 |",
+            "|    |    |    | 30 | 6  | 90 |",
             "+----+----+----+----+----+----+",
         ];
-        assert_batches_sorted_eq!(expected, &batches);
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
         Ok(())
     }
 
@@ -1743,7 +1743,8 @@ mod tests {
             "| 5  | 7  | 11 |",
             "+----+----+----+",
         ];
-        assert_batches_sorted_eq!(expected, &batches);
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
         Ok(())
     }
 
@@ -1774,7 +1775,8 @@ mod tests {
             "| 2  | 5  | 8  |",
             "+----+----+----+",
         ];
-        assert_batches_sorted_eq!(expected, &batches);
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
         Ok(())
     }
 
@@ -1805,7 +1807,8 @@ mod tests {
             "| 2 | 5 | 8 | 20 | 2 | 80 |",
             "+---+---+---+----+---+----+",
         ];
-        assert_batches_sorted_eq!(expected, &batches);
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
         Ok(())
     }
 
@@ -1838,7 +1841,8 @@ mod tests {
             "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 
| 1970-03-22 |",
             
"+------------+------------+------------+------------+------------+------------+",
         ];
-        assert_batches_sorted_eq!(expected, &batches);
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
         Ok(())
     }
 
@@ -1871,6 +1875,224 @@ mod tests {
             "| 1970-01-01 | 2022-04-25 | 1970-01-01 | 1970-01-01 | 2022-04-25 
| 1970-01-01 |",
             
"+------------+------------+------------+------------+------------+------------+",
         ];
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn join_left_sort_order() -> Result<()> {
+        let left = build_table(
+            ("a1", &vec![0, 1, 2, 3, 4, 5]),
+            ("b1", &vec![3, 4, 5, 6, 6, 7]),
+            ("c1", &vec![4, 5, 6, 7, 8, 9]),
+        );
+        let right = build_table(
+            ("a2", &vec![0, 10, 20, 30, 40]),
+            ("b2", &vec![2, 4, 6, 6, 8]),
+            ("c2", &vec![50, 60, 70, 80, 90]),
+        );
+        let on = vec![(
+            Column::new_with_schema("b1", &left.schema())?,
+            Column::new_with_schema("b2", &right.schema())?,
+        )];
+
+        let (_, batches) = join_collect(left, right, on, 
JoinType::Left).await?;
+        let expected = vec![
+            "+----+----+----+----+----+----+",
+            "| a1 | b1 | c1 | a2 | b2 | c2 |",
+            "+----+----+----+----+----+----+",
+            "| 0  | 3  | 4  |    |    |    |",
+            "| 1  | 4  | 5  | 10 | 4  | 60 |",
+            "| 2  | 5  | 6  |    |    |    |",
+            "| 3  | 6  | 7  | 20 | 6  | 70 |",
+            "| 3  | 6  | 7  | 30 | 6  | 80 |",
+            "| 4  | 6  | 8  | 20 | 6  | 70 |",
+            "| 4  | 6  | 8  | 30 | 6  | 80 |",
+            "| 5  | 7  | 9  |    |    |    |",
+            "+----+----+----+----+----+----+",
+        ];
+        assert_batches_eq!(expected, &batches);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn join_right_sort_order() -> Result<()> {
+        let left = build_table(
+            ("a1", &vec![0, 1, 2, 3]),
+            ("b1", &vec![3, 4, 5, 7]),
+            ("c1", &vec![6, 7, 8, 9]),
+        );
+        let right = build_table(
+            ("a2", &vec![0, 10, 20, 30]),
+            ("b2", &vec![2, 4, 5, 6]),
+            ("c2", &vec![60, 70, 80, 90]),
+        );
+        let on = vec![(
+            Column::new_with_schema("b1", &left.schema())?,
+            Column::new_with_schema("b2", &right.schema())?,
+        )];
+
+        let (_, batches) = join_collect(left, right, on, 
JoinType::Right).await?;
+        let expected = vec![
+            "+----+----+----+----+----+----+",
+            "| a1 | b1 | c1 | a2 | b2 | c2 |",
+            "+----+----+----+----+----+----+",
+            "|    |    |    | 0  | 2  | 60 |",
+            "| 1  | 4  | 7  | 10 | 4  | 70 |",
+            "| 2  | 5  | 8  | 20 | 5  | 80 |",
+            "|    |    |    | 30 | 6  | 90 |",
+            "+----+----+----+----+----+----+",
+        ];
+        assert_batches_eq!(expected, &batches);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn join_left_multiple_batches() -> Result<()> {
+        let left_batch_1 = build_table_i32(
+            ("a1", &vec![0, 1, 2]),
+            ("b1", &vec![3, 4, 5]),
+            ("c1", &vec![4, 5, 6]),
+        );
+        let left_batch_2 = build_table_i32(
+            ("a1", &vec![3, 4, 5, 6]),
+            ("b1", &vec![6, 6, 7, 9]),
+            ("c1", &vec![7, 8, 9, 9]),
+        );
+        let right_batch_1 = build_table_i32(
+            ("a2", &vec![0, 10, 20]),
+            ("b2", &vec![2, 4, 6]),
+            ("c2", &vec![50, 60, 70]),
+        );
+        let right_batch_2 = build_table_i32(
+            ("a2", &vec![30, 40]),
+            ("b2", &vec![6, 8]),
+            ("c2", &vec![80, 90]),
+        );
+        let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
+        let right = build_table_from_batches(vec![right_batch_1, 
right_batch_2]);
+        let on = vec![(
+            Column::new_with_schema("b1", &left.schema())?,
+            Column::new_with_schema("b2", &right.schema())?,
+        )];
+
+        let (_, batches) = join_collect(left, right, on, 
JoinType::Left).await?;
+        let expected = vec![
+            "+----+----+----+----+----+----+",
+            "| a1 | b1 | c1 | a2 | b2 | c2 |",
+            "+----+----+----+----+----+----+",
+            "| 0  | 3  | 4  |    |    |    |",
+            "| 1  | 4  | 5  | 10 | 4  | 60 |",
+            "| 2  | 5  | 6  |    |    |    |",
+            "| 3  | 6  | 7  | 20 | 6  | 70 |",
+            "| 3  | 6  | 7  | 30 | 6  | 80 |",
+            "| 4  | 6  | 8  | 20 | 6  | 70 |",
+            "| 4  | 6  | 8  | 30 | 6  | 80 |",
+            "| 5  | 7  | 9  |    |    |    |",
+            "| 6  | 9  | 9  |    |    |    |",
+            "+----+----+----+----+----+----+",
+        ];
+        assert_batches_eq!(expected, &batches);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn join_right_multiple_batches() -> Result<()> {
+        let right_batch_1 = build_table_i32(
+            ("a2", &vec![0, 1, 2]),
+            ("b2", &vec![3, 4, 5]),
+            ("c2", &vec![4, 5, 6]),
+        );
+        let right_batch_2 = build_table_i32(
+            ("a2", &vec![3, 4, 5, 6]),
+            ("b2", &vec![6, 6, 7, 9]),
+            ("c2", &vec![7, 8, 9, 9]),
+        );
+        let left_batch_1 = build_table_i32(
+            ("a1", &vec![0, 10, 20]),
+            ("b1", &vec![2, 4, 6]),
+            ("c1", &vec![50, 60, 70]),
+        );
+        let left_batch_2 = build_table_i32(
+            ("a1", &vec![30, 40]),
+            ("b1", &vec![6, 8]),
+            ("c1", &vec![80, 90]),
+        );
+        let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
+        let right = build_table_from_batches(vec![right_batch_1, 
right_batch_2]);
+        let on = vec![(
+            Column::new_with_schema("b1", &left.schema())?,
+            Column::new_with_schema("b2", &right.schema())?,
+        )];
+
+        let (_, batches) = join_collect(left, right, on, 
JoinType::Right).await?;
+        let expected = vec![
+            "+----+----+----+----+----+----+",
+            "| a1 | b1 | c1 | a2 | b2 | c2 |",
+            "+----+----+----+----+----+----+",
+            "|    |    |    | 0  | 3  | 4  |",
+            "| 10 | 4  | 60 | 1  | 4  | 5  |",
+            "|    |    |    | 2  | 5  | 6  |",
+            "| 20 | 6  | 70 | 3  | 6  | 7  |",
+            "| 30 | 6  | 80 | 3  | 6  | 7  |",
+            "| 20 | 6  | 70 | 4  | 6  | 8  |",
+            "| 30 | 6  | 80 | 4  | 6  | 8  |",
+            "|    |    |    | 5  | 7  | 9  |",
+            "|    |    |    | 6  | 9  | 9  |",
+            "+----+----+----+----+----+----+",
+        ];
+        assert_batches_eq!(expected, &batches);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn join_full_multiple_batches() -> Result<()> {
+        let left_batch_1 = build_table_i32(
+            ("a1", &vec![0, 1, 2]),
+            ("b1", &vec![3, 4, 5]),
+            ("c1", &vec![4, 5, 6]),
+        );
+        let left_batch_2 = build_table_i32(
+            ("a1", &vec![3, 4, 5, 6]),
+            ("b1", &vec![6, 6, 7, 9]),
+            ("c1", &vec![7, 8, 9, 9]),
+        );
+        let right_batch_1 = build_table_i32(
+            ("a2", &vec![0, 10, 20]),
+            ("b2", &vec![2, 4, 6]),
+            ("c2", &vec![50, 60, 70]),
+        );
+        let right_batch_2 = build_table_i32(
+            ("a2", &vec![30, 40]),
+            ("b2", &vec![6, 8]),
+            ("c2", &vec![80, 90]),
+        );
+        let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
+        let right = build_table_from_batches(vec![right_batch_1, 
right_batch_2]);
+        let on = vec![(
+            Column::new_with_schema("b1", &left.schema())?,
+            Column::new_with_schema("b2", &right.schema())?,
+        )];
+
+        let (_, batches) = join_collect(left, right, on, 
JoinType::Full).await?;
+        let expected = vec![
+            "+----+----+----+----+----+----+",
+            "| a1 | b1 | c1 | a2 | b2 | c2 |",
+            "+----+----+----+----+----+----+",
+            "|    |    |    | 0  | 2  | 50 |",
+            "|    |    |    | 40 | 8  | 90 |",
+            "| 0  | 3  | 4  |    |    |    |",
+            "| 1  | 4  | 5  | 10 | 4  | 60 |",
+            "| 2  | 5  | 6  |    |    |    |",
+            "| 3  | 6  | 7  | 20 | 6  | 70 |",
+            "| 3  | 6  | 7  | 30 | 6  | 80 |",
+            "| 4  | 6  | 8  | 20 | 6  | 70 |",
+            "| 4  | 6  | 8  | 30 | 6  | 80 |",
+            "| 5  | 7  | 9  |    |    |    |",
+            "| 6  | 9  | 9  |    |    |    |",
+            "+----+----+----+----+----+----+",
+        ];
         assert_batches_sorted_eq!(expected, &batches);
         Ok(())
     }

Reply via email to