This is an automated email from the ASF dual-hosted git repository.

xudong963 pushed a commit to branch fix/sort-merge-reservation-starvation
in repository https://gitbox.apache.org/repos/asf/datafusion.git

commit 8327648ec8c073dd084d0a3baa269d4c778da975
Author: xudong.w <[email protected]>
AuthorDate: Sun Mar 1 21:13:01 2026 +0100

    try to fix the sort merge oom
---
 datafusion/physical-plan/src/sorts/builder.rs      |  41 +++++++-
 .../physical-plan/src/sorts/multi_level_merge.rs   |  83 +++++++++------
 datafusion/physical-plan/src/sorts/sort.rs         | 114 ++++++++++++++++++++-
 3 files changed, 202 insertions(+), 36 deletions(-)

diff --git a/datafusion/physical-plan/src/sorts/builder.rs 
b/datafusion/physical-plan/src/sorts/builder.rs
index 9b2fa96822..881b995cf9 100644
--- a/datafusion/physical-plan/src/sorts/builder.rs
+++ b/datafusion/physical-plan/src/sorts/builder.rs
@@ -40,9 +40,24 @@ pub struct BatchBuilder {
     /// Maintain a list of [`RecordBatch`] and their corresponding stream
     batches: Vec<(usize, RecordBatch)>,
 
-    /// Accounts for memory used by buffered batches
+    /// Accounts for memory used by buffered batches.
+    ///
+    /// May include pre-reserved bytes (from `sort_spill_reservation_bytes`)
+    /// that were transferred via [`MemoryReservation::take()`] to prevent
+    /// starvation when concurrent sort partitions compete for pool memory.
     reservation: MemoryReservation,
 
+    /// Tracks the actual memory used by buffered batches (not including
+    /// pre-reserved bytes). This allows [`Self::push_batch`] to skip pool
+    /// allocation requests when the pre-reserved bytes cover the batch.
+    batches_mem_used: usize,
+
+    /// The initial reservation size at construction time. When the reservation
+    /// is pre-loaded with `sort_spill_reservation_bytes` (via `take()`), this
+    /// records that amount so we never shrink below it, maintaining the
+    /// anti-starvation guarantee throughout the merge.
+    initial_reservation: usize,
+
     /// The current [`BatchCursor`] for each stream
     cursors: Vec<BatchCursor>,
 
@@ -59,19 +74,29 @@ impl BatchBuilder {
         batch_size: usize,
         reservation: MemoryReservation,
     ) -> Self {
+        let initial_reservation = reservation.size();
         Self {
             schema,
             batches: Vec::with_capacity(stream_count * 2),
             cursors: vec![BatchCursor::default(); stream_count],
             indices: Vec::with_capacity(batch_size),
             reservation,
+            batches_mem_used: 0,
+            initial_reservation,
         }
     }
 
     /// Append a new batch in `stream_idx`
     pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) -> 
Result<()> {
-        self.reservation
-            .try_grow(get_record_batch_memory_size(&batch))?;
+        let size = get_record_batch_memory_size(&batch);
+        self.batches_mem_used += size;
+        // Only request additional memory from the pool when actual batch
+        // usage exceeds the current reservation (which may include
+        // pre-reserved bytes from sort_spill_reservation_bytes).
+        if self.batches_mem_used > self.reservation.size() {
+            self.reservation
+                .try_grow(self.batches_mem_used - self.reservation.size())?;
+        }
         let batch_idx = self.batches.len();
         self.batches.push((stream_idx, batch));
         self.cursors[stream_idx] = BatchCursor {
@@ -143,11 +168,19 @@ impl BatchBuilder {
                 stream_cursor.batch_idx = retained;
                 retained += 1;
             } else {
-                self.reservation.shrink(get_record_batch_memory_size(batch));
+                self.batches_mem_used -= get_record_batch_memory_size(batch);
             }
             retain
         });
 
+        // Release excess memory back to the pool, but never shrink below
+        // initial_reservation to maintain the anti-starvation guarantee
+        // for the merge phase.
+        let target = self.batches_mem_used.max(self.initial_reservation);
+        if self.reservation.size() > target {
+            self.reservation.shrink(self.reservation.size() - target);
+        }
+
         Ok(Some(RecordBatch::try_new(
             Arc::clone(&self.schema),
             columns,
diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs 
b/datafusion/physical-plan/src/sorts/multi_level_merge.rs
index 6e7a5e7a72..25cc102681 100644
--- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs
+++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs
@@ -253,7 +253,12 @@ impl MultiLevelMergeBuilder {
 
             // Need to merge multiple streams
             (_, _) => {
-                let mut memory_reservation = self.reservation.new_empty();
+                // Transfer any pre-reserved bytes (from 
sort_spill_reservation_bytes)
+                // to the merge memory reservation. This prevents starvation 
when
+                // concurrent sort partitions compete for pool memory: the 
pre-reserved
+                // bytes cover spill file buffer reservations without 
additional pool
+                // allocation.
+                let mut memory_reservation = self.reservation.take();
 
                 // Don't account for existing streams memory
                 // as we are not holding the memory for them
@@ -333,8 +338,10 @@ impl MultiLevelMergeBuilder {
             builder = builder.with_bypass_mempool();
         } else {
             // If we are only merging in-memory streams, we need to use the 
memory reservation
-            // because we don't know the maximum size of the batches in the 
streams
-            builder = builder.with_reservation(self.reservation.new_empty());
+            // because we don't know the maximum size of the batches in the 
streams.
+            // Use take() to transfer any pre-reserved bytes so the merge can 
use them
+            // as its initial budget without additional pool allocation.
+            builder = builder.with_reservation(self.reservation.take());
         }
 
         builder.build()
@@ -352,41 +359,55 @@ impl MultiLevelMergeBuilder {
     ) -> Result<(Vec<SortedSpillFile>, usize)> {
         assert_ne!(buffer_len, 0, "Buffer length must be greater than 0");
         let mut number_of_spills_to_read_for_current_phase = 0;
+        // Track total memory needed for spill file buffers. When the
+        // reservation has pre-reserved bytes (from 
sort_spill_reservation_bytes),
+        // those bytes cover the first N spill files without additional pool
+        // allocation, preventing starvation under memory pressure.
+        let mut total_needed: usize = 0;
 
         for spill in &self.sorted_spill_files {
-            // For memory pools that are not shared this is good, for other 
this is not
-            // and there should be some upper limit to memory reservation so 
we won't starve the system
-            match reservation.try_grow(get_reserved_byte_for_record_batch_size(
+            let per_spill = get_reserved_byte_for_record_batch_size(
                 spill.max_record_batch_memory * buffer_len,
-            )) {
-                Ok(_) => {
-                    number_of_spills_to_read_for_current_phase += 1;
-                }
-                // If we can't grow the reservation, we need to stop
-                Err(err) => {
-                    // We must have at least 2 streams to merge, so if we 
don't have enough memory
-                    // fail
-                    if minimum_number_of_required_streams
-                        > number_of_spills_to_read_for_current_phase
-                    {
-                        // Free the memory we reserved for this merge as we 
either try again or fail
-                        reservation.free();
-                        if buffer_len > 1 {
-                            // Try again with smaller buffer size, it will be 
slower but at least we can merge
-                            return self.get_sorted_spill_files_to_merge(
-                                buffer_len - 1,
-                                minimum_number_of_required_streams,
-                                reservation,
-                            );
+            );
+            total_needed += per_spill;
+
+            // Only request additional memory from the pool when total needed
+            // exceeds what's already reserved (which may include pre-reserved
+            // bytes from sort_spill_reservation_bytes).
+            if total_needed > reservation.size() {
+                match reservation.try_grow(total_needed - reservation.size()) {
+                    Ok(_) => {
+                        number_of_spills_to_read_for_current_phase += 1;
+                    }
+                    // If we can't grow the reservation, we need to stop
+                    Err(err) => {
+                        // We must have at least 2 streams to merge, so if we 
don't have enough memory
+                        // fail
+                        if minimum_number_of_required_streams
+                            > number_of_spills_to_read_for_current_phase
+                        {
+                            // Free the memory we reserved for this merge as 
we either try again or fail
+                            reservation.free();
+                            if buffer_len > 1 {
+                                // Try again with smaller buffer size, it will 
be slower but at least we can merge
+                                return self.get_sorted_spill_files_to_merge(
+                                    buffer_len - 1,
+                                    minimum_number_of_required_streams,
+                                    reservation,
+                                );
+                            }
+
+                            return Err(err);
                         }
 
-                        return Err(err);
+                        // We reached the maximum amount of memory we can use
+                        // for this merge
+                        break;
                     }
-
-                    // We reached the maximum amount of memory we can use
-                    // for this merge
-                    break;
                 }
+            } else {
+                // Pre-reserved bytes cover this spill file's buffer
+                number_of_spills_to_read_for_current_phase += 1;
             }
         }
 
diff --git a/datafusion/physical-plan/src/sorts/sort.rs 
b/datafusion/physical-plan/src/sorts/sort.rs
index 19239f60cd..3e200609cf 100644
--- a/datafusion/physical-plan/src/sorts/sort.rs
+++ b/datafusion/physical-plan/src/sorts/sort.rs
@@ -354,6 +354,13 @@ impl ExternalSorter {
                 self.sort_and_spill_in_mem_batches().await?;
             }
 
+            // Transfer the pre-reserved merge memory to the streaming merge
+            // using `take()` instead of `new_empty()`. This ensures the merge
+            // stream starts with `sort_spill_reservation_bytes` already
+            // allocated, preventing starvation when concurrent sort partitions
+            // compete for pool memory. `take()` moves the bytes atomically
+            // without releasing them back to the pool, so other partitions
+            // cannot race to consume the freed memory.
             StreamingMergeBuilder::new()
                 .with_sorted_spill_files(std::mem::take(&mut 
self.finished_spill_files))
                 .with_spill_manager(self.spill_manager.clone())
@@ -362,7 +369,7 @@ impl ExternalSorter {
                 .with_metrics(self.metrics.baseline.clone())
                 .with_batch_size(self.batch_size)
                 .with_fetch(None)
-                .with_reservation(self.merge_reservation.new_empty())
+                .with_reservation(self.merge_reservation.take())
                 .build()
         } else {
             self.in_mem_sort_stream(self.metrics.baseline.clone())
@@ -2407,4 +2414,109 @@ mod tests {
 
         Ok((sorted_batches, metrics))
     }
+
+    /// Test that concurrent sort partitions sharing a tight memory pool
+    /// don't starve during the merge phase.
+    ///
+    /// This reproduces the starvation scenario where:
+    /// 1. Multiple ExternalSorter instances share a single GreedyMemoryPool
+    /// 2. Each reserves `sort_spill_reservation_bytes` for its merge phase
+    /// 3. After spilling, the merge must proceed using the pre-reserved bytes
+    ///    without additional pool allocation
+    ///
+    /// Without the fix (using `take()` + smart tracking), the merge's
+    /// `new_empty()` reservation starts at 0 bytes and the pre-reserved bytes
+    /// sit unused in ExternalSorter's merge_reservation. When other partitions
+    /// consume the freed memory, the merge starves.
+    ///
+    /// With the fix, the pre-reserved bytes are atomically transferred to the
+    /// merge stream and used for spill file buffer reservations, preventing
+    /// starvation.
+    #[tokio::test]
+    async fn test_sort_merge_no_starvation_with_concurrent_partitions() -> 
Result<()> {
+        use futures::TryStreamExt;
+
+        let sort_spill_reservation_bytes: usize = 10 * 1024; // 10 KB per 
partition
+        let num_partitions: usize = 4;
+
+        // Pool: each partition needs sort_spill_reservation_bytes for its 
merge,
+        // plus a small amount for data accumulation before spilling.
+        // Total: 4 * 10KB + 8KB = 48KB -- very tight.
+        let memory_limit =
+            sort_spill_reservation_bytes * num_partitions + 8 * 1024;
+
+        let session_config = SessionConfig::new()
+            .with_batch_size(128)
+            .with_sort_spill_reservation_bytes(sort_spill_reservation_bytes);
+
+        let runtime = RuntimeEnvBuilder::new()
+            .with_memory_limit(memory_limit, 1.0)
+            .build_arc()?;
+
+        let task_ctx = Arc::new(
+            TaskContext::default()
+                .with_session_config(session_config)
+                .with_runtime(runtime),
+        );
+
+        // Create multiple batches per partition to force spilling.
+        // Each batch: 100 rows of Int32 ≈ 400 bytes.
+        // 20 batches per partition ≈ 8KB per partition.
+        // With only ~2KB of pool headroom per partition, this forces spilling.
+        let batches_per_partition = 20;
+        let rows_per_batch: i32 = 100;
+
+        let all_partitions: Vec<Vec<RecordBatch>> = (0..num_partitions)
+            .map(|_| {
+                (0..batches_per_partition)
+                    .map(|_| make_partition(rows_per_batch))
+                    .collect()
+            })
+            .collect();
+
+        let schema = all_partitions[0][0].schema();
+        let input = TestMemoryExec::try_new_exec(&all_partitions, 
schema.clone(), None)?;
+
+        let sort_exec = Arc::new(
+            SortExec::new(
+                [PhysicalSortExpr {
+                    expr: col("i", &schema)?,
+                    options: SortOptions::default(),
+                }]
+                .into(),
+                input,
+            )
+            .with_preserve_partitioning(true),
+        );
+
+        // Execute all partitions concurrently -- they share the same pool.
+        let mut tasks = Vec::new();
+        for partition in 0..num_partitions {
+            let sort = Arc::clone(&sort_exec);
+            let ctx = Arc::clone(&task_ctx);
+            tasks.push(tokio::spawn(async move {
+                let stream = sort.execute(partition, ctx)?;
+                let batches: Vec<RecordBatch> = stream.try_collect().await?;
+                let total_rows: usize = batches.iter().map(|b| 
b.num_rows()).sum();
+                Ok::<usize, DataFusionError>(total_rows)
+            }));
+        }
+
+        let mut total_rows = 0;
+        for task in tasks {
+            total_rows += task.await.unwrap()?;
+        }
+
+        let expected_rows =
+            num_partitions * batches_per_partition * (rows_per_batch as usize);
+        assert_eq!(total_rows, expected_rows);
+
+        assert_eq!(
+            task_ctx.runtime_env().memory_pool.reserved(),
+            0,
+            "All memory should be returned to the pool after sort completes"
+        );
+
+        Ok(())
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to