This is an automated email from the ASF dual-hosted git repository. xudong963 pushed a commit to branch fix/sort-merge-reservation-starvation in repository https://gitbox.apache.org/repos/asf/datafusion.git
commit 8327648ec8c073dd084d0a3baa269d4c778da975 Author: xudong.w <[email protected]> AuthorDate: Sun Mar 1 21:13:01 2026 +0100 try to fix the sort merge oom --- datafusion/physical-plan/src/sorts/builder.rs | 41 +++++++- .../physical-plan/src/sorts/multi_level_merge.rs | 83 +++++++++------ datafusion/physical-plan/src/sorts/sort.rs | 114 ++++++++++++++++++++- 3 files changed, 202 insertions(+), 36 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/builder.rs b/datafusion/physical-plan/src/sorts/builder.rs index 9b2fa96822..881b995cf9 100644 --- a/datafusion/physical-plan/src/sorts/builder.rs +++ b/datafusion/physical-plan/src/sorts/builder.rs @@ -40,9 +40,24 @@ pub struct BatchBuilder { /// Maintain a list of [`RecordBatch`] and their corresponding stream batches: Vec<(usize, RecordBatch)>, - /// Accounts for memory used by buffered batches + /// Accounts for memory used by buffered batches. + /// + /// May include pre-reserved bytes (from `sort_spill_reservation_bytes`) + /// that were transferred via [`MemoryReservation::take()`] to prevent + /// starvation when concurrent sort partitions compete for pool memory. reservation: MemoryReservation, + /// Tracks the actual memory used by buffered batches (not including + /// pre-reserved bytes). This allows [`Self::push_batch`] to skip pool + /// allocation requests when the pre-reserved bytes cover the batch. + batches_mem_used: usize, + + /// The initial reservation size at construction time. When the reservation + /// is pre-loaded with `sort_spill_reservation_bytes` (via `take()`), this + /// records that amount so we never shrink below it, maintaining the + /// anti-starvation guarantee throughout the merge. + initial_reservation: usize, + /// The current [`BatchCursor`] for each stream cursors: Vec<BatchCursor>, @@ -59,19 +74,29 @@ impl BatchBuilder { batch_size: usize, reservation: MemoryReservation, ) -> Self { + let initial_reservation = reservation.size(); Self { schema, batches: Vec::with_capacity(stream_count * 2), cursors: vec![BatchCursor::default(); stream_count], indices: Vec::with_capacity(batch_size), reservation, + batches_mem_used: 0, + initial_reservation, } } /// Append a new batch in `stream_idx` pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) -> Result<()> { - self.reservation - .try_grow(get_record_batch_memory_size(&batch))?; + let size = get_record_batch_memory_size(&batch); + self.batches_mem_used += size; + // Only request additional memory from the pool when actual batch + // usage exceeds the current reservation (which may include + // pre-reserved bytes from sort_spill_reservation_bytes). + if self.batches_mem_used > self.reservation.size() { + self.reservation + .try_grow(self.batches_mem_used - self.reservation.size())?; + } let batch_idx = self.batches.len(); self.batches.push((stream_idx, batch)); self.cursors[stream_idx] = BatchCursor { @@ -143,11 +168,19 @@ impl BatchBuilder { stream_cursor.batch_idx = retained; retained += 1; } else { - self.reservation.shrink(get_record_batch_memory_size(batch)); + self.batches_mem_used -= get_record_batch_memory_size(batch); } retain }); + // Release excess memory back to the pool, but never shrink below + // initial_reservation to maintain the anti-starvation guarantee + // for the merge phase. + let target = self.batches_mem_used.max(self.initial_reservation); + if self.reservation.size() > target { + self.reservation.shrink(self.reservation.size() - target); + } + Ok(Some(RecordBatch::try_new( Arc::clone(&self.schema), columns, diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index 6e7a5e7a72..25cc102681 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -253,7 +253,12 @@ impl MultiLevelMergeBuilder { // Need to merge multiple streams (_, _) => { - let mut memory_reservation = self.reservation.new_empty(); + // Transfer any pre-reserved bytes (from sort_spill_reservation_bytes) + // to the merge memory reservation. This prevents starvation when + // concurrent sort partitions compete for pool memory: the pre-reserved + // bytes cover spill file buffer reservations without additional pool + // allocation. + let mut memory_reservation = self.reservation.take(); // Don't account for existing streams memory // as we are not holding the memory for them @@ -333,8 +338,10 @@ impl MultiLevelMergeBuilder { builder = builder.with_bypass_mempool(); } else { // If we are only merging in-memory streams, we need to use the memory reservation - // because we don't know the maximum size of the batches in the streams - builder = builder.with_reservation(self.reservation.new_empty()); + // because we don't know the maximum size of the batches in the streams. + // Use take() to transfer any pre-reserved bytes so the merge can use them + // as its initial budget without additional pool allocation. + builder = builder.with_reservation(self.reservation.take()); } builder.build() @@ -352,41 +359,55 @@ impl MultiLevelMergeBuilder { ) -> Result<(Vec<SortedSpillFile>, usize)> { assert_ne!(buffer_len, 0, "Buffer length must be greater than 0"); let mut number_of_spills_to_read_for_current_phase = 0; + // Track total memory needed for spill file buffers. When the + // reservation has pre-reserved bytes (from sort_spill_reservation_bytes), + // those bytes cover the first N spill files without additional pool + // allocation, preventing starvation under memory pressure. + let mut total_needed: usize = 0; for spill in &self.sorted_spill_files { - // For memory pools that are not shared this is good, for other this is not - // and there should be some upper limit to memory reservation so we won't starve the system - match reservation.try_grow(get_reserved_byte_for_record_batch_size( + let per_spill = get_reserved_byte_for_record_batch_size( spill.max_record_batch_memory * buffer_len, - )) { - Ok(_) => { - number_of_spills_to_read_for_current_phase += 1; - } - // If we can't grow the reservation, we need to stop - Err(err) => { - // We must have at least 2 streams to merge, so if we don't have enough memory - // fail - if minimum_number_of_required_streams - > number_of_spills_to_read_for_current_phase - { - // Free the memory we reserved for this merge as we either try again or fail - reservation.free(); - if buffer_len > 1 { - // Try again with smaller buffer size, it will be slower but at least we can merge - return self.get_sorted_spill_files_to_merge( - buffer_len - 1, - minimum_number_of_required_streams, - reservation, - ); + ); + total_needed += per_spill; + + // Only request additional memory from the pool when total needed + // exceeds what's already reserved (which may include pre-reserved + // bytes from sort_spill_reservation_bytes). + if total_needed > reservation.size() { + match reservation.try_grow(total_needed - reservation.size()) { + Ok(_) => { + number_of_spills_to_read_for_current_phase += 1; + } + // If we can't grow the reservation, we need to stop + Err(err) => { + // We must have at least 2 streams to merge, so if we don't have enough memory + // fail + if minimum_number_of_required_streams + > number_of_spills_to_read_for_current_phase + { + // Free the memory we reserved for this merge as we either try again or fail + reservation.free(); + if buffer_len > 1 { + // Try again with smaller buffer size, it will be slower but at least we can merge + return self.get_sorted_spill_files_to_merge( + buffer_len - 1, + minimum_number_of_required_streams, + reservation, + ); + } + + return Err(err); } - return Err(err); + // We reached the maximum amount of memory we can use + // for this merge + break; } - - // We reached the maximum amount of memory we can use - // for this merge - break; } + } else { + // Pre-reserved bytes cover this spill file's buffer + number_of_spills_to_read_for_current_phase += 1; } } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 19239f60cd..3e200609cf 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -354,6 +354,13 @@ impl ExternalSorter { self.sort_and_spill_in_mem_batches().await?; } + // Transfer the pre-reserved merge memory to the streaming merge + // using `take()` instead of `new_empty()`. This ensures the merge + // stream starts with `sort_spill_reservation_bytes` already + // allocated, preventing starvation when concurrent sort partitions + // compete for pool memory. `take()` moves the bytes atomically + // without releasing them back to the pool, so other partitions + // cannot race to consume the freed memory. StreamingMergeBuilder::new() .with_sorted_spill_files(std::mem::take(&mut self.finished_spill_files)) .with_spill_manager(self.spill_manager.clone()) @@ -362,7 +369,7 @@ impl ExternalSorter { .with_metrics(self.metrics.baseline.clone()) .with_batch_size(self.batch_size) .with_fetch(None) - .with_reservation(self.merge_reservation.new_empty()) + .with_reservation(self.merge_reservation.take()) .build() } else { self.in_mem_sort_stream(self.metrics.baseline.clone()) @@ -2407,4 +2414,109 @@ mod tests { Ok((sorted_batches, metrics)) } + + /// Test that concurrent sort partitions sharing a tight memory pool + /// don't starve during the merge phase. + /// + /// This reproduces the starvation scenario where: + /// 1. Multiple ExternalSorter instances share a single GreedyMemoryPool + /// 2. Each reserves `sort_spill_reservation_bytes` for its merge phase + /// 3. After spilling, the merge must proceed using the pre-reserved bytes + /// without additional pool allocation + /// + /// Without the fix (using `take()` + smart tracking), the merge's + /// `new_empty()` reservation starts at 0 bytes and the pre-reserved bytes + /// sit unused in ExternalSorter's merge_reservation. When other partitions + /// consume the freed memory, the merge starves. + /// + /// With the fix, the pre-reserved bytes are atomically transferred to the + /// merge stream and used for spill file buffer reservations, preventing + /// starvation. + #[tokio::test] + async fn test_sort_merge_no_starvation_with_concurrent_partitions() -> Result<()> { + use futures::TryStreamExt; + + let sort_spill_reservation_bytes: usize = 10 * 1024; // 10 KB per partition + let num_partitions: usize = 4; + + // Pool: each partition needs sort_spill_reservation_bytes for its merge, + // plus a small amount for data accumulation before spilling. + // Total: 4 * 10KB + 8KB = 48KB -- very tight. + let memory_limit = + sort_spill_reservation_bytes * num_partitions + 8 * 1024; + + let session_config = SessionConfig::new() + .with_batch_size(128) + .with_sort_spill_reservation_bytes(sort_spill_reservation_bytes); + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .build_arc()?; + + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ); + + // Create multiple batches per partition to force spilling. + // Each batch: 100 rows of Int32 ≈ 400 bytes. + // 20 batches per partition ≈ 8KB per partition. + // With only ~2KB of pool headroom per partition, this forces spilling. + let batches_per_partition = 20; + let rows_per_batch: i32 = 100; + + let all_partitions: Vec<Vec<RecordBatch>> = (0..num_partitions) + .map(|_| { + (0..batches_per_partition) + .map(|_| make_partition(rows_per_batch)) + .collect() + }) + .collect(); + + let schema = all_partitions[0][0].schema(); + let input = TestMemoryExec::try_new_exec(&all_partitions, schema.clone(), None)?; + + let sort_exec = Arc::new( + SortExec::new( + [PhysicalSortExpr { + expr: col("i", &schema)?, + options: SortOptions::default(), + }] + .into(), + input, + ) + .with_preserve_partitioning(true), + ); + + // Execute all partitions concurrently -- they share the same pool. + let mut tasks = Vec::new(); + for partition in 0..num_partitions { + let sort = Arc::clone(&sort_exec); + let ctx = Arc::clone(&task_ctx); + tasks.push(tokio::spawn(async move { + let stream = sort.execute(partition, ctx)?; + let batches: Vec<RecordBatch> = stream.try_collect().await?; + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + Ok::<usize, DataFusionError>(total_rows) + })); + } + + let mut total_rows = 0; + for task in tasks { + total_rows += task.await.unwrap()?; + } + + let expected_rows = + num_partitions * batches_per_partition * (rows_per_batch as usize); + assert_eq!(total_rows, expected_rows); + + assert_eq!( + task_ctx.runtime_env().memory_pool.reserved(), + 0, + "All memory should be returned to the pool after sort completes" + ); + + Ok(()) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
