liukun4515 commented on code in PR #4562:
URL: https://github.com/apache/arrow-datafusion/pull/4562#discussion_r1044065982


##########
datafusion/core/src/physical_plan/joins/nested_loop_join.rs:
##########
@@ -0,0 +1,872 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines the nested loop join plan, it just support non-equal join.
+//! The nested loop join can execute in parallel by partitions and it is
+//! determined by the [`JoinType`].
+
+use crate::physical_plan::joins::utils::{
+    adjust_indices_by_join_type, adjust_right_output_partitioning,
+    apply_join_filter_to_indices, build_batch_from_indices, build_join_schema,
+    check_join_is_valid, combine_join_equivalence_properties, 
estimate_join_statistics,
+    get_final_indices, need_produce_result_in_final, ColumnIndex, JoinFilter, 
OnceAsync,
+    OnceFut,
+};
+use crate::physical_plan::{
+    DisplayFormatType, Distribution, ExecutionPlan, Partitioning, 
RecordBatchStream,
+    SendableRecordBatchStream,
+};
+use arrow::array::{
+    BooleanBufferBuilder, UInt32Array, UInt32Builder, UInt64Array, 
UInt64Builder,
+};
+use arrow::datatypes::{Schema, SchemaRef};
+use arrow::error::{ArrowError, Result as ArrowResult};
+use arrow::record_batch::RecordBatch;
+use datafusion_common::Statistics;
+use datafusion_expr::JoinType;
+use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortExpr};
+use futures::{ready, Stream, StreamExt, TryStreamExt};
+use log::debug;
+use std::any::Any;
+use std::fmt::Formatter;
+use std::sync::Arc;
+use std::task::Poll;
+use std::time::Instant;
+
+use crate::error::Result;
+use crate::execution::context::TaskContext;
+use crate::physical_plan::coalesce_batches::concat_batches;
+
+/// Data of the left side
+type JoinLeftData = RecordBatch;
+
+///
+#[derive(Debug)]
+pub struct NestedLoopJoinExec {
+    /// left side
+    pub(crate) left: Arc<dyn ExecutionPlan>,
+    /// right side
+    pub(crate) right: Arc<dyn ExecutionPlan>,
+    /// Filters which are applied while finding matching rows
+    pub(crate) filter: Option<JoinFilter>,
+    /// How the join is performed
+    pub(crate) join_type: JoinType,
+    /// The schema once the join is applied
+    schema: SchemaRef,
+    /// Build-side data
+    left_fut: OnceAsync<JoinLeftData>,
+    /// Information of index and left / right placement of columns
+    column_indices: Vec<ColumnIndex>,
+}
+
+impl NestedLoopJoinExec {
+    /// Try to create a nwe [`NestedLoopJoinExec`]
+    pub fn try_new(
+        left: Arc<dyn ExecutionPlan>,
+        right: Arc<dyn ExecutionPlan>,
+        filter: Option<JoinFilter>,
+        join_type: &JoinType,
+    ) -> Result<Self> {
+        let left_schema = left.schema();
+        let right_schema = right.schema();
+        check_join_is_valid(&left_schema, &right_schema, &[])?;
+        let (schema, column_indices) =
+            build_join_schema(&left_schema, &right_schema, join_type);
+        Ok(NestedLoopJoinExec {
+            left,
+            right,
+            filter,
+            join_type: *join_type,
+            schema: Arc::new(schema),
+            left_fut: Default::default(),
+            column_indices,
+        })
+    }
+
+    fn is_single_partition_for_left(&self) -> bool {
+        matches!(
+            self.required_input_distribution()[0],
+            Distribution::SinglePartition
+        )
+    }
+}
+
+impl ExecutionPlan for NestedLoopJoinExec {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+
+    fn output_partitioning(&self) -> Partitioning {
+        // the partition of output is determined by the rule of 
`required_input_distribution`
+        // TODO we can replace it by `partitioned_join_output_partitioning`
+        match self.join_type {
+            // use the left partition
+            JoinType::Inner
+            | JoinType::Left
+            | JoinType::LeftSemi
+            | JoinType::LeftAnti
+            | JoinType::Full => self.left.output_partitioning(),
+            // use the right partition
+            JoinType::Right => {
+                // if the partition of right is hash, and should adjust the 
column index for the
+                // right expr
+                adjust_right_output_partitioning(
+                    self.right.output_partitioning(),
+                    self.left.schema().fields.len(),
+                )
+            }
+            // use the right partition
+            JoinType::RightSemi | JoinType::RightAnti => 
self.right.output_partitioning(),
+        }
+    }
+
+    fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
+        // no specified order for the output
+        None
+    }
+
+    fn required_input_distribution(&self) -> Vec<Distribution> {
+        distribution_from_join_type(&self.join_type)
+    }
+
+    fn equivalence_properties(&self) -> EquivalenceProperties {
+        let left_columns_len = self.left.schema().fields.len();
+        combine_join_equivalence_properties(
+            self.join_type,
+            self.left.equivalence_properties(),
+            self.right.equivalence_properties(),
+            left_columns_len,
+            &[], // empty join keys
+            self.schema(),
+        )
+    }
+
+    fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+        vec![self.left.clone(), self.right.clone()]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        children: Vec<Arc<dyn ExecutionPlan>>,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        Ok(Arc::new(NestedLoopJoinExec::try_new(
+            children[0].clone(),
+            children[1].clone(),
+            self.filter.clone(),
+            &self.join_type,
+        )?))
+    }
+
+    fn execute(
+        &self,
+        partition: usize,
+        context: Arc<TaskContext>,
+    ) -> Result<SendableRecordBatchStream> {
+        // if the distribution of left is `SinglePartition`, just need to 
collect the left one
+        let left_is_single_partition = self.is_single_partition_for_left();
+        // left side
+        let left_fut = if left_is_single_partition {
+            self.left_fut.once(|| {
+                // just one partition for the left side
+                load_specified_partition_input(0, self.left.clone(), 
context.clone())
+            })
+        } else {
+            // the distribution of left is not single partition, just need the 
specified partition for left
+            OnceFut::new(load_specified_partition_input(
+                partition,
+                self.left.clone(),
+                context.clone(),
+            ))
+        };
+        // right side
+        let right_side = if left_is_single_partition {
+            self.right.execute(partition, context)?
+        } else {
+            // the distribution of right is `SinglePartition`
+            self.right.execute(0, context)?
+        };
+
+        Ok(Box::pin(NestedLoopJoinStream {
+            schema: self.schema.clone(),
+            filter: self.filter.clone(),
+            join_type: self.join_type,
+            left_fut,
+            right: right_side,
+            is_exhausted: false,
+            visited_left_side: None,
+            column_indices: self.column_indices.clone(),
+        }))
+    }
+
+    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> 
std::fmt::Result {
+        match t {
+            DisplayFormatType::Default => {
+                let display_filter = self.filter.as_ref().map_or_else(
+                    || "".to_string(),
+                    |f| format!(", filter={:?}", f.expression()),
+                );
+                write!(
+                    f,
+                    "NestedLoopJoinExec: join_type={:?}{}",
+                    self.join_type, display_filter
+                )
+            }
+        }
+    }
+
+    fn statistics(&self) -> Statistics {
+        estimate_join_statistics(
+            self.left.clone(),
+            self.right.clone(),
+            vec![],
+            &self.join_type,
+        )
+    }
+}
+
+// For the nested loop join, different `JoinType` need the different 
distribution for
+// left and right node.
+fn distribution_from_join_type(join_type: &JoinType) -> Vec<Distribution> {
+    match join_type {
+        JoinType::Inner | JoinType::Left | JoinType::LeftSemi | 
JoinType::LeftAnti => {
+            // need the left data, and the right should be one partition
+            vec![
+                Distribution::UnspecifiedDistribution,
+                Distribution::SinglePartition,
+            ]
+        }
+        JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
+            // need the right data, and the left should be one partition
+            vec![
+                Distribution::SinglePartition,
+                Distribution::UnspecifiedDistribution,
+            ]
+        }
+        JoinType::Full => {
+            // need the left and right data, and the left and right should be 
one partition
+            vec![Distribution::SinglePartition, Distribution::SinglePartition]
+        }
+    }
+}
+
+/// Asynchronously collect the result of the left child for the specified 
partition
+async fn load_specified_partition_input(
+    partition: usize,
+    left: Arc<dyn ExecutionPlan>,
+    context: Arc<TaskContext>,
+) -> Result<JoinLeftData> {
+    let start = Instant::now();
+    let stream = left.execute(partition, context)?;
+
+    // Load all batches and count the rows
+    let (batches, num_rows) = stream
+        .try_fold((Vec::new(), 0usize), |mut acc, batch| async {
+            acc.1 += batch.num_rows();
+            acc.0.push(batch);
+            Ok(acc)
+        })
+        .await?;
+
+    let merged_batch = concat_batches(&left.schema(), &batches, num_rows)?;
+
+    debug!(
+        "Built left-side of nested loop join containing {} rows in {} ms for 
partition {}",
+        num_rows,
+        start.elapsed().as_millis(),
+        partition
+    );
+
+    Ok(merged_batch)
+}
+
+/// A stream that issues [RecordBatch]es as they arrive from the right  of the 
join.
+struct NestedLoopJoinStream {
+    /// Input schema
+    schema: Arc<Schema>,
+    /// join filter
+    filter: Option<JoinFilter>,
+    /// type of the join
+    join_type: JoinType,
+    /// future for data from left side
+    left_fut: OnceFut<JoinLeftData>,
+    /// right
+    right: SendableRecordBatchStream,
+    /// There is nothing to process anymore and left side is processed in case 
of left/left semi/left anti/full join
+    is_exhausted: bool,
+    /// Keeps track of the left side rows whether they are visited
+    visited_left_side: Option<BooleanBufferBuilder>,
+    /// Information of index and left / right placement of columns
+    column_indices: Vec<ColumnIndex>,
+    // TODO: support null aware equal
+    // null_equals_null: bool
+}
+
+fn build_join_indices(
+    left_index: usize,
+    batch: &RecordBatch,
+    left_data: &JoinLeftData,
+    filter: Option<&JoinFilter>,
+) -> Result<(UInt64Array, UInt32Array)> {
+    let right_row_count = batch.num_rows();
+    // left indices: [left_index, left_index, ...., left_index]
+    // right indices: [0, 1, 2, 3, 4,....,right_row_count]
+    let left_indices = UInt64Array::from(vec![left_index as u64; 
right_row_count]);
+    let right_indices = UInt32Array::from_iter_values(0..(right_row_count as 
u32));
+    if let Some(filter) = filter {
+        // Filter the indices which is satisfies the non-equal join condition, 
like `left.b1 = 10`
+        apply_join_filter_to_indices(
+            left_data,
+            batch,
+            left_indices,
+            right_indices,
+            filter,
+        )
+    } else {
+        Ok((left_indices, right_indices))
+    }
+}
+
+impl NestedLoopJoinStream {
+    fn poll_next_impl(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Option<ArrowResult<RecordBatch>>> {
+        // all left row
+        let left_data = match ready!(self.left_fut.get(cx)) {
+            Ok(left_data) => left_data,
+            Err(e) => return Poll::Ready(Some(Err(e))),
+        };
+
+        let visited_left_side = self.visited_left_side.get_or_insert_with(|| {
+            let left_num_rows = left_data.num_rows();
+            if need_produce_result_in_final(self.join_type) {
+                // these join type need the bitmap to identify which row has 
be matched or unmatched.
+                // For the `left semi` join, need to use the bitmap to produce 
the matched row in the left side
+                // For the `left` join, need to use the bitmap to produce the 
unmatched row in the left side with null
+                // For the `left anti` join, need to use the bitmap to produce 
the unmatched row in the left side
+                // For the `full` join, need to use the bitmap to produce the 
unmatched row in the left side with null
+                let mut buffer = BooleanBufferBuilder::new(left_num_rows);
+                buffer.append_n(left_num_rows, false);
+                buffer
+            } else {
+                BooleanBufferBuilder::new(0)
+            }
+        });
+
+        // iter the right batch
+        self.right
+            .poll_next_unpin(cx)
+            .map(|maybe_batch| match maybe_batch {
+                Some(Ok(right_batch)) => {
+                    // get the matched left and right indices
+                    // each left row will try to match every right row
+                    let indices_result = (0..left_data.num_rows())
+                        .map(|left_row_index| {
+                            build_join_indices(
+                                left_row_index,
+                                &right_batch,
+                                left_data,
+                                self.filter.as_ref(),
+                            )
+                        })
+                        .collect::<Result<Vec<(UInt64Array, UInt32Array)>>>();
+                    let mut left_indices_builder = UInt64Builder::new();
+                    let mut right_indices_builder = UInt32Builder::new();
+                    let left_right_indices = match indices_result {
+                        Err(_) => {
+                            // TODO why the type of result stream is 
`Result<T, ArrowError>`, and not the `DataFusionError`
+                            Err(ArrowError::ComputeError(
+                                "Build left right indices error".to_string(),
+                            ))
+                        }
+                        Ok(indices) => {
+                            for (left_side, right_side) in indices {
+                                left_indices_builder.append_values(
+                                    left_side.values(),
+                                    &vec![true; left_side.len()],
+                                );
+                                right_indices_builder.append_values(
+                                    right_side.values(),
+                                    &vec![true; right_side.len()],
+                                );
+                            }
+                            Ok((
+                                left_indices_builder.finish(),
+                                right_indices_builder.finish(),
+                            ))
+                        }
+                    };
+                    let result = match left_right_indices {
+                        Ok((left_side, right_side)) => {
+                            // set the left bitmap
+                            // and only left, full, left semi, left anti need 
the left bitmap
+                            if need_produce_result_in_final(self.join_type) {
+                                left_side.iter().flatten().for_each(|x| {
+                                    visited_left_side.set_bit(x as usize, 
true);
+                                });
+                            }
+                            // adjust the two side indices base on the join 
type
+                            let (left_side, right_side) = 
adjust_indices_by_join_type(
+                                left_side,
+                                right_side,
+                                right_batch.num_rows(),
+                                self.join_type,
+                            );
+
+                            let result = build_batch_from_indices(
+                                &self.schema,
+                                left_data,
+                                &right_batch,
+                                left_side,
+                                right_side,
+                                &self.column_indices,
+                            );
+                            Some(result)
+                        }
+                        Err(e) => Some(Err(e)),
+                    };
+                    result
+                }
+                Some(err) => Some(err),
+                None => {
+                    if need_produce_result_in_final(self.join_type) && 
!self.is_exhausted
+                    {
+                        // use the global left bitmap to produce the left 
indices and right indices
+                        let (left_side, right_side) =
+                            get_final_indices(visited_left_side, 
self.join_type);
+                        let empty_right_batch =
+                            RecordBatch::new_empty(self.right.schema());
+                        // use the left and right indices to produce the batch 
result
+                        let result = build_batch_from_indices(
+                            &self.schema,
+                            left_data,
+                            &empty_right_batch,
+                            left_side,
+                            right_side,
+                            &self.column_indices,
+                        );
+                        self.is_exhausted = true;
+                        Some(result)
+                    } else {
+                        // end of the join loop
+                        None
+                    }
+                }
+            })
+    }
+}
+
+impl Stream for NestedLoopJoinStream {
+    type Item = ArrowResult<RecordBatch>;
+
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        self.poll_next_impl(cx)
+    }
+}
+
+impl RecordBatchStream for NestedLoopJoinStream {
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}
+
+#[cfg(test)]
+mod tests {

Review Comment:
   there are som duplicated test code in the `hashjoin` and `cross join`.
   
   I will refactor and clean up them in the followup pr



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to