kosiew commented on code in PR #21962:
URL: https://github.com/apache/datafusion/pull/21962#discussion_r3198707608


##########
datafusion/physical-plan/src/joins/sort_merge_join/tests.rs:
##########
@@ -2487,6 +2487,223 @@ async fn overallocation_multi_batch_spill() -> 
Result<()> {
     Ok(())
 }
 
+/// Verifies that `peak_mem_used` reflects join_arrays memory on the spill 
path.
+///
+/// Uses a memory limit smaller than a single batch's `size_estimation` so that
+/// every batch spills — the `Ok` arm of `allocate_reservation` is never hit.
+/// Before the fix, `peak_mem_used` would stay 0 because `set_max` was only
+/// called in the `Ok` arm. After the fix, the spill path calls
+/// `grow(join_arrays_mem)` + `set_max`, so `peak_mem_used > 0`.
+#[tokio::test]
+async fn spill_join_arrays_memory_accounting() -> Result<()> {
+    use arrow::array::Array;
+
+    let left_batch = build_table_i32(
+        ("a1", &vec![0, 1]),
+        ("b1", &vec![1, 1]),
+        ("c1", &vec![4, 5]),
+    );
+    let size_estimation = left_batch.get_array_memory_size()
+        + Int32Array::from(vec![1, 1]).get_array_memory_size()
+        + 2usize.next_power_of_two() * size_of::<usize>()
+        + size_of::<std::ops::Range<usize>>()
+        + size_of::<usize>();
+    let join_arrays_mem = Int32Array::from(vec![1, 1]).get_array_memory_size();
+
+    // Memory limit: too small for a full batch, large enough for join_arrays.
+    // Every batch hits the Err arm → spills → grow(join_arrays_mem).
+    let memory_limit = (size_estimation + join_arrays_mem) / 2;
+    assert!(
+        memory_limit < size_estimation && memory_limit > join_arrays_mem,
+        "limit {memory_limit} must be between join_arrays_mem 
{join_arrays_mem} \
+         and size_estimation {size_estimation}"
+    );
+
+    let left_batches: Vec<RecordBatch> = (0..4)
+        .map(|i| {
+            build_table_i32(
+                ("a1", &vec![i * 2, i * 2 + 1]),
+                ("b1", &vec![1, 1]),
+                ("c1", &vec![100 + i, 101 + i]),
+            )
+        })
+        .collect();
+    let left = build_table_from_batches(left_batches);
+
+    let right_batches: Vec<RecordBatch> = (0..2)
+        .map(|i| {
+            build_table_i32(
+                ("a2", &vec![i * 2, i * 2 + 1]),
+                ("b2", &vec![1, 1]),
+                ("c2", &vec![200 + i, 201 + i]),
+            )
+        })
+        .collect();
+    let right = build_table_from_batches(right_batches);
+
+    let on = vec![(
+        Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+        Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
+    )];
+    let sort_options = vec![SortOptions::default(); on.len()];
+
+    let runtime = RuntimeEnvBuilder::new()
+        .with_memory_limit(memory_limit, 1.0)
+        .with_disk_manager_builder(
+            
DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
+        )
+        .build_arc()?;
+
+    let session_config = SessionConfig::default().with_batch_size(50);
+    let task_ctx = Arc::new(
+        TaskContext::default()
+            .with_session_config(session_config)
+            .with_runtime(Arc::clone(&runtime)),
+    );
+
+    let join = join_with_options(
+        Arc::clone(&left),
+        Arc::clone(&right),
+        on.clone(),
+        Inner,
+        sort_options,
+        NullEquality::NullEqualsNothing,
+    )?;
+
+    let stream = join.execute(0, task_ctx)?;
+    let result = common::collect(stream).await.unwrap();
+
+    assert!(!result.is_empty(), "Expected non-empty join result");
+
+    let metrics = join.metrics().unwrap();
+    assert!(
+        metrics.spill_count().unwrap() > 0,
+        "Expected spilling to occur"
+    );
+
+    // Before the fix, peak_mem_used was 0 here because set_max was only
+    // called in the Ok arm of allocate_reservation, which is never reached
+    // when every batch spills. After the fix, the spill path calls
+    // grow(join_arrays_mem) + set_max unconditionally.
+    let peak_mem = metrics
+        .sum_by_name("peak_mem_used")
+        .map(|m| m.as_usize())
+        .unwrap_or(0);

Review Comment:
   The new tests look good and `peak_mem > 0` definitely catches the old 
missing-accounting path.
   
   As a possible follow-up improvement, it might be worth tightening this to 
something like `peak_mem >= join_arrays_mem` if that value is available in 
scope. That would also help catch partial accounting regressions. Non-blocking 
though.



##########
datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs:
##########
@@ -235,6 +235,14 @@ pub(super) struct BufferedBatch {
     pub null_joined: Vec<usize>,
     /// Size estimation used for reserving / releasing memory
     pub size_estimation: usize,
+    /// Actual amount tracked in the memory reservation for this batch.
+    ///
+    /// - `InMemory`: equals `size_estimation` (full batch + join_arrays + 
metadata)
+    /// - `Spilled`: equals join_arrays memory if `try_grow` succeeded after 
spill, else 0

Review Comment:
   Nice catch fixing the accounting path here 👍
   
   One small thing: the `reserved_amount` field doc still says spilled batches 
only track join-array memory if `try_grow` succeeds, otherwise `0`. Since the 
implementation now uses unconditional `grow(join_arrays_mem)`, I think the doc 
comment should be updated to match the new behavior.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to