alamb commented on code in PR #11610: URL: https://github.com/apache/datafusion/pull/11610#discussion_r1695138487
########## datafusion/physical-plan/src/coalesce_batches.rs: ########## @@ -246,243 +267,386 @@ impl CoalesceBatchesStream { let input_batch = self.input.poll_next_unpin(cx); // records time on drop let _timer = cloned_time.timer(); - match input_batch { - Poll::Ready(x) => match x { - Some(Ok(batch)) => { - // Handle fetch limit: - if let Some(fetch) = self.fetch { - if self.total_rows + batch.num_rows() >= fetch { - // We have reached the fetch limit. - let remaining_rows = fetch - self.total_rows; - debug_assert!(remaining_rows > 0); - + match ready!(input_batch) { + Some(result) => { + let Ok(input_batch) = result else { + return Poll::Ready(Some(result)); // pass back error + }; + // Buffer the batch and either get more input if not enough + // rows yet or output + match self.coalescer.push_batch(input_batch) { + Ok(None) => continue, + res => { + if self.coalescer.limit_reached() { self.is_closed = true; - self.total_rows = fetch; - // Trim the batch and add to buffered batches: - let batch = batch.slice(0, remaining_rows); - self.buffered_rows += batch.num_rows(); - self.buffer.push(batch); - // Combine buffered batches: - let batch = concat_batches(&self.schema, &self.buffer)?; - // Reset the buffer state and return final batch: - self.buffer.clear(); - self.buffered_rows = 0; - return Poll::Ready(Some(Ok(batch))); - } - } - self.total_rows += batch.num_rows(); - - if batch.num_rows() >= self.target_batch_size - && self.buffer.is_empty() - { - return Poll::Ready(Some(Ok(batch))); - } else if batch.num_rows() == 0 { - // discard empty batches - } else { - // add to the buffered batches - self.buffered_rows += batch.num_rows(); - self.buffer.push(batch); - // check to see if we have enough batches yet - if self.buffered_rows >= self.target_batch_size { - // combine the batches and return - let batch = concat_batches(&self.schema, &self.buffer)?; - // reset buffer state - self.buffer.clear(); - self.buffered_rows = 0; - // return batch - return Poll::Ready(Some(Ok(batch))); } + return Poll::Ready(res.transpose()); } } - None => { - self.is_closed = true; - // we have reached the end of the input stream but there could still - // be buffered batches - if self.buffer.is_empty() { - return Poll::Ready(None); - } else { - // combine the batches and return - let batch = concat_batches(&self.schema, &self.buffer)?; - // reset buffer state - self.buffer.clear(); - self.buffered_rows = 0; - // return batch - return Poll::Ready(Some(Ok(batch))); - } - } - other => return Poll::Ready(other), - }, - Poll::Pending => return Poll::Pending, + } + None => { + self.is_closed = true; + // we have reached the end of the input stream but there could still + // be buffered batches + return match self.coalescer.finish() { + Ok(None) => Poll::Ready(None), + res => Poll::Ready(res.transpose()), + }; + } } } } } impl RecordBatchStream for CoalesceBatchesStream { fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) + self.coalescer.schema() } } -#[cfg(test)] -mod tests { - use super::*; - use crate::{memory::MemoryExec, repartition::RepartitionExec, Partitioning}; +/// Concatenate multiple record batches into larger batches +/// +/// See [`CoalesceBatchesExec`] for more details. +/// +/// Notes: +/// +/// 1. The output rows is the same order as the input rows +/// +/// 2. The output is a sequence of batches, with all but the last being at least +/// `target_batch_size` rows. +/// +/// 3. Eventually this may also be able to handle other optimizations such as a +/// combined filter/coalesce operation. +#[derive(Debug)] +struct BatchCoalescer { + /// The input schema + schema: SchemaRef, + /// Minimum number of rows for coalesces batches + target_batch_size: usize, + /// Total number of rows returned so far + total_rows: usize, + /// Buffered batches + buffer: Vec<RecordBatch>, + /// Buffered row count + buffered_rows: usize, + /// Maximum number of rows to fetch, `None` means fetching all rows + fetch: Option<usize>, +} - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::UInt32Array; +impl BatchCoalescer { + /// Create a new `BatchCoalescer` + /// + /// # Arguments + /// - `schema` - the schema of the output batches + /// - `target_batch_size` - the minimum number of rows for each + /// output batch (until limit reached) + /// - `fetch` - the maximum number of rows to fetch, `None` means fetch all rows + fn new(schema: SchemaRef, target_batch_size: usize, fetch: Option<usize>) -> Self { + Self { + schema, + target_batch_size, + total_rows: 0, + buffer: vec![], + buffered_rows: 0, + fetch, + } + } - #[tokio::test(flavor = "multi_thread")] - async fn test_concat_batches() -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; + /// Return the schema of the output batches + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + /// Add a batch, returning a batch if the target batch size or limit is reached + fn push_batch(&mut self, batch: RecordBatch) -> Result<Option<RecordBatch>> { + // discard empty batches + if batch.num_rows() == 0 { + return Ok(None); + } - let output_partitions = coalesce_batches(&schema, partitions, 21, None).await?; - assert_eq!(1, output_partitions.len()); + // past limit + if self.limit_reached() { + return Ok(None); + } - // input is 10 batches x 8 rows (80 rows) - // expected output is batches of at least 20 rows (except for the final batch) - let batches = &output_partitions[0]; - assert_eq!(4, batches.len()); - assert_eq!(24, batches[0].num_rows()); - assert_eq!(24, batches[1].num_rows()); - assert_eq!(24, batches[2].num_rows()); - assert_eq!(8, batches[3].num_rows()); + // Handle fetch limit: + if let Some(fetch) = self.fetch { + if self.total_rows + batch.num_rows() >= fetch { + // We have reached the fetch limit. + let remaining_rows = fetch - self.total_rows; + debug_assert!(remaining_rows > 0); + self.total_rows = fetch; + // Trim the batch and add to buffered batches: + let batch = batch.slice(0, remaining_rows); + self.buffered_rows += batch.num_rows(); + self.buffer.push(batch); + // Combine buffered batches: + let batch = concat_batches(&self.schema, &self.buffer)?; + // Reset the buffer state and return final batch: + self.buffer.clear(); + self.buffered_rows = 0; + return Ok(Some(batch)); + } + } + self.total_rows += batch.num_rows(); - Ok(()) + // batch itself is already big enough and we have no buffered rows so + // return it directly + if batch.num_rows() >= self.target_batch_size && self.buffer.is_empty() { + return Ok(Some(batch)); + } + // add to the buffered batches + self.buffered_rows += batch.num_rows(); + self.buffer.push(batch); + // check to see if we have enough batches yet + let batch = if self.buffered_rows >= self.target_batch_size { + // combine the batches and return + let batch = concat_batches(&self.schema, &self.buffer)?; + // reset buffer state + self.buffer.clear(); + self.buffered_rows = 0; + // return batch + Some(batch) + } else { + None + }; + Ok(batch) } - #[tokio::test] - async fn test_concat_batches_with_fetch_larger_than_input_size() -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; + /// Finish the coalescing process, returning all buffered data as a final, + /// single batch, if any + fn finish(&mut self) -> Result<Option<RecordBatch>> { + if self.buffer.is_empty() { + Ok(None) + } else { + // combine the batches and return + let batch = concat_batches(&self.schema, &self.buffer)?; + // reset buffer state + self.buffer.clear(); + self.buffered_rows = 0; + // return batch + Ok(Some(batch)) + } + } - let output_partitions = - coalesce_batches(&schema, partitions, 21, Some(100)).await?; - assert_eq!(1, output_partitions.len()); + /// returns true if there is a limit and it has been reached + pub fn limit_reached(&self) -> bool { + if let Some(fetch) = self.fetch { + self.total_rows >= fetch + } else { + false + } + } +} - // input is 10 batches x 8 rows (80 rows) with fetch limit of 100 - // expected to behave the same as `test_concat_batches` - let batches = &output_partitions[0]; - assert_eq!(4, batches.len()); - assert_eq!(24, batches[0].num_rows()); - assert_eq!(24, batches[1].num_rows()); - assert_eq!(24, batches[2].num_rows()); - assert_eq!(8, batches[3].num_rows()); +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_array::UInt32Array; + use std::ops::Range; + + #[test] + fn test_coalesce() { + let batch = uint32_batch(0..8); + Test::new() Review Comment: I updated the tests to actually check the contents of the coalesced batches, in addition to checking the sizes. I also factored out much of the test repetition and and made it easier to understand I think ########## datafusion/physical-plan/src/coalesce_batches.rs: ########## @@ -246,243 +267,386 @@ impl CoalesceBatchesStream { let input_batch = self.input.poll_next_unpin(cx); // records time on drop let _timer = cloned_time.timer(); - match input_batch { - Poll::Ready(x) => match x { - Some(Ok(batch)) => { - // Handle fetch limit: - if let Some(fetch) = self.fetch { - if self.total_rows + batch.num_rows() >= fetch { - // We have reached the fetch limit. - let remaining_rows = fetch - self.total_rows; - debug_assert!(remaining_rows > 0); - + match ready!(input_batch) { + Some(result) => { + let Ok(input_batch) = result else { + return Poll::Ready(Some(result)); // pass back error + }; + // Buffer the batch and either get more input if not enough + // rows yet or output + match self.coalescer.push_batch(input_batch) { + Ok(None) => continue, + res => { + if self.coalescer.limit_reached() { self.is_closed = true; - self.total_rows = fetch; - // Trim the batch and add to buffered batches: - let batch = batch.slice(0, remaining_rows); - self.buffered_rows += batch.num_rows(); - self.buffer.push(batch); - // Combine buffered batches: - let batch = concat_batches(&self.schema, &self.buffer)?; - // Reset the buffer state and return final batch: - self.buffer.clear(); - self.buffered_rows = 0; - return Poll::Ready(Some(Ok(batch))); - } - } - self.total_rows += batch.num_rows(); - - if batch.num_rows() >= self.target_batch_size - && self.buffer.is_empty() - { - return Poll::Ready(Some(Ok(batch))); - } else if batch.num_rows() == 0 { - // discard empty batches - } else { - // add to the buffered batches - self.buffered_rows += batch.num_rows(); - self.buffer.push(batch); - // check to see if we have enough batches yet - if self.buffered_rows >= self.target_batch_size { - // combine the batches and return - let batch = concat_batches(&self.schema, &self.buffer)?; - // reset buffer state - self.buffer.clear(); - self.buffered_rows = 0; - // return batch - return Poll::Ready(Some(Ok(batch))); } + return Poll::Ready(res.transpose()); } } - None => { - self.is_closed = true; - // we have reached the end of the input stream but there could still - // be buffered batches - if self.buffer.is_empty() { - return Poll::Ready(None); - } else { - // combine the batches and return - let batch = concat_batches(&self.schema, &self.buffer)?; - // reset buffer state - self.buffer.clear(); - self.buffered_rows = 0; - // return batch - return Poll::Ready(Some(Ok(batch))); - } - } - other => return Poll::Ready(other), - }, - Poll::Pending => return Poll::Pending, + } + None => { + self.is_closed = true; + // we have reached the end of the input stream but there could still + // be buffered batches + return match self.coalescer.finish() { + Ok(None) => Poll::Ready(None), + res => Poll::Ready(res.transpose()), + }; + } } } } } impl RecordBatchStream for CoalesceBatchesStream { fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) + self.coalescer.schema() } } -#[cfg(test)] -mod tests { - use super::*; - use crate::{memory::MemoryExec, repartition::RepartitionExec, Partitioning}; +/// Concatenate multiple record batches into larger batches +/// +/// See [`CoalesceBatchesExec`] for more details. +/// +/// Notes: +/// +/// 1. The output rows is the same order as the input rows +/// +/// 2. The output is a sequence of batches, with all but the last being at least +/// `target_batch_size` rows. +/// +/// 3. Eventually this may also be able to handle other optimizations such as a +/// combined filter/coalesce operation. +#[derive(Debug)] +struct BatchCoalescer { + /// The input schema + schema: SchemaRef, + /// Minimum number of rows for coalesces batches + target_batch_size: usize, + /// Total number of rows returned so far + total_rows: usize, + /// Buffered batches + buffer: Vec<RecordBatch>, + /// Buffered row count + buffered_rows: usize, + /// Maximum number of rows to fetch, `None` means fetching all rows + fetch: Option<usize>, +} - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::UInt32Array; +impl BatchCoalescer { + /// Create a new `BatchCoalescer` + /// + /// # Arguments + /// - `schema` - the schema of the output batches + /// - `target_batch_size` - the minimum number of rows for each + /// output batch (until limit reached) + /// - `fetch` - the maximum number of rows to fetch, `None` means fetch all rows + fn new(schema: SchemaRef, target_batch_size: usize, fetch: Option<usize>) -> Self { + Self { + schema, + target_batch_size, + total_rows: 0, + buffer: vec![], + buffered_rows: 0, + fetch, + } + } - #[tokio::test(flavor = "multi_thread")] - async fn test_concat_batches() -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; + /// Return the schema of the output batches + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + /// Add a batch, returning a batch if the target batch size or limit is reached + fn push_batch(&mut self, batch: RecordBatch) -> Result<Option<RecordBatch>> { + // discard empty batches + if batch.num_rows() == 0 { + return Ok(None); + } - let output_partitions = coalesce_batches(&schema, partitions, 21, None).await?; - assert_eq!(1, output_partitions.len()); + // past limit + if self.limit_reached() { + return Ok(None); + } - // input is 10 batches x 8 rows (80 rows) - // expected output is batches of at least 20 rows (except for the final batch) - let batches = &output_partitions[0]; - assert_eq!(4, batches.len()); - assert_eq!(24, batches[0].num_rows()); - assert_eq!(24, batches[1].num_rows()); - assert_eq!(24, batches[2].num_rows()); - assert_eq!(8, batches[3].num_rows()); + // Handle fetch limit: Review Comment: I just moved this code out of the main `CoalesceBatchesExec` poll loop and into a function. Other than the `is_closed()` handling this is literally a copy/paste -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org