alamb commented on code in PR #20494:
URL: https://github.com/apache/datafusion/pull/20494#discussion_r2977914939


##########
datafusion/physical-plan/src/topk/mod.rs:
##########
@@ -749,47 +736,84 @@ impl TopKHeap {
     }
 
     /// Returns the values stored in this heap, from values low to
-    /// high, as a single [`RecordBatch`], resetting the inner heap
-    pub fn emit(&mut self) -> Result<Option<RecordBatch>> {
+    /// high, as [`RecordBatch`]es, resetting the inner heap
+    pub fn emit(&mut self) -> Result<Vec<RecordBatch>> {
         Ok(self.emit_with_state()?.0)
     }
 
     /// Returns the values stored in this heap, from values low to
-    /// high, as a single [`RecordBatch`], and a sorted vec of the
+    /// high, as [`RecordBatch`]es, and a sorted vec of the
     /// current heap's contents
-    pub fn emit_with_state(&mut self) -> Result<(Option<RecordBatch>, 
Vec<TopKRow>)> {
+    pub fn emit_with_state(&mut self) -> Result<(Vec<RecordBatch>, 
Vec<TopKRow>)> {
         // generate sorted rows
         let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec();
 
         if self.store.is_empty() {
-            return Ok((None, topk_rows));
+            return Ok((Vec::new(), topk_rows));
+        }
+
+        let batches = self.interleave_topk_rows(&topk_rows, self.batch_size)?;
+
+        Ok((batches, topk_rows))
+    }
+
+    fn interleave_topk_rows(

Review Comment:
   Can you please add some comments that explain the rationale for this method 
and  what the rationale is?
   
   You can probably repurpose the (very nice) description in this PR



##########
datafusion/physical-plan/src/topk/mod.rs:
##########
@@ -884,6 +916,300 @@ impl TopKHeap {
         Ok(Some(scalar_values))
     }
 }
+const I32_OFFSET_LIMIT: i64 = i32::MAX as i64;
+
+fn split_indices_by_i32_offsets(
+    record_batches: &[&RecordBatch],
+    all_indices: &[(usize, usize)],
+    max_rows_per_batch: usize,
+    max_offset: i64,
+) -> Result<Vec<Range<usize>>> {
+    if all_indices.is_empty() {
+        return Ok(Vec::new());
+    }
+
+    let var_width_columns =
+        collect_var_width_columns(record_batches.first().ok_or_else(|| {
+            internal_datafusion_err!("Missing record batches for TopK 
interleave")
+        })?);
+
+    if var_width_columns.is_empty() {
+        return Ok(split_indices_by_row_count(
+            all_indices.len(),
+            max_rows_per_batch,
+        ));
+    }
+
+    // Fast path: if the combined data across *all* batches is well under the
+    // limit, no single interleave chunk can overflow regardless of which rows
+    // are selected, so we can skip the per-row accounting loop entirely.
+    let total_across_batches: i64 = record_batches
+        .iter()
+        .flat_map(|batch| {
+            var_width_columns.iter().map(|col| {
+                col.get_array(batch)
+                    .map(|a| col.total_data_size(a))
+                    .unwrap_or(0)
+            })
+        })
+        .fold(0_i64, |acc, v| acc.saturating_add(v));
+
+    if total_across_batches <= max_offset / 2 {
+        return Ok(split_indices_by_row_count(
+            all_indices.len(),
+            max_rows_per_batch,
+        ));
+    }
+
+    let mut ranges = Vec::new();
+    let mut start = 0;
+    let mut totals = vec![0_i64; var_width_columns.len()];
+
+    for (pos, (batch_pos, row_index)) in all_indices.iter().enumerate() {
+        if pos - start >= max_rows_per_batch {
+            ranges.push(start..pos);
+            start = pos;
+            totals.fill(0);
+        }
+
+        let batch = record_batches.get(*batch_pos).ok_or_else(|| {
+            internal_datafusion_err!("Invalid batch position in TopK indices")
+        })?;
+
+        let mut row_sizes = Vec::with_capacity(var_width_columns.len());
+        for column in &var_width_columns {
+            let array = column.get_array(batch)?;
+            let size = column.row_size(array, *row_index)?;
+            if size > max_offset {
+                return internal_err!(
+                    "TopK row requires {size} offsets which exceeds i32::MAX"
+                );
+            }
+            row_sizes.push(size);
+        }
+
+        if totals
+            .iter()
+            .zip(row_sizes.iter())
+            .any(|(total, size)| total + size > max_offset)
+        {
+            ranges.push(start..pos);
+            start = pos;
+            totals.fill(0);
+        }
+
+        for (total, size) in totals.iter_mut().zip(row_sizes.iter()) {
+            *total += *size;
+        }
+    }
+
+    if start < all_indices.len() {
+        ranges.push(start..all_indices.len());
+    }
+
+    Ok(ranges)
+}
+
+fn split_indices_by_row_count(
+    total_rows: usize,
+    max_rows_per_batch: usize,
+) -> Vec<Range<usize>> {
+    let mut ranges = Vec::new();
+    let mut start = 0;
+    let max_rows_per_batch = max_rows_per_batch.max(1);
+    while start < total_rows {
+        let end = (start + max_rows_per_batch).min(total_rows);
+        ranges.push(start..end);
+        start = end;
+    }
+    ranges
+}
+
+/// Recursively collect all variable-width leaf columns from `batch`, walking
+/// into `Struct` fields. Each returned `VarWidthColumn` stores the full index
+/// path needed to reach its array (top-level index, then zero or more struct
+/// child indices).
+fn collect_var_width_columns(batch: &RecordBatch) -> Vec<VarWidthColumn> {

Review Comment:
   Is this code covered by tests? I didn't see any tests that have Structs and 
yet this code seems to try and walk into Structs 🤔 



##########
datafusion/physical-plan/src/topk/mod.rs:
##########
@@ -749,47 +736,84 @@ impl TopKHeap {
     }
 
     /// Returns the values stored in this heap, from values low to
-    /// high, as a single [`RecordBatch`], resetting the inner heap
-    pub fn emit(&mut self) -> Result<Option<RecordBatch>> {
+    /// high, as [`RecordBatch`]es, resetting the inner heap
+    pub fn emit(&mut self) -> Result<Vec<RecordBatch>> {
         Ok(self.emit_with_state()?.0)
     }
 
     /// Returns the values stored in this heap, from values low to
-    /// high, as a single [`RecordBatch`], and a sorted vec of the
+    /// high, as [`RecordBatch`]es, and a sorted vec of the
     /// current heap's contents
-    pub fn emit_with_state(&mut self) -> Result<(Option<RecordBatch>, 
Vec<TopKRow>)> {
+    pub fn emit_with_state(&mut self) -> Result<(Vec<RecordBatch>, 
Vec<TopKRow>)> {
         // generate sorted rows
         let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec();
 
         if self.store.is_empty() {
-            return Ok((None, topk_rows));
+            return Ok((Vec::new(), topk_rows));
+        }
+
+        let batches = self.interleave_topk_rows(&topk_rows, self.batch_size)?;
+
+        Ok((batches, topk_rows))
+    }
+
+    fn interleave_topk_rows(
+        &self,
+        topk_rows: &[TopKRow],
+        max_rows_per_batch: usize,
+    ) -> Result<Vec<RecordBatch>> {
+        let (record_batches, batch_id_array_pos) = 
self.collect_record_batches();
+        let all_indices = self.collect_indices(topk_rows, 
&batch_id_array_pos)?;
+        let index_ranges = split_indices_by_i32_offsets(
+            &record_batches,
+            &all_indices,
+            max_rows_per_batch.max(1),
+            I32_OFFSET_LIMIT,
+        )?;
+
+        let mut batches = Vec::with_capacity(index_ranges.len());
+        for range in index_ranges {
+            let indices = &all_indices[range];
+            let batch = interleave_record_batch(&record_batches, indices)?;
+            batches.push(batch);
         }
 
-        // Collect the batches into a vec and store the "batch_id -> 
array_pos" mapping, to then
-        // build the `indices` vec below. This is needed since the batch ids 
are not continuous.
+        Ok(batches)
+    }
+
+    fn collect_record_batches(&self) -> (Vec<&RecordBatch>, HashMap<u32, 
usize>) {
+        // Collect the batches into a vec and store the "batch_id -> 
array_pos" mapping.
+        // This is needed since the batch ids are not continuous.
         let mut record_batches = Vec::new();
         let mut batch_id_array_pos = HashMap::new();
         for (array_pos, (batch_id, batch)) in 
self.store.batches.iter().enumerate() {
             record_batches.push(&batch.batch);
             batch_id_array_pos.insert(*batch_id, array_pos);
         }
+        (record_batches, batch_id_array_pos)
+    }
 
-        let indices: Vec<_> = topk_rows
+    fn collect_indices(

Review Comment:
   Please add comments about how to interpret the return type (e.g. what do the 
two `(usize, usize)`)  repreresent?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to