mbutrovich commented on code in PR #21184:
URL: https://github.com/apache/datafusion/pull/21184#discussion_r3002682773


##########
datafusion/physical-plan/src/joins/sort_merge_join/stream.rs:
##########
@@ -1267,209 +1231,294 @@ impl SortMergeJoinStream {
 
     // Produces and stages record batch for all output indices found
     // for current streamed batch and clears staged output indices.
+    //
+    // Null-joined chunks (no buffered match) are pushed immediately.
+    // Matched chunks are collected and processed together in
+    // freeze_streamed_matched() to amortize filter evaluation overhead.
     fn freeze_streamed(&mut self) -> Result<()> {
+        let mut matched_chunks: Vec<(usize, UInt64Array, UInt64Array)> = 
Vec::new();
+        let mut total_matched_rows: usize = 0;
+
         for chunk in self.streamed_batch.output_indices.iter_mut() {
-            // The row indices of joined streamed batch
             let left_indices = chunk.streamed_indices.finish();
-
             if left_indices.is_empty() {
                 continue;
             }
+            let right_indices: UInt64Array = chunk.buffered_indices.finish();
 
-            let mut left_columns = if let Some(range) = 
is_contiguous_range(&left_indices)
-            {
-                // When indices form a contiguous range (common for the 
streamed
-                // side which advances sequentially), use zero-copy slice 
instead
-                // of the O(n) take kernel.
-                self.streamed_batch
-                    .batch
-                    .slice(range.start, range.len())
-                    .columns()
-                    .to_vec()
-            } else {
-                take_arrays(self.streamed_batch.batch.columns(), 
&left_indices, None)?
-            };
+            if chunk.buffered_batch_idx.is_none() {
+                let left_columns =
+                    materialize_left_columns(&self.streamed_batch.batch, 
&left_indices)?;
+                let right_columns =
+                    create_unmatched_columns(&self.buffered_schema, 
left_indices.len());
 
-            // The row indices of joined buffered batch
-            let right_indices: UInt64Array = chunk.buffered_indices.finish();
-            let mut right_columns =
-                if matches!(self.join_type, JoinType::LeftMark | 
JoinType::RightMark) {
-                    vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef]
-                } else if matches!(
-                    self.join_type,
-                    JoinType::LeftSemi
-                        | JoinType::LeftAnti
-                        | JoinType::RightAnti
-                        | JoinType::RightSemi
-                ) {
-                    vec![]
-                } else if let Some(buffered_idx) = chunk.buffered_batch_idx {
-                    fetch_right_columns_by_idxs(
-                        &self.buffered_data,
-                        buffered_idx,
-                        &right_indices,
-                    )?
+                let columns = if self.join_type != JoinType::Right {
+                    [left_columns, right_columns].concat()
                 } else {
-                    // If buffered batch none, meaning it is null joined batch.
-                    // We need to create null arrays for buffered columns to 
join with streamed rows.
-                    create_unmatched_columns(
-                        self.join_type,
-                        &self.buffered_schema,
-                        right_indices.len(),
-                    )
+                    [right_columns, left_columns].concat()
                 };
+                let batch = RecordBatch::try_new(Arc::clone(&self.schema), 
columns)?;
 
-            // Prepare the columns we apply join filter on later.
-            // Only for joined rows between streamed and buffered.
-            let filter_columns = if let Some(buffered_batch_idx) =
-                chunk.buffered_batch_idx
-            {
-                if self.join_type != JoinType::Right {
-                    if matches!(
-                        self.join_type,
-                        JoinType::LeftSemi | JoinType::LeftAnti | 
JoinType::LeftMark
-                    ) {
-                        let right_cols = fetch_right_columns_by_idxs(
-                            &self.buffered_data,
-                            buffered_batch_idx,
-                            &right_indices,
-                        )?;
-
-                        get_filter_columns(&self.filter, &left_columns, 
&right_cols)
-                    } else if matches!(
-                        self.join_type,
-                        JoinType::RightAnti | JoinType::RightSemi | 
JoinType::RightMark
-                    ) {
-                        let right_cols = fetch_right_columns_by_idxs(
-                            &self.buffered_data,
-                            buffered_batch_idx,
-                            &right_indices,
-                        )?;
-
-                        get_filter_columns(&self.filter, &right_cols, 
&left_columns)
-                    } else {
-                        get_filter_columns(&self.filter, &left_columns, 
&right_columns)
-                    }
+                // Null-joined rows (no buffered match) need no filter 
correction,
+                // but must flow through the same pipeline as matched rows to
+                // preserve output ordering. Use null metadata as a sentinel so
+                // get_corrected_filter_mask() passes them through unchanged.
+                if needs_deferred_filtering(&self.filter, self.join_type) {
+                    self.joined_record_batches
+                        .push_batch_with_null_metadata(batch, self.join_type);
                 } else {
-                    get_filter_columns(&self.filter, &right_columns, 
&left_columns)
+                    self.joined_record_batches
+                        .push_batch_without_metadata(batch);
                 }
-            } else {
-                // This chunk is totally for null joined rows (outer join), we 
don't need to apply join filter.
-                // Any join filter applied only on either streamed or buffered 
side will be pushed already.
-                vec![]
-            };
+                continue;
+            }
 
-            let columns = if self.join_type != JoinType::Right {
-                left_columns.extend(right_columns);
-                left_columns
-            } else {
-                right_columns.extend(left_columns);
-                right_columns
-            };
+            total_matched_rows += left_indices.len();
+            matched_chunks.push((
+                chunk.buffered_batch_idx.unwrap(),
+                left_indices,
+                right_indices,
+            ));
+        }
+
+        if !matched_chunks.is_empty() {
+            self.freeze_streamed_matched(&matched_chunks, total_matched_rows)?;
+        }
+
+        self.streamed_batch.output_indices.clear();
+        self.streamed_batch.num_output_rows = 0;
+        Ok(())
+    }
+
+    /// Materializes columns, evaluates the join filter, and pushes output
+    /// for all matched chunks in a single batch. This avoids per-chunk
+    /// RecordBatch construction and filter evaluation, which dominates
+    /// cost when keys are near-unique (1 row per chunk).
+    fn freeze_streamed_matched(
+        &mut self,
+        matched_chunks: &[(usize, UInt64Array, UInt64Array)],
+        total_matched_rows: usize,
+    ) -> Result<()> {
+        debug_assert!(
+            !matched_chunks.is_empty(),
+            "caller guards this with an is_empty check before calling"
+        );
+        debug_assert!(
+            matched_chunks.iter().all(|(idx, left, right)| {
+                left.len() == right.len() && *idx < 
self.buffered_data.batches.len()
+            }),
+            "left/right indices are built in pairs from the same 
streamed×buffered cross, \
+             and batch_idx comes from iterating buffered_data.batches"
+        );
+        debug_assert_eq!(
+            matched_chunks
+                .iter()
+                .map(|(_, l, _)| l.len())
+                .sum::<usize>(),
+            total_matched_rows,
+            "total_matched_rows is accumulated from the same chunks in 
freeze_streamed"
+        );
+
+        let combined_left_indices = if matched_chunks.len() == 1 {
+            matched_chunks[0].1.clone()
+        } else {
+            let refs: Vec<&dyn Array> =
+                matched_chunks.iter().map(|c| &c.1 as &dyn Array).collect();
+            as_uint64_array(&compute::concat(&refs)?)?.clone()
+        };
+
+        let left_columns =
+            materialize_left_columns(&self.streamed_batch.batch, 
&combined_left_indices)?;
 
-            let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), 
columns)?;
-            // Apply join filter if any
-            if !filter_columns.is_empty() {
-                if let Some(f) = &self.filter {
-                    // Construct batch with only filter columns
-                    let filter_batch =
-                        RecordBatch::try_new(Arc::clone(f.schema()), 
filter_columns)?;
-
-                    let filter_result = f
-                        .expression()
-                        .evaluate(&filter_batch)?
-                        .into_array(filter_batch.num_rows())?;
-
-                    // The boolean selection mask of the join filter result
-                    let pre_mask =
-                        
datafusion_common::cast::as_boolean_array(&filter_result)?;
-
-                    // If there are nulls in join filter result, exclude them 
from selecting
-                    // the rows to output.
-                    let mask = if pre_mask.null_count() > 0 {
-                        compute::prep_null_mask_filter(
-                            
datafusion_common::cast::as_boolean_array(&filter_result)?,
-                        )
+        let right_columns =
+            self.materialize_right_columns(matched_chunks, 
total_matched_rows)?;
+
+        let filter_columns = if self.join_type == JoinType::Right {
+            get_filter_columns(&self.filter, &right_columns, &left_columns)
+        } else {
+            get_filter_columns(&self.filter, &left_columns, &right_columns)
+        };
+
+        let columns = if self.join_type != JoinType::Right {
+            [left_columns, right_columns].concat()
+        } else {
+            [right_columns, left_columns].concat()
+        };
+        let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), 
columns)?;
+
+        if !filter_columns.is_empty() {
+            if let Some(f) = &self.filter {
+                let filter_batch =
+                    RecordBatch::try_new(Arc::clone(f.schema()), 
filter_columns)?;
+                let filter_result = f
+                    .expression()
+                    .evaluate(&filter_batch)?
+                    .into_array(filter_batch.num_rows())?;
+
+                let pre_mask = 
datafusion_common::cast::as_boolean_array(&filter_result)?;
+
+                let mask = if pre_mask.null_count() > 0 {
+                    compute::prep_null_mask_filter(pre_mask)
+                } else {
+                    pre_mask.clone()
+                };
+
+                if needs_deferred_filtering(&self.filter, self.join_type) {
+                    // Full join uses pre_mask (preserving nulls) for
+                    // get_corrected_filter_mask; other outer joins use mask.
+                    let mask_to_use = if self.join_type != JoinType::Full {
+                        &mask
                     } else {
-                        pre_mask.clone()
+                        pre_mask
                     };
 
-                    // Push the filtered batch which contains rows passing 
join filter to the output
-                    // For outer/semi/anti/mark joins with deferred filtering, 
push the unfiltered batch with metadata
-                    // For INNER joins, filter immediately and push without 
metadata
-                    let needs_deferred_filtering = matches!(
+                    self.joined_record_batches.push_batch_with_filter_metadata(
+                        output_batch,
+                        &combined_left_indices,
+                        mask_to_use,
+                        self.streamed_batch_counter.load(Relaxed),
                         self.join_type,
-                        JoinType::Left
-                            | JoinType::LeftSemi
-                            | JoinType::Right
-                            | JoinType::RightSemi
-                            | JoinType::LeftAnti
-                            | JoinType::RightAnti
-                            | JoinType::LeftMark
-                            | JoinType::RightMark
-                            | JoinType::Full
                     );
+                } else {
+                    let filtered_batch = filter_record_batch(&output_batch, 
&mask)?;
+                    self.joined_record_batches
+                        .push_batch_without_metadata(filtered_batch);
+                }
 
-                    if needs_deferred_filtering {
-                        // Outer/semi/anti/mark joins: push unfiltered batch 
with metadata for deferred filtering
-                        let mask_to_use = if self.join_type != JoinType::Full {
-                            &mask
-                        } else {
-                            pre_mask
-                        };
-
-                        
self.joined_record_batches.push_batch_with_filter_metadata(
-                            output_batch,
-                            &left_indices,
-                            mask_to_use,
-                            self.streamed_batch_counter.load(Relaxed),
-                            self.join_type,
-                        );
-                    } else {
-                        // INNER joins: filter immediately and push without 
metadata
-                        let filtered_batch = 
filter_record_batch(&output_batch, &mask)?;
-                        self.joined_record_batches
-                            .push_batch_without_metadata(filtered_batch, 
self.join_type);
-                    }
+                // Track which buffered rows had all filter matches fail,
+                // so full join can emit them as null-joined later.
+                if self.join_type == JoinType::Full {
+                    let mut offset = 0usize;
+                    for (batch_idx, _left, right) in matched_chunks {
+                        let chunk_len = right.len();
+                        let buffered_batch = &mut 
self.buffered_data.batches[*batch_idx];
 
-                    // For outer joins, we need to push the null joined rows 
to the output if
-                    // all joined rows are failed on the join filter.
-                    // I.e., if all rows joined from a streamed row are failed 
with the join filter,
-                    // we need to join it with nulls as buffered side.
-                    if self.join_type == JoinType::Full {
-                        let buffered_batch = &mut self.buffered_data.batches
-                            [chunk.buffered_batch_idx.unwrap()];
-
-                        for i in 0..pre_mask.len() {
-                            // If the buffered row is not joined with streamed 
side,
-                            // skip it.
-                            if right_indices.is_null(i) {
+                        for i in 0..chunk_len {
+                            if right.is_null(i) {
                                 continue;
                             }
-
-                            let buffered_index = right_indices.value(i);
-
+                            let buffered_index = right.value(i);
                             buffered_batch.join_filter_not_matched_map.insert(
                                 buffered_index,
                                 *buffered_batch
                                     .join_filter_not_matched_map
                                     .get(&buffered_index)
                                     .unwrap_or(&true)
-                                    && !pre_mask.value(i),
+                                    && !pre_mask.value(offset + i),
                             );
                         }
+                        offset += chunk_len;
                     }
+                    debug_assert_eq!(
+                        offset, total_matched_rows,
+                        "offset must advance through every chunk exactly once"
+                    );
                 }
+            }
+        } else {
+            self.joined_record_batches
+                .push_batch_without_metadata(output_batch);
+        }
+
+        Ok(())
+    }
+
+    /// Materializes right-side columns across all matched chunks.
+    ///
+    /// When chunks reference a single buffered batch, indices are concatenated
+    /// for a single fetch. When multiple batches are involved, `interleave`
+    /// gathers columns across sources. A null-row sentinel at source index 0
+    /// handles null right indices (unmatched streamed rows).
+    fn materialize_right_columns(
+        &self,
+        matched_chunks: &[(usize, UInt64Array, UInt64Array)],
+        total_matched_rows: usize,
+    ) -> Result<Vec<ArrayRef>> {
+        let first_batch_idx = matched_chunks[0].0;
+        let single_source = matched_chunks.iter().all(|c| c.0 == 
first_batch_idx);
+
+        if single_source {
+            let combined_right_indices = if matched_chunks.len() == 1 {
+                matched_chunks[0].2.clone()
             } else {
-                self.joined_record_batches
-                    .push_batch_without_metadata(output_batch, self.join_type);
+                let refs: Vec<&dyn Array> =
+                    matched_chunks.iter().map(|c| &c.2 as &dyn 
Array).collect();
+                as_uint64_array(&compute::concat(&refs)?)?.clone()
+            };
+            return fetch_right_columns_by_idxs(
+                &self.buffered_data,
+                first_batch_idx,
+                &combined_right_indices,
+            );
+        }
+
+        // Multiple source batches: map each buffered_batch_idx to a
+        // contiguous source index, reserving source 0 for a null sentinel.
+        let mut batch_idx_to_source: HashMap<usize, usize> = HashMap::new();
+        let mut source_batches: Vec<usize> = Vec::new();
+        for (batch_idx, _, _) in matched_chunks {
+            batch_idx_to_source.entry(*batch_idx).or_insert_with(|| {
+                let idx = source_batches.len() + 1;
+                source_batches.push(*batch_idx);
+                idx
+            });
+        }
+
+        let mut interleave_indices: Vec<(usize, usize)> =
+            Vec::with_capacity(total_matched_rows);
+        for (batch_idx, _, right) in matched_chunks {
+            let source = batch_idx_to_source[batch_idx];
+            for i in 0..right.len() {
+                if right.is_null(i) {
+                    interleave_indices.push((0, 0));
+                } else {
+                    interleave_indices.push((source, right.value(i) as usize));
+                }
             }
         }
 
-        self.streamed_batch.output_indices.clear();
-        self.streamed_batch.num_output_rows = 0;
+        let num_right_cols = self.buffered_schema.fields().len();
+        let mut right_columns = Vec::with_capacity(num_right_cols);
 
-        Ok(())
+        // Read each source batch once (spilled batches require disk I/O).
+        let source_data: Vec<Option<RecordBatch>> = source_batches
+            .iter()
+            .map(|&idx| {
+                let bb = &self.buffered_data.batches[idx];
+                match &bb.batch {
+                    BufferedBatchState::InMemory(batch) => Some(batch.clone()),
+                    BufferedBatchState::Spilled(spill_file) => {
+                        let file = 
BufReader::new(File::open(spill_file.path()).ok()?);
+                        let reader = StreamReader::try_new(file, None).ok()?;
+                        reader.into_iter().next()?.ok()
+                    }
+                }
+            })
+            .collect();
+
+        for col_idx in 0..num_right_cols {

Review Comment:
   Similar concept to `create_unmatched_batch` in `piecewise_merge_join` — both 
produce null-padded rows for unmatched sides. Consolidating into a shared 
utility is future work.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to