zhuqi-lucas commented on code in PR #17193: URL: https://github.com/apache/datafusion/pull/17193#discussion_r2280265535
########## datafusion/physical-plan/src/coalesce/mod.rs: ########## @@ -15,290 +15,158 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ - builder::StringViewBuilder, cast::AsArray, Array, ArrayRef, RecordBatch, - RecordBatchOptions, -}; -use arrow::compute::concat_batches; +use arrow::array::RecordBatch; +use arrow::compute::BatchCoalescer; use arrow::datatypes::SchemaRef; -use std::sync::Arc; +use datafusion_common::{internal_err, Result}; -/// Concatenate multiple [`RecordBatch`]es -/// -/// `BatchCoalescer` concatenates multiple small [`RecordBatch`]es, produced by -/// operations such as `FilterExec` and `RepartitionExec`, into larger ones for -/// more efficient processing by subsequent operations. -/// -/// # Background -/// -/// Generally speaking, larger [`RecordBatch`]es are more efficient to process -/// than smaller record batches (until the CPU cache is exceeded) because there -/// is fixed processing overhead per batch. DataFusion tries to operate on -/// batches of `target_batch_size` rows to amortize this overhead -/// -/// ```text -/// ┌────────────────────┐ -/// │ RecordBatch │ -/// │ num_rows = 23 │ -/// └────────────────────┘ ┌────────────────────┐ -/// │ │ -/// ┌────────────────────┐ Coalesce │ │ -/// │ │ Batches │ │ -/// │ RecordBatch │ │ │ -/// │ num_rows = 50 │ ─ ─ ─ ─ ─ ─ ▶ │ │ -/// │ │ │ RecordBatch │ -/// │ │ │ num_rows = 106 │ -/// └────────────────────┘ │ │ -/// │ │ -/// ┌────────────────────┐ │ │ -/// │ │ │ │ -/// │ RecordBatch │ │ │ -/// │ num_rows = 33 │ └────────────────────┘ -/// │ │ -/// └────────────────────┘ -/// ``` -/// -/// # Notes: -/// -/// 1. Output rows are produced in the same order as the input rows -/// -/// 2. The output is a sequence of batches, with all but the last being at least -/// `target_batch_size` rows. -/// -/// 3. Eventually this may also be able to handle other optimizations such as a -/// combined filter/coalesce operation. +/// Concatenate multiple [`RecordBatch`]es and apply a limit /// +/// See [`BatchCoalescer`] for more details on how this works. #[derive(Debug)] -pub struct BatchCoalescer { - /// The input schema - schema: SchemaRef, - /// Minimum number of rows for coalesces batches - target_batch_size: usize, +pub struct LimitedBatchCoalescer { + /// The arrow structure that builds the output batches + inner: BatchCoalescer, /// Total number of rows returned so far total_rows: usize, - /// Buffered batches - buffer: Vec<RecordBatch>, - /// Buffered row count - buffered_rows: usize, /// Limit: maximum number of rows to fetch, `None` means fetch all rows fetch: Option<usize>, + /// Indicates if the coalescer is finished + finished: bool, + /// The biggest size of the coalesced batch + biggest_coalesce_size: usize, +} + +/// Status returned by [`LimitedBatchCoalescer::push_batch`] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PushBatchStatus { + /// The limit has **not** been reached, and more batches can be pushed + Continue, + /// The limit **has** been reached after processing this batch + /// The caller should call [`LimitedBatchCoalescer::finish`] + /// to flush any buffered rows and stop pushing more batches. + LimitReached, } -impl BatchCoalescer { +impl LimitedBatchCoalescer { /// Create a new `BatchCoalescer` /// /// # Arguments /// - `schema` - the schema of the output batches /// - `target_batch_size` - the minimum number of rows for each /// output batch (until limit reached) /// - `fetch` - the maximum number of rows to fetch, `None` means fetch all rows + /// - `biggest_coalesce_size` - the max size of the batch to coalesce, now it's fixed to `target_batch_size / 2` pub fn new( schema: SchemaRef, target_batch_size: usize, fetch: Option<usize>, ) -> Self { Self { - schema, - target_batch_size, + inner: BatchCoalescer::new(schema, target_batch_size), total_rows: 0, - buffer: vec![], - buffered_rows: 0, fetch, + finished: false, + biggest_coalesce_size: target_batch_size / 2, } } /// Return the schema of the output batches pub fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) + self.inner.schema() } - /// Push next batch, and returns [`CoalescerState`] indicating the current - /// state of the buffer. - pub fn push_batch(&mut self, batch: RecordBatch) -> CoalescerState { - let batch = gc_string_view_batch(&batch); - if self.limit_reached(&batch) { - CoalescerState::LimitReached - } else if self.target_reached(batch) { - CoalescerState::TargetReached - } else { - CoalescerState::Continue + /// Pushes the next [`RecordBatch`] into the coalescer and returns its status. + /// + /// # Arguments + /// * `batch` - The [`RecordBatch`] to append. + /// + /// # Returns + /// * [`PushBatchStatus::Continue`] - More batches can still be pushed. + /// * [`PushBatchStatus::LimitReached`] - The row limit was reached after processing + /// this batch. The caller should call [`Self::finish`] before retrieving the + /// remaining buffered batches. + /// + /// # Errors + /// Returns an error if called after [`Self::finish`] or if the internal push + /// operation fails. + pub fn push_batch(&mut self, batch: RecordBatch) -> Result<PushBatchStatus> { + if self.finished { + return internal_err!( + "LimitedBatchCoalescer: cannot push batch after finish" + ); } - } - /// Return true if the there is no data buffered - pub fn is_empty(&self) -> bool { - self.buffer.is_empty() - } + // if we are at the limit, return LimitReached + if let Some(fetch) = self.fetch { + // limit previously reached + if self.total_rows >= fetch { + return Ok(PushBatchStatus::LimitReached); + } - /// Checks if the buffer will reach the specified limit after getting - /// `batch`. - /// - /// If fetch would be exceeded, slices the received batch, updates the - /// buffer with it, and returns `true`. - /// - /// Otherwise: does nothing and returns `false`. - fn limit_reached(&mut self, batch: &RecordBatch) -> bool { - match self.fetch { - Some(fetch) if self.total_rows + batch.num_rows() >= fetch => { + // limit now reached + if self.total_rows + batch.num_rows() >= fetch { // Limit is reached let remaining_rows = fetch - self.total_rows; debug_assert!(remaining_rows > 0); - let batch = batch.slice(0, remaining_rows); - self.buffered_rows += batch.num_rows(); - self.total_rows = fetch; - self.buffer.push(batch); - true + let batch_head = batch.slice(0, remaining_rows); + self.total_rows += batch_head.num_rows(); + self.inner.push_batch(batch_head)?; + return Ok(PushBatchStatus::LimitReached); } - _ => false, } - } - /// Updates the buffer with the given batch. - /// - /// If the target batch size is reached, returns `true`. Otherwise, returns - /// `false`. - fn target_reached(&mut self, batch: RecordBatch) -> bool { - if batch.num_rows() == 0 { - false + // Limit not reached, push the entire batch + self.total_rows += batch.num_rows(); + + if batch.num_rows() >= self.biggest_coalesce_size { Review Comment: Good suggestion! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org