This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 5ca4ec3b59 Extract CoalesceBatchesStream to a struct (#11610)
5ca4ec3b59 is described below

commit 5ca4ec3b59044f08a7b5487de2d146e1b9b3bd29
Author: Andrew Lamb <[email protected]>
AuthorDate: Fri Aug 2 08:05:48 2024 -0400

    Extract CoalesceBatchesStream to a struct (#11610)
---
 datafusion/physical-plan/src/coalesce_batches.rs | 602 ++++++++++++++---------
 1 file changed, 382 insertions(+), 220 deletions(-)

diff --git a/datafusion/physical-plan/src/coalesce_batches.rs 
b/datafusion/physical-plan/src/coalesce_batches.rs
index 038727daa7..b822ec2daf 100644
--- a/datafusion/physical-plan/src/coalesce_batches.rs
+++ b/datafusion/physical-plan/src/coalesce_batches.rs
@@ -15,13 +15,12 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! CoalesceBatchesExec combines small batches into larger batches for more 
efficient use of
-//! vectorized processing by upstream operators.
+//! [`CoalesceBatchesExec`] combines small batches into larger batches.
 
 use std::any::Any;
 use std::pin::Pin;
 use std::sync::Arc;
-use std::task::{Context, Poll};
+use std::task::{ready, Context, Poll};
 
 use arrow::array::{AsArray, StringViewBuilder};
 use arrow::compute::concat_batches;
@@ -41,11 +40,43 @@ use super::metrics::{BaselineMetrics, 
ExecutionPlanMetricsSet, MetricsSet};
 use super::{DisplayAs, ExecutionPlanProperties, PlanProperties, Statistics};
 
 /// `CoalesceBatchesExec` combines small batches into larger batches for more
-/// efficient use of vectorized processing by later operators. The operator
-/// works by buffering batches until it collects `target_batch_size` rows. When
-/// only a limited number of rows are necessary (specified by the `fetch`
-/// parameter), the operator will stop buffering and return the final batch
-/// once the number of collected rows reaches the `fetch` value.
+/// efficient use of vectorized processing by later operators.
+///
+/// The operator buffers batches until it collects `target_batch_size` rows and
+/// then emits a single concatenated batch. When only a limited number of rows
+/// are necessary (specified by the `fetch` parameter), the operator will stop
+/// buffering and returns the final batch once the number of collected rows
+/// reaches the `fetch` value.
+///
+/// # Background
+///
+/// Generally speaking, larger RecordBatches are more efficient to process than
+/// smaller record batches (until the CPU cache is exceeded) because there is
+/// fixed processing overhead per batch. This code concatenates multiple small
+/// record batches into larger ones to amortize this overhead.
+///
+/// ```text
+/// ┌────────────────────┐
+/// │    RecordBatch     │
+/// │   num_rows = 23    │
+/// └────────────────────┘                 ┌────────────────────┐
+///                                        │                    │
+/// ┌────────────────────┐     Coalesce    │                    │
+/// │                    │      Batches    │                    │
+/// │    RecordBatch     │                 │                    │
+/// │   num_rows = 50    │  ─ ─ ─ ─ ─ ─ ▶  │                    │
+/// │                    │                 │    RecordBatch     │
+/// │                    │                 │   num_rows = 106   │
+/// └────────────────────┘                 │                    │
+///                                        │                    │
+/// ┌────────────────────┐                 │                    │
+/// │                    │                 │                    │
+/// │    RecordBatch     │                 │                    │
+/// │   num_rows = 33    │                 └────────────────────┘
+/// │                    │
+/// └────────────────────┘
+/// ```
+
 #[derive(Debug)]
 pub struct CoalesceBatchesExec {
     /// The input plan
@@ -166,12 +197,11 @@ impl ExecutionPlan for CoalesceBatchesExec {
     ) -> Result<SendableRecordBatchStream> {
         Ok(Box::pin(CoalesceBatchesStream {
             input: self.input.execute(partition, context)?,
-            schema: self.input.schema(),
-            target_batch_size: self.target_batch_size,
-            fetch: self.fetch,
-            buffer: Vec::new(),
-            buffered_rows: 0,
-            total_rows: 0,
+            coalescer: BatchCoalescer::new(
+                self.input.schema(),
+                self.target_batch_size,
+                self.fetch,
+            ),
             is_closed: false,
             baseline_metrics: BaselineMetrics::new(&self.metrics, partition),
         }))
@@ -196,21 +226,12 @@ impl ExecutionPlan for CoalesceBatchesExec {
     }
 }
 
+/// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more 
details.
 struct CoalesceBatchesStream {
     /// The input plan
     input: SendableRecordBatchStream,
-    /// The input schema
-    schema: SchemaRef,
-    /// Minimum number of rows for coalesces batches
-    target_batch_size: usize,
-    /// Maximum number of rows to fetch, `None` means fetching all rows
-    fetch: Option<usize>,
-    /// Buffered batches
-    buffer: Vec<RecordBatch>,
-    /// Buffered row count
-    buffered_rows: usize,
-    /// Total number of rows returned
-    total_rows: usize,
+    /// Buffer for combining batches
+    coalescer: BatchCoalescer,
     /// Whether the stream has finished returning all of its data or not
     is_closed: bool,
     /// Execution metrics
@@ -249,84 +270,178 @@ impl CoalesceBatchesStream {
             let input_batch = self.input.poll_next_unpin(cx);
             // records time on drop
             let _timer = cloned_time.timer();
-            match input_batch {
-                Poll::Ready(x) => match x {
-                    Some(Ok(batch)) => {
-                        let batch = gc_string_view_batch(&batch);
-
-                        // Handle fetch limit:
-                        if let Some(fetch) = self.fetch {
-                            if self.total_rows + batch.num_rows() >= fetch {
-                                // We have reached the fetch limit.
-                                let remaining_rows = fetch - self.total_rows;
-                                debug_assert!(remaining_rows > 0);
-
+            match ready!(input_batch) {
+                Some(result) => {
+                    let Ok(input_batch) = result else {
+                        return Poll::Ready(Some(result)); // pass back error
+                    };
+                    // Buffer the batch and either get more input if not enough
+                    // rows yet or output
+                    match self.coalescer.push_batch(input_batch) {
+                        Ok(None) => continue,
+                        res => {
+                            if self.coalescer.limit_reached() {
                                 self.is_closed = true;
-                                self.total_rows = fetch;
-                                // Trim the batch and add to buffered batches:
-                                let batch = batch.slice(0, remaining_rows);
-                                self.buffered_rows += batch.num_rows();
-                                self.buffer.push(batch);
-                                // Combine buffered batches:
-                                let batch = concat_batches(&self.schema, 
&self.buffer)?;
-                                // Reset the buffer state and return final 
batch:
-                                self.buffer.clear();
-                                self.buffered_rows = 0;
-                                return Poll::Ready(Some(Ok(batch)));
-                            }
-                        }
-                        self.total_rows += batch.num_rows();
-
-                        if batch.num_rows() >= self.target_batch_size
-                            && self.buffer.is_empty()
-                        {
-                            return Poll::Ready(Some(Ok(batch)));
-                        } else if batch.num_rows() == 0 {
-                            // discard empty batches
-                        } else {
-                            // add to the buffered batches
-                            self.buffered_rows += batch.num_rows();
-                            self.buffer.push(batch);
-                            // check to see if we have enough batches yet
-                            if self.buffered_rows >= self.target_batch_size {
-                                // combine the batches and return
-                                let batch = concat_batches(&self.schema, 
&self.buffer)?;
-                                // reset buffer state
-                                self.buffer.clear();
-                                self.buffered_rows = 0;
-                                // return batch
-                                return Poll::Ready(Some(Ok(batch)));
                             }
+                            return Poll::Ready(res.transpose());
                         }
                     }
-                    None => {
-                        self.is_closed = true;
-                        // we have reached the end of the input stream but 
there could still
-                        // be buffered batches
-                        if self.buffer.is_empty() {
-                            return Poll::Ready(None);
-                        } else {
-                            // combine the batches and return
-                            let batch = concat_batches(&self.schema, 
&self.buffer)?;
-                            // reset buffer state
-                            self.buffer.clear();
-                            self.buffered_rows = 0;
-                            // return batch
-                            return Poll::Ready(Some(Ok(batch)));
-                        }
-                    }
-                    other => return Poll::Ready(other),
-                },
-                Poll::Pending => return Poll::Pending,
+                }
+                None => {
+                    self.is_closed = true;
+                    // we have reached the end of the input stream but there 
could still
+                    // be buffered batches
+                    return match self.coalescer.finish() {
+                        Ok(None) => Poll::Ready(None),
+                        res => Poll::Ready(res.transpose()),
+                    };
+                }
             }
         }
     }
 }
 
 impl RecordBatchStream for CoalesceBatchesStream {
+    fn schema(&self) -> SchemaRef {
+        self.coalescer.schema()
+    }
+}
+
+/// Concatenate multiple record batches into larger batches
+///
+/// See [`CoalesceBatchesExec`] for more details.
+///
+/// Notes:
+///
+/// 1. The output rows is the same order as the input rows
+///
+/// 2. The output is a sequence of batches, with all but the last being at 
least
+///    `target_batch_size` rows.
+///
+/// 3. Eventually this may also be able to handle other optimizations such as a
+///    combined filter/coalesce operation.
+#[derive(Debug)]
+struct BatchCoalescer {
+    /// The input schema
+    schema: SchemaRef,
+    /// Minimum number of rows for coalesces batches
+    target_batch_size: usize,
+    /// Total number of rows returned so far
+    total_rows: usize,
+    /// Buffered batches
+    buffer: Vec<RecordBatch>,
+    /// Buffered row count
+    buffered_rows: usize,
+    /// Maximum number of rows to fetch, `None` means fetching all rows
+    fetch: Option<usize>,
+}
+
+impl BatchCoalescer {
+    /// Create a new `BatchCoalescer`
+    ///
+    /// # Arguments
+    /// - `schema` - the schema of the output batches
+    /// - `target_batch_size` - the minimum number of rows for each
+    ///    output batch (until limit reached)
+    /// - `fetch` - the maximum number of rows to fetch, `None` means fetch 
all rows
+    fn new(schema: SchemaRef, target_batch_size: usize, fetch: Option<usize>) 
-> Self {
+        Self {
+            schema,
+            target_batch_size,
+            total_rows: 0,
+            buffer: vec![],
+            buffered_rows: 0,
+            fetch,
+        }
+    }
+
+    /// Return the schema of the output batches
     fn schema(&self) -> SchemaRef {
         Arc::clone(&self.schema)
     }
+
+    /// Add a batch, returning a batch if the target batch size or limit is 
reached
+    fn push_batch(&mut self, batch: RecordBatch) -> 
Result<Option<RecordBatch>> {
+        // discard empty batches
+        if batch.num_rows() == 0 {
+            return Ok(None);
+        }
+
+        // past limit
+        if self.limit_reached() {
+            return Ok(None);
+        }
+
+        let batch = gc_string_view_batch(&batch);
+
+        // Handle fetch limit:
+        if let Some(fetch) = self.fetch {
+            if self.total_rows + batch.num_rows() >= fetch {
+                // We have reached the fetch limit.
+                let remaining_rows = fetch - self.total_rows;
+                debug_assert!(remaining_rows > 0);
+                self.total_rows = fetch;
+                // Trim the batch and add to buffered batches:
+                let batch = batch.slice(0, remaining_rows);
+                self.buffered_rows += batch.num_rows();
+                self.buffer.push(batch);
+                // Combine buffered batches:
+                let batch = concat_batches(&self.schema, &self.buffer)?;
+                // Reset the buffer state and return final batch:
+                self.buffer.clear();
+                self.buffered_rows = 0;
+                return Ok(Some(batch));
+            }
+        }
+        self.total_rows += batch.num_rows();
+
+        // batch itself is already big enough and we have no buffered rows so
+        // return it directly
+        if batch.num_rows() >= self.target_batch_size && 
self.buffer.is_empty() {
+            return Ok(Some(batch));
+        }
+        // add to the buffered batches
+        self.buffered_rows += batch.num_rows();
+        self.buffer.push(batch);
+        // check to see if we have enough batches yet
+        let batch = if self.buffered_rows >= self.target_batch_size {
+            // combine the batches and return
+            let batch = concat_batches(&self.schema, &self.buffer)?;
+            // reset buffer state
+            self.buffer.clear();
+            self.buffered_rows = 0;
+            // return batch
+            Some(batch)
+        } else {
+            None
+        };
+        Ok(batch)
+    }
+
+    /// Finish the coalescing process, returning all buffered data as a final,
+    /// single batch, if any
+    fn finish(&mut self) -> Result<Option<RecordBatch>> {
+        if self.buffer.is_empty() {
+            Ok(None)
+        } else {
+            // combine the batches and return
+            let batch = concat_batches(&self.schema, &self.buffer)?;
+            // reset buffer state
+            self.buffer.clear();
+            self.buffered_rows = 0;
+            // return batch
+            Ok(Some(batch))
+        }
+    }
+
+    /// returns true if there is a limit and it has been reached
+    pub fn limit_reached(&self) -> bool {
+        if let Some(fetch) = self.fetch {
+            self.total_rows >= fetch
+        } else {
+            false
+        }
+    }
 }
 
 /// Heuristically compact `StringViewArray`s to reduce memory usage, if needed
@@ -400,164 +515,206 @@ fn gc_string_view_batch(batch: &RecordBatch) -> 
RecordBatch {
 
 #[cfg(test)]
 mod tests {
+    use super::*;
     use arrow::datatypes::{DataType, Field, Schema};
     use arrow_array::builder::ArrayBuilder;
     use arrow_array::{StringViewArray, UInt32Array};
+    use std::ops::Range;
 
-    use crate::{memory::MemoryExec, repartition::RepartitionExec, 
Partitioning};
-
-    use super::*;
-
-    #[tokio::test(flavor = "multi_thread")]
-    async fn test_concat_batches() -> Result<()> {
-        let schema = test_schema();
-        let partition = create_vec_batches(&schema, 10);
-        let partitions = vec![partition];
-
-        let output_partitions = coalesce_batches(&schema, partitions, 21, 
None).await?;
-        assert_eq!(1, output_partitions.len());
-
-        // input is 10 batches x 8 rows (80 rows)
-        // expected output is batches of at least 20 rows (except for the 
final batch)
-        let batches = &output_partitions[0];
-        assert_eq!(4, batches.len());
-        assert_eq!(24, batches[0].num_rows());
-        assert_eq!(24, batches[1].num_rows());
-        assert_eq!(24, batches[2].num_rows());
-        assert_eq!(8, batches[3].num_rows());
-
-        Ok(())
+    #[test]
+    fn test_coalesce() {
+        let batch = uint32_batch(0..8);
+        Test::new()
+            .with_batches(std::iter::repeat(batch).take(10))
+            // expected output is batches of at least 20 rows (except for the 
final batch)
+            .with_target_batch_size(21)
+            .with_expected_output_sizes(vec![24, 24, 24, 8])
+            .run()
     }
 
-    #[tokio::test]
-    async fn test_concat_batches_with_fetch_larger_than_input_size() -> 
Result<()> {
-        let schema = test_schema();
-        let partition = create_vec_batches(&schema, 10);
-        let partitions = vec![partition];
-
-        let output_partitions =
-            coalesce_batches(&schema, partitions, 21, Some(100)).await?;
-        assert_eq!(1, output_partitions.len());
+    #[test]
+    fn test_coalesce_with_fetch_larger_than_input_size() {
+        let batch = uint32_batch(0..8);
+        Test::new()
+            .with_batches(std::iter::repeat(batch).take(10))
+            // input is 10 batches x 8 rows (80 rows) with fetch limit of 100
+            // expected to behave the same as `test_concat_batches`
+            .with_target_batch_size(21)
+            .with_fetch(Some(100))
+            .with_expected_output_sizes(vec![24, 24, 24, 8])
+            .run();
+    }
 
-        // input is 10 batches x 8 rows (80 rows) with fetch limit of 100
-        // expected to behave the same as `test_concat_batches`
-        let batches = &output_partitions[0];
-        assert_eq!(4, batches.len());
-        assert_eq!(24, batches[0].num_rows());
-        assert_eq!(24, batches[1].num_rows());
-        assert_eq!(24, batches[2].num_rows());
-        assert_eq!(8, batches[3].num_rows());
+    #[test]
+    fn test_coalesce_with_fetch_less_than_input_size() {
+        let batch = uint32_batch(0..8);
+        Test::new()
+            .with_batches(std::iter::repeat(batch).take(10))
+            // input is 10 batches x 8 rows (80 rows) with fetch limit of 50
+            .with_target_batch_size(21)
+            .with_fetch(Some(50))
+            .with_expected_output_sizes(vec![24, 24, 2])
+            .run();
+    }
 
-        Ok(())
+    #[test]
+    fn test_coalesce_with_fetch_less_than_target_and_no_remaining_rows() {
+        let batch = uint32_batch(0..8);
+        Test::new()
+            .with_batches(std::iter::repeat(batch).take(10))
+            // input is 10 batches x 8 rows (80 rows) with fetch limit of 48
+            .with_target_batch_size(21)
+            .with_fetch(Some(48))
+            .with_expected_output_sizes(vec![24, 24])
+            .run();
     }
 
-    #[tokio::test]
-    async fn test_concat_batches_with_fetch_less_than_input_size() -> 
Result<()> {
-        let schema = test_schema();
-        let partition = create_vec_batches(&schema, 10);
-        let partitions = vec![partition];
+    #[test]
+    fn test_coalesce_with_fetch_less_target_batch_size() {
+        let batch = uint32_batch(0..8);
+        Test::new()
+            .with_batches(std::iter::repeat(batch).take(10))
+            // input is 10 batches x 8 rows (80 rows) with fetch limit of 10
+            .with_target_batch_size(21)
+            .with_fetch(Some(10))
+            .with_expected_output_sizes(vec![10])
+            .run();
+    }
 
-        let output_partitions =
-            coalesce_batches(&schema, partitions, 21, Some(50)).await?;
-        assert_eq!(1, output_partitions.len());
+    #[test]
+    fn test_coalesce_single_large_batch_over_fetch() {
+        let large_batch = uint32_batch(0..100);
+        Test::new()
+            .with_batch(large_batch)
+            .with_target_batch_size(20)
+            .with_fetch(Some(7))
+            .with_expected_output_sizes(vec![7])
+            .run()
+    }
+
+    /// Test for [`BatchCoalescer`]
+    ///
+    /// Pushes the input batches to the coalescer and verifies that the 
resulting
+    /// batches have the expected number of rows and contents.
+    #[derive(Debug, Clone, Default)]
+    struct Test {
+        /// Batches to feed to the coalescer. Tests must have at least one
+        /// schema
+        input_batches: Vec<RecordBatch>,
+        /// Expected output sizes of the resulting batches
+        expected_output_sizes: Vec<usize>,
+        /// target batch size
+        target_batch_size: usize,
+        /// Fetch (limit)
+        fetch: Option<usize>,
+    }
 
-        // input is 10 batches x 8 rows (80 rows) with fetch limit of 50
-        let batches = &output_partitions[0];
-        assert_eq!(3, batches.len());
-        assert_eq!(24, batches[0].num_rows());
-        assert_eq!(24, batches[1].num_rows());
-        assert_eq!(2, batches[2].num_rows());
+    impl Test {
+        fn new() -> Self {
+            Self::default()
+        }
 
-        Ok(())
-    }
+        /// Set the target batch size
+        fn with_target_batch_size(mut self, target_batch_size: usize) -> Self {
+            self.target_batch_size = target_batch_size;
+            self
+        }
 
-    #[tokio::test]
-    async fn 
test_concat_batches_with_fetch_less_than_target_and_no_remaining_rows(
-    ) -> Result<()> {
-        let schema = test_schema();
-        let partition = create_vec_batches(&schema, 10);
-        let partitions = vec![partition];
+        /// Set the fetch (limit)
+        fn with_fetch(mut self, fetch: Option<usize>) -> Self {
+            self.fetch = fetch;
+            self
+        }
 
-        let output_partitions =
-            coalesce_batches(&schema, partitions, 21, Some(48)).await?;
-        assert_eq!(1, output_partitions.len());
+        /// Extend the input batches with `batch`
+        fn with_batch(mut self, batch: RecordBatch) -> Self {
+            self.input_batches.push(batch);
+            self
+        }
 
-        // input is 10 batches x 8 rows (80 rows) with fetch limit of 48
-        let batches = &output_partitions[0];
-        assert_eq!(2, batches.len());
-        assert_eq!(24, batches[0].num_rows());
-        assert_eq!(24, batches[1].num_rows());
+        /// Extends the input batches with `batches`
+        fn with_batches(
+            mut self,
+            batches: impl IntoIterator<Item = RecordBatch>,
+        ) -> Self {
+            self.input_batches.extend(batches);
+            self
+        }
 
-        Ok(())
-    }
+        /// Extends `sizes` to expected output sizes
+        fn with_expected_output_sizes(
+            mut self,
+            sizes: impl IntoIterator<Item = usize>,
+        ) -> Self {
+            self.expected_output_sizes.extend(sizes);
+            self
+        }
 
-    #[tokio::test]
-    async fn test_concat_batches_with_fetch_less_target_batch_size() -> 
Result<()> {
-        let schema = test_schema();
-        let partition = create_vec_batches(&schema, 10);
-        let partitions = vec![partition];
+        /// Runs the test -- see documentation on [`Test`] for details
+        fn run(self) {
+            let Self {
+                input_batches,
+                target_batch_size,
+                fetch,
+                expected_output_sizes,
+            } = self;
 
-        let output_partitions =
-            coalesce_batches(&schema, partitions, 21, Some(10)).await?;
-        assert_eq!(1, output_partitions.len());
+            let schema = input_batches[0].schema();
 
-        // input is 10 batches x 8 rows (80 rows) with fetch limit of 10
-        let batches = &output_partitions[0];
-        assert_eq!(1, batches.len());
-        assert_eq!(10, batches[0].num_rows());
+            // create a single large input batch for output comparison
+            let single_input_batch = concat_batches(&schema, 
&input_batches).unwrap();
 
-        Ok(())
-    }
+            let mut coalescer = BatchCoalescer::new(schema, target_batch_size, 
fetch);
 
-    fn test_schema() -> Arc<Schema> {
-        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
-    }
+            let mut output_batches = vec![];
+            for batch in input_batches {
+                if let Some(batch) = coalescer.push_batch(batch).unwrap() {
+                    output_batches.push(batch);
+                }
+            }
+            if let Some(batch) = coalescer.finish().unwrap() {
+                output_batches.push(batch);
+            }
 
-    async fn coalesce_batches(
-        schema: &SchemaRef,
-        input_partitions: Vec<Vec<RecordBatch>>,
-        target_batch_size: usize,
-        fetch: Option<usize>,
-    ) -> Result<Vec<Vec<RecordBatch>>> {
-        // create physical plan
-        let exec = MemoryExec::try_new(&input_partitions, Arc::clone(schema), 
None)?;
-        let exec =
-            RepartitionExec::try_new(Arc::new(exec), 
Partitioning::RoundRobinBatch(1))?;
-        let exec: Arc<dyn ExecutionPlan> = Arc::new(
-            CoalesceBatchesExec::new(Arc::new(exec), 
target_batch_size).with_fetch(fetch),
-        );
-
-        // execute and collect results
-        let output_partition_count = 
exec.output_partitioning().partition_count();
-        let mut output_partitions = Vec::with_capacity(output_partition_count);
-        for i in 0..output_partition_count {
-            // execute this *output* partition and collect all batches
-            let task_ctx = Arc::new(TaskContext::default());
-            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
-            let mut batches = vec![];
-            while let Some(result) = stream.next().await {
-                batches.push(result?);
+            // make sure we got the expected number of output batches and 
content
+            let mut starting_idx = 0;
+            assert_eq!(expected_output_sizes.len(), output_batches.len());
+            for (i, (expected_size, batch)) in
+                expected_output_sizes.iter().zip(output_batches).enumerate()
+            {
+                assert_eq!(
+                    *expected_size,
+                    batch.num_rows(),
+                    "Unexpected number of rows in Batch {i}"
+                );
+
+                // compare the contents of the batch (using `==` compares the
+                // underlying memory layout too)
+                let expected_batch =
+                    single_input_batch.slice(starting_idx, *expected_size);
+                let batch_strings = batch_to_pretty_strings(&batch);
+                let expected_batch_strings = 
batch_to_pretty_strings(&expected_batch);
+                let batch_strings = batch_strings.lines().collect::<Vec<_>>();
+                let expected_batch_strings =
+                    expected_batch_strings.lines().collect::<Vec<_>>();
+                assert_eq!(
+                    expected_batch_strings, batch_strings,
+                    "Unexpected content in Batch {i}:\
+                    
\n\nExpected:\n{expected_batch_strings:#?}\n\nActual:\n{batch_strings:#?}"
+                );
+                starting_idx += *expected_size;
             }
-            output_partitions.push(batches);
         }
-        Ok(output_partitions)
     }
 
-    /// Create vector batches
-    fn create_vec_batches(schema: &Schema, n: usize) -> Vec<RecordBatch> {
-        let batch = create_batch(schema);
-        let mut vec = Vec::with_capacity(n);
-        for _ in 0..n {
-            vec.push(batch.clone());
-        }
-        vec
-    }
+    /// Return a batch of  UInt32 with the specified range
+    fn uint32_batch(range: Range<u32>) -> RecordBatch {
+        let schema =
+            Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, 
false)]));
 
-    /// Create batch
-    fn create_batch(schema: &Schema) -> RecordBatch {
         RecordBatch::try_new(
-            Arc::new(schema.clone()),
-            vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
+            Arc::clone(&schema),
+            vec![Arc::new(UInt32Array::from_iter_values(range))],
         )
         .unwrap()
     }
@@ -656,4 +813,9 @@ mod tests {
             }
         }
     }
+    fn batch_to_pretty_strings(batch: &RecordBatch) -> String {
+        arrow::util::pretty::pretty_format_batches(&[batch.clone()])
+            .unwrap()
+            .to_string()
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to