mbutrovich commented on code in PR #20806: URL: https://github.com/apache/datafusion/pull/20806#discussion_r2907668997
########## datafusion/physical-plan/src/joins/semi_anti_sort_merge_join/stream.rs: ########## @@ -0,0 +1,1218 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Stream implementation for semi/anti sort-merge joins. + +use std::cmp::Ordering; +use std::fs::File; +use std::io::BufReader; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::RecordBatchStream; +use crate::joins::utils::{JoinFilter, compare_join_arrays}; +use crate::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, +}; +use crate::spill::spill_manager::SpillManager; +use arrow::array::{Array, ArrayRef, BooleanArray, BooleanBufferBuilder, RecordBatch}; +use arrow::compute::{BatchCoalescer, SortOptions, filter_record_batch, not}; +use arrow::datatypes::SchemaRef; +use arrow::ipc::reader::StreamReader; +use arrow::util::bit_chunk_iterator::UnalignedBitChunk; +use arrow::util::bit_util::apply_bitwise_binary_op; +use datafusion_common::{ + JoinSide, JoinType, NullEquality, Result, ScalarValue, internal_err, +}; +use datafusion_execution::SendableRecordBatchStream; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; + +use futures::{Stream, StreamExt, ready}; + +/// Evaluates join key expressions against a batch, returning one array per key. +fn evaluate_join_keys( + batch: &RecordBatch, + on: &[PhysicalExprRef], +) -> Result<Vec<ArrayRef>> { + on.iter() + .map(|expr| { + let num_rows = batch.num_rows(); + let val = expr.evaluate(batch)?; + val.into_array(num_rows) + }) + .collect() +} + +/// Find the first index in `key_arrays` starting from `from` where the key +/// differs from the key at `from`. Uses `compare_join_arrays` for zero-alloc +/// ordinal comparison. +/// +/// Optimized for join workloads: checks adjacent and boundary keys before +/// falling back to binary search, since most key groups are small (often 1). +fn find_key_group_end( + key_arrays: &[ArrayRef], + from: usize, + len: usize, + sort_options: &[SortOptions], + null_equality: NullEquality, +) -> Result<usize> { + let next = from + 1; + if next >= len { + return Ok(len); + } + + // Fast path: single-row group (common with unique keys). + if compare_join_arrays( + key_arrays, + from, + key_arrays, + next, + sort_options, + null_equality, + )? != Ordering::Equal + { + return Ok(next); + } + + // Check if the entire remaining batch shares this key. + let last = len - 1; + if compare_join_arrays( + key_arrays, + from, + key_arrays, + last, + sort_options, + null_equality, + )? == Ordering::Equal + { + return Ok(len); + } + + // Binary search the interior: key at `next` matches, key at `last` doesn't. + let mut lo = next + 1; + let mut hi = last; + while lo < hi { + let mid = lo + (hi - lo) / 2; + if compare_join_arrays( + key_arrays, + from, + key_arrays, + mid, + sort_options, + null_equality, + )? == Ordering::Equal + { + lo = mid + 1; + } else { + hi = mid; + } + } + Ok(lo) +} + +/// Tracks whether we're mid-key-group when `poll_next_outer_batch` returns +/// `Poll::Pending` inside the Equal branch's boundary loop. +/// +/// When an outer key group spans a batch boundary, the boundary loop emits +/// the current batch, then polls for the next. If that poll returns Pending, +/// `ready!` exits `poll_join` and we re-enter from the top on the next call. +/// Without this state, the new batch would be processed fresh by the +/// merge-scan — but inner already advanced past this key, so the matching +/// outer rows would be skipped via `Ordering::Less` and never marked. +/// +/// This enum saves the context needed to resume the boundary loop on +/// re-entry: compare the new batch's first key with the saved key to +/// decide whether to continue marking or move on. +#[derive(Debug)] +enum BoundaryState { + /// Normal processing — not inside a boundary poll. + Normal, + /// The no-filter boundary loop's `poll_next_outer_batch` returned + /// Pending. Carries the key arrays and index from the last emitted + /// batch so we can compare with the next batch's first key. + NoFilterPending { + saved_keys: Vec<ArrayRef>, + saved_idx: usize, + }, + /// The filtered boundary loop's `poll_next_outer_batch` returned + /// Pending. Carries the key arrays and index from the last emitted + /// outer batch so we can compare with the next batch's first key + /// without reading from the inner key buffer (which may have been + /// spilled to disk). + FilteredPending { + saved_keys: Vec<ArrayRef>, + saved_idx: usize, + }, +} + +pub(super) struct SemiAntiSortMergeJoinStream { + /// true for semi (emit matched), false for anti (emit unmatched) + is_semi: bool, + + // Input streams — "outer" is the streamed side whose rows we output, + // "inner" is the buffered side we match against. + outer: SendableRecordBatchStream, + inner: SendableRecordBatchStream, + + // Current batches and cursor positions + outer_batch: Option<RecordBatch>, + outer_offset: usize, + outer_key_arrays: Vec<ArrayRef>, + inner_batch: Option<RecordBatch>, + inner_offset: usize, + inner_key_arrays: Vec<ArrayRef>, + + // Per-outer-batch match tracking, reused across batches. + // Bit-packed (not Vec<bool>) so that: + // - emit: finish() yields a BooleanBuffer directly (no packing iteration) + // - OR: apply_bitwise_binary_op ORs filter results in u64 chunks + // - count: UnalignedBitChunk::count_ones uses popcnt + matched: BooleanBufferBuilder, + + // Inner key group buffer: all inner rows sharing the current join key. + // Only populated when a filter is present. Unbounded — a single key + // with many inner rows will buffer them all. See "Degenerate cases" + // in exec.rs. Spilled to disk when memory reservation fails. + inner_key_buffer: Vec<RecordBatch>, + inner_key_spill: Option<RefCountedTempFile>, + + // True when buffer_inner_key_group returned Pending after partially + // filling inner_key_buffer. On re-entry, buffer_inner_key_group + // must skip clear() and resume from poll_next_inner_batch (the + // current inner_batch was already sliced and pushed before Pending). + buffering_inner_pending: bool, + + // Boundary re-entry state — see BoundaryState doc comment. + boundary_state: BoundaryState, + + // Join condition + on_outer: Vec<PhysicalExprRef>, + on_inner: Vec<PhysicalExprRef>, + filter: Option<JoinFilter>, + sort_options: Vec<SortOptions>, + null_equality: NullEquality, + // When join_type is RightSemi/RightAnti, outer=right, inner=left, + // so we need to swap sides when building the filter batch. + outer_is_left: bool, + + // Output + coalescer: BatchCoalescer, + schema: SchemaRef, + + // Metrics + join_time: crate::metrics::Time, + input_batches: Count, + input_rows: Count, + baseline_metrics: BaselineMetrics, + peak_mem_used: Gauge, + + // Memory / spill — only the inner key buffer is tracked via reservation, + // matching existing SMJ (which tracks only the buffered side). The outer + // batch is a single batch at a time and cannot be spilled. + reservation: MemoryReservation, + spill_manager: SpillManager, + runtime_env: Arc<datafusion_execution::runtime_env::RuntimeEnv>, + inner_buffer_size: usize, + + // True once the current outer batch has been emitted. The Equal + // branch's inner loops call emit then `ready!(poll_next_outer_batch)`. + // If that poll returns Pending, poll_join re-enters from the top + // on the next poll — with outer_batch still Some and outer_offset + // past the end. The main loop's step 3 would re-emit without this + // guard. Cleared when poll_next_outer_batch loads a new batch. + batch_emitted: bool, +} + +impl SemiAntiSortMergeJoinStream { + #[expect(clippy::too_many_arguments)] + pub fn try_new( + schema: SchemaRef, + sort_options: Vec<SortOptions>, + null_equality: NullEquality, + outer: SendableRecordBatchStream, + inner: SendableRecordBatchStream, + on_outer: Vec<PhysicalExprRef>, + on_inner: Vec<PhysicalExprRef>, + filter: Option<JoinFilter>, + join_type: JoinType, + batch_size: usize, + partition: usize, + metrics: &ExecutionPlanMetricsSet, + reservation: MemoryReservation, + peak_mem_used: Gauge, + spill_manager: SpillManager, + runtime_env: Arc<datafusion_execution::runtime_env::RuntimeEnv>, + ) -> Result<Self> { + let is_semi = matches!(join_type, JoinType::LeftSemi | JoinType::RightSemi); + let outer_is_left = matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti); + + let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let baseline_metrics = BaselineMetrics::new(metrics, partition); + + Ok(Self { + is_semi, + outer, + inner, + outer_batch: None, + outer_offset: 0, + outer_key_arrays: vec![], + inner_batch: None, + inner_offset: 0, + inner_key_arrays: vec![], + matched: BooleanBufferBuilder::new(0), + inner_key_buffer: vec![], + inner_key_spill: None, + buffering_inner_pending: false, + boundary_state: BoundaryState::Normal, + on_outer, + on_inner, + filter, + sort_options, + null_equality, + outer_is_left, + coalescer: BatchCoalescer::new(Arc::clone(&schema), batch_size) + .with_biggest_coalesce_batch_size(Some(batch_size / 2)), + schema, + join_time, + input_batches, + input_rows, + baseline_metrics, + peak_mem_used, + reservation, + spill_manager, + runtime_env, + inner_buffer_size: 0, + batch_emitted: false, + }) + } + + /// Resize the memory reservation to match current tracked usage. + fn try_resize_reservation(&mut self) -> Result<()> { + let needed = self.inner_buffer_size; + self.reservation.try_resize(needed)?; + self.peak_mem_used.set_max(self.reservation.size()); + Ok(()) + } + + /// Spill the in-memory inner key buffer to disk and clear it. + fn spill_inner_key_buffer(&mut self) -> Result<()> { + let spill_file = self + .spill_manager + .spill_record_batch_and_finish( + &self.inner_key_buffer, + "semi_anti_smj_inner_key_spill", + )? + .expect("inner_key_buffer is non-empty when spilling"); + self.inner_key_buffer.clear(); + self.inner_buffer_size = 0; + self.inner_key_spill = Some(spill_file); + // Should succeed now — inner buffer has been spilled. + self.try_resize_reservation() + } + + /// Clear inner key group state after processing. Does not resize the + /// reservation — the next key group will resize when buffering, or + /// the stream's Drop will free it. This avoids unnecessary memory + /// pool interactions (see apache/datafusion#20729). + fn clear_inner_key_group(&mut self) { + self.inner_key_buffer.clear(); + self.inner_key_spill = None; + self.inner_buffer_size = 0; + } + + /// Poll for the next outer batch. Returns true if a batch was loaded. + fn poll_next_outer_batch(&mut self, cx: &mut Context<'_>) -> Poll<Result<bool>> { + loop { + match ready!(self.outer.poll_next_unpin(cx)) { + None => return Poll::Ready(Ok(false)), + Some(Err(e)) => return Poll::Ready(Err(e)), + Some(Ok(batch)) => { + self.input_batches.add(1); + self.input_rows.add(batch.num_rows()); + if batch.num_rows() == 0 { + continue; + } + let keys = evaluate_join_keys(&batch, &self.on_outer)?; + let num_rows = batch.num_rows(); + self.outer_batch = Some(batch); + self.outer_offset = 0; + self.outer_key_arrays = keys; + self.batch_emitted = false; + self.matched = BooleanBufferBuilder::new(num_rows); + self.matched.append_n(num_rows, false); + return Poll::Ready(Ok(true)); + } + } + } + } + + /// Poll for the next inner batch. Returns true if a batch was loaded. + fn poll_next_inner_batch(&mut self, cx: &mut Context<'_>) -> Poll<Result<bool>> { + loop { + match ready!(self.inner.poll_next_unpin(cx)) { + None => return Poll::Ready(Ok(false)), + Some(Err(e)) => return Poll::Ready(Err(e)), + Some(Ok(batch)) => { + self.input_batches.add(1); + self.input_rows.add(batch.num_rows()); + if batch.num_rows() == 0 { + continue; + } + let keys = evaluate_join_keys(&batch, &self.on_inner)?; + self.inner_batch = Some(batch); + self.inner_offset = 0; + self.inner_key_arrays = keys; + return Poll::Ready(Ok(true)); + } + } + } + } + + /// Emit the current outer batch through the coalescer, applying the + /// matched bitset as a selection mask. No-op if already emitted + /// (see `batch_emitted` field). + fn emit_outer_batch(&mut self) -> Result<()> { + if self.batch_emitted { + return Ok(()); + } + self.batch_emitted = true; + + let batch = self.outer_batch.as_ref().unwrap(); + + // finish() converts the bit-packed builder directly to a + // BooleanBuffer — no iteration or repacking needed. + let selection = BooleanArray::new(self.matched.finish(), None); + + let selection = if self.is_semi { + selection + } else { + not(&selection)? + }; + + let filtered = filter_record_batch(batch, &selection)?; + if filtered.num_rows() > 0 { + self.coalescer.push_batch(filtered)?; + } + Ok(()) + } + + /// Process a key match between outer and inner sides (no filter). + /// Sets matched bits for all outer rows sharing the current key. + fn process_key_match_no_filter(&mut self) -> Result<()> { + let outer_batch = self.outer_batch.as_ref().unwrap(); + let num_outer = outer_batch.num_rows(); + + let outer_group_end = find_key_group_end( + &self.outer_key_arrays, + self.outer_offset, + num_outer, + &self.sort_options, + self.null_equality, + )?; + + for i in self.outer_offset..outer_group_end { + self.matched.set_bit(i, true); + } + + self.outer_offset = outer_group_end; + Ok(()) + } + + /// Advance inner past the current key group. Returns Ok(true) if inner + /// is exhausted. + fn advance_inner_past_key_group( + &mut self, + cx: &mut Context<'_>, + ) -> Poll<Result<bool>> { + loop { + let inner_batch = match &self.inner_batch { + Some(b) => b, + None => return Poll::Ready(Ok(true)), + }; + let num_inner = inner_batch.num_rows(); + + let group_end = find_key_group_end( + &self.inner_key_arrays, + self.inner_offset, + num_inner, + &self.sort_options, + self.null_equality, + )?; + + if group_end < num_inner { + self.inner_offset = group_end; + return Poll::Ready(Ok(false)); + } + + // Key group extends to end of batch — need to check next batch + let last_key_idx = num_inner - 1; + let saved_inner_keys = self.inner_key_arrays.clone(); + + match ready!(self.poll_next_inner_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + return Poll::Ready(Ok(true)); + } + Ok(true) => { + if keys_match( + &saved_inner_keys, + last_key_idx, + &self.inner_key_arrays, + 0, + &self.sort_options, + self.null_equality, + )? { + continue; + } else { + return Poll::Ready(Ok(false)); + } + } + } + } + } + + /// Buffer inner key group for filter evaluation. Collects all inner rows + /// with the current key across batch boundaries. + /// + /// If poll_next_inner_batch returns Pending, we save progress via + /// buffering_inner_pending. On re-entry (from the Equal branch in + /// poll_join), we skip clear() and the slice+push for the current + /// batch (which was already buffered before Pending), and go directly + /// to polling for the next inner batch. + fn buffer_inner_key_group(&mut self, cx: &mut Context<'_>) -> Poll<Result<bool>> { + // On re-entry after Pending: don't clear the partially-filled + // buffer. The current inner_batch was already sliced and pushed + // before Pending, so jump to polling for the next batch. + let mut resume_from_poll = false; + if self.buffering_inner_pending { + self.buffering_inner_pending = false; + resume_from_poll = true; + } else { + self.clear_inner_key_group(); + } + + loop { + let inner_batch = match &self.inner_batch { + Some(b) => b, + None => return Poll::Ready(Ok(true)), + }; + let num_inner = inner_batch.num_rows(); + let group_end = find_key_group_end( + &self.inner_key_arrays, + self.inner_offset, + num_inner, + &self.sort_options, + self.null_equality, + )?; + + if !resume_from_poll { + let slice = + inner_batch.slice(self.inner_offset, group_end - self.inner_offset); + self.inner_buffer_size += slice.get_array_memory_size(); + self.inner_key_buffer.push(slice); + + // Reserve memory for the newly buffered slice. If the pool + // is exhausted, spill the entire buffer to disk. + if self.try_resize_reservation().is_err() { + if self.runtime_env.disk_manager.tmp_files_enabled() { + self.spill_inner_key_buffer()?; + } else { + // Re-attempt to get the error message + self.try_resize_reservation().map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "{e}. Disk spilling disabled." + )) + })?; + } + } + + if group_end < num_inner { + self.inner_offset = group_end; + return Poll::Ready(Ok(false)); + } + } + resume_from_poll = false; + + // Key group extends to end of batch — check next + let last_key_idx = num_inner - 1; + let saved_inner_keys = self.inner_key_arrays.clone(); + + // If poll returns Pending, the current batch is already + // in inner_key_buffer. + self.buffering_inner_pending = true; + match ready!(self.poll_next_inner_batch(cx)) { + Err(e) => { + self.buffering_inner_pending = false; + return Poll::Ready(Err(e)); + } + Ok(false) => { + self.buffering_inner_pending = false; + return Poll::Ready(Ok(true)); + } + Ok(true) => { + self.buffering_inner_pending = false; + if keys_match( + &saved_inner_keys, + last_key_idx, + &self.inner_key_arrays, + 0, + &self.sort_options, + self.null_equality, + )? { + continue; + } else { + return Poll::Ready(Ok(false)); + } + } + } + } + } + + /// Process a key match with a filter. For each inner row in the buffered + /// key group, evaluates the filter against the outer key group and ORs + /// the results into the matched bitset using u64-chunked bitwise ops. + fn process_key_match_with_filter(&mut self) -> Result<()> { + let filter = self.filter.as_ref().unwrap(); + let outer_batch = self.outer_batch.as_ref().unwrap(); + let num_outer = outer_batch.num_rows(); + + // buffer_inner_key_group must be called before this function + debug_assert!( + !self.inner_key_buffer.is_empty() || self.inner_key_spill.is_some(), + "process_key_match_with_filter called with no inner key data" + ); + debug_assert!( + self.outer_offset < num_outer, + "outer_offset must be within the current batch" + ); + debug_assert!( + self.matched.len() == num_outer, + "matched vector must be sized for the current outer batch" + ); + + let outer_group_end = find_key_group_end( + &self.outer_key_arrays, + self.outer_offset, + num_outer, + &self.sort_options, + self.null_equality, + )?; + let outer_group_len = outer_group_end - self.outer_offset; + let outer_slice = outer_batch.slice(self.outer_offset, outer_group_len); + + // Count already-matched bits using popcnt on u64 chunks (zero-copy). + let mut matched_count = UnalignedBitChunk::new( + self.matched.as_slice(), + self.outer_offset, + outer_group_len, + ) + .count_ones(); + + // Process spilled inner batches first (read back from disk). + if let Some(spill_file) = &self.inner_key_spill { + let file = BufReader::new(File::open(spill_file.path())?); + let reader = StreamReader::try_new(file, None)?; + for batch_result in reader { + let inner_slice = batch_result?; + matched_count = eval_filter_for_inner_slice( + self.outer_is_left, + filter, + &outer_slice, + &inner_slice, + &mut self.matched, + self.outer_offset, + outer_group_len, + matched_count, + )?; + if matched_count == outer_group_len { + break; + } + } + } + + // Then process in-memory inner batches. + // evaluate_filter_for_inner_row is a free function (not &self method) + // so that Rust can split the struct borrow: &mut self.matched coexists + // with &self.inner_key_buffer and &self.filter inside this loop. + if matched_count < outer_group_len { + 'outer: for inner_slice in &self.inner_key_buffer { + matched_count = eval_filter_for_inner_slice( + self.outer_is_left, + filter, + &outer_slice, + inner_slice, + &mut self.matched, + self.outer_offset, + outer_group_len, + matched_count, + )?; + if matched_count == outer_group_len { + break 'outer; + } + } + } + + self.outer_offset = outer_group_end; + Ok(()) + } + + /// Continue processing an outer key group that spans multiple outer + /// batches. Returns `true` if this outer batch was fully consumed + /// by the key group and the caller should load another. + fn resume_boundary(&mut self) -> Result<bool> { + debug_assert!( + self.outer_batch.is_some(), + "caller must load outer_batch first" + ); + match std::mem::replace(&mut self.boundary_state, BoundaryState::Normal) { Review Comment: Yeah, the `Normal` in `std::mem::replace` is just a placeholder. However, at lines 952 and 1017, `Normal` is also the correct final state (we consumed the pending data). At line 687, it's either correct (for the `Normal` and non-boundary-continuing cases) or gets overwritten at lines 706/738. I could add a comment on each replace call to clarify the intent, but I'd avoid adding a `TmpTransition` variant — it adds a match arm everywhere, turns a compile-time guarantee into a runtime panic, and the code paths already ensure `Normal` is correct. Alternatively, since `boundary_state` is only read via `match`, we could wrap it in `Option<BoundaryState>` and `.take()`. But that adds `Option` unwrapping everywhere else. I want to make sure this is maintainable for other folks, but I don't love the addition state. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
