mbutrovich commented on code in PR #18875:
URL: https://github.com/apache/datafusion/pull/18875#discussion_r2589507334
##########
datafusion/physical-plan/src/joins/sort_merge_join/stream.rs:
##########
@@ -347,26 +352,199 @@ pub(super) struct SortMergeJoinStream {
pub streamed_batch_counter: AtomicUsize,
}
-/// Joined batches with attached join filter information
+/// Staging area for joined data before output
+///
+/// Accumulates joined rows until either:
+/// - Target batch size reached (for efficiency)
+/// - Stream exhausted (flush remaining data)
pub(super) struct JoinedRecordBatches {
/// Joined batches. Each batch is already joined columns from left and
right sources
- pub batches: Vec<RecordBatch>,
- /// Filter match mask for each row(matched/non-matched)
- pub filter_mask: BooleanBuilder,
- /// Left row indices to glue together rows in `batches` and `filter_mask`
- pub row_indices: UInt64Builder,
- /// Which unique batch id the row belongs to
- /// It is necessary to differentiate rows that are distributed the way
when they point to the same
- /// row index but in not the same batches
- pub batch_ids: Vec<usize>,
+ pub(super) joined_batches: BatchCoalescer,
+ /// Did each output row pass the join filter? (detect if input row found
any match)
+ pub(super) filter_mask: BooleanBuilder,
+ /// Which input row (within batch) produced each output row? (for grouping
by input row)
+ pub(super) row_indices: UInt64Builder,
+ /// Which input batch did each output row come from? (disambiguate
row_indices)
+ pub(super) batch_ids: Vec<usize>,
}
impl JoinedRecordBatches {
- fn clear(&mut self) {
- self.batches.clear();
+ /// Concatenates all accumulated batches into a single RecordBatch
+ ///
+ /// Must drain ALL batches from BatchCoalescer for filtered joins to ensure
+ /// metadata alignment when applying get_corrected_filter_mask().
+ pub(super) fn concat_batches(&mut self, schema: &SchemaRef) ->
Result<RecordBatch> {
+ self.joined_batches.finish_buffered_batch()?;
+
+ let mut all_batches = vec![];
+ while let Some(batch) = self.joined_batches.next_completed_batch() {
+ all_batches.push(batch);
+ }
+
+ match all_batches.len() {
Review Comment:
Done, thanks!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]