This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 2266474 Fix bug while merging `RecordBatch`, add
`SortPreservingMerge` fuzz tester (#1678)
2266474 is described below
commit 2266474ff5312e822153fefd7d7eddd1ca66c136
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Jan 27 06:10:33 2022 -0500
Fix bug while merging `RecordBatch`, add `SortPreservingMerge` fuzz tester
(#1678)
* skip empty batch while inserting
* `SortPreservingMerge` fuzz testing
Co-authored-by: Yijie Shen <[email protected]>
---
.../physical_plan/sorts/sort_preserving_merge.rs | 69 ++++---
datafusion/tests/merge_fuzz.rs | 223 +++++++++++++++++++++
2 files changed, 264 insertions(+), 28 deletions(-)
diff --git a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
index d6a5787..f950526 100644
--- a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -410,40 +410,53 @@ impl SortPreservingMergeStream {
// Cursor is not finished - don't need a new RecordBatch yet
return Poll::Ready(Ok(()));
}
- let mut streams = self.streams.streams.lock().unwrap();
+ let mut empty_batch = false;
+ {
+ let mut streams = self.streams.streams.lock().unwrap();
- let stream = &mut streams[idx];
- if stream.is_terminated() {
- return Poll::Ready(Ok(()));
- }
-
- // Fetch a new input record and create a cursor from it
- match futures::ready!(stream.poll_next_unpin(cx)) {
- None => return Poll::Ready(Ok(())),
- Some(Err(e)) => {
- return Poll::Ready(Err(e));
+ let stream = &mut streams[idx];
+ if stream.is_terminated() {
+ return Poll::Ready(Ok(()));
}
- Some(Ok(batch)) => {
- let cursor = match SortKeyCursor::new(
- idx,
- self.next_batch_id, // assign this batch an ID
- &batch,
- &self.column_expressions,
- self.sort_options.clone(),
- ) {
- Ok(cursor) => cursor,
- Err(e) => {
- return
Poll::Ready(Err(ArrowError::ExternalError(Box::new(e))));
+
+ // Fetch a new input record and create a cursor from it
+ match futures::ready!(stream.poll_next_unpin(cx)) {
+ None => return Poll::Ready(Ok(())),
+ Some(Err(e)) => {
+ return Poll::Ready(Err(e));
+ }
+ Some(Ok(batch)) => {
+ if batch.num_rows() > 0 {
+ let cursor = match SortKeyCursor::new(
+ idx,
+ self.next_batch_id, // assign this batch an ID
+ &batch,
+ &self.column_expressions,
+ self.sort_options.clone(),
+ ) {
+ Ok(cursor) => cursor,
+ Err(e) => {
+ return
Poll::Ready(Err(ArrowError::ExternalError(
+ Box::new(e),
+ )));
+ }
+ };
+ self.next_batch_id += 1;
+ self.min_heap.push(cursor);
+ self.cursor_finished[idx] = false;
+ self.batches[idx].push_back(batch)
+ } else {
+ empty_batch = true;
}
- };
- self.next_batch_id += 1;
- self.min_heap.push(cursor);
- self.cursor_finished[idx] = false;
- self.batches[idx].push_back(batch)
+ }
}
}
- Poll::Ready(Ok(()))
+ if empty_batch {
+ self.maybe_poll_stream(cx, idx)
+ } else {
+ Poll::Ready(Ok(()))
+ }
}
/// Drains the in_progress row indexes, and builds a new RecordBatch from
them
diff --git a/datafusion/tests/merge_fuzz.rs b/datafusion/tests/merge_fuzz.rs
new file mode 100644
index 0000000..8192054
--- /dev/null
+++ b/datafusion/tests/merge_fuzz.rs
@@ -0,0 +1,223 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Fuzz Test for various corner cases merging streams of RecordBatchs
+use std::sync::Arc;
+
+use arrow::{
+ array::{ArrayRef, Int32Array},
+ compute::SortOptions,
+ record_batch::RecordBatch,
+};
+use datafusion::{
+ execution::runtime_env::{RuntimeConfig, RuntimeEnv},
+ physical_plan::{
+ collect,
+ expressions::{col, PhysicalSortExpr},
+ memory::MemoryExec,
+ sorts::sort_preserving_merge::SortPreservingMergeExec,
+ },
+};
+use rand::{prelude::StdRng, Rng, SeedableRng};
+
+#[tokio::test]
+async fn test_merge_2() {
+ run_merge_test(vec![
+ // (0..100)
+ // (0..100)
+ make_staggered_batches(0, 100, 2),
+ make_staggered_batches(0, 100, 3),
+ ])
+ .await
+}
+
+#[tokio::test]
+async fn test_merge_2_no_overlap() {
+ run_merge_test(vec![
+ // (0..20)
+ // (20..40)
+ make_staggered_batches(0, 20, 2),
+ make_staggered_batches(20, 40, 3),
+ ])
+ .await
+}
+
+#[tokio::test]
+async fn test_merge_3() {
+ run_merge_test(vec![
+ // (0 .. 100)
+ // (0 .. 100)
+ // (0 .. 51)
+ make_staggered_batches(0, 100, 2),
+ make_staggered_batches(0, 100, 3),
+ make_staggered_batches(0, 51, 4),
+ ])
+ .await
+}
+
+#[tokio::test]
+async fn test_merge_3_gaps() {
+ run_merge_test(vec![
+ // (0 .. 50)(50 .. 100)
+ // (0 ..33) (50 .. 100)
+ // (0 .. 51)
+ concat(
+ make_staggered_batches(0, 50, 2),
+ make_staggered_batches(50, 100, 7),
+ ),
+ concat(
+ make_staggered_batches(0, 33, 21),
+ make_staggered_batches(50, 123, 31),
+ ),
+ make_staggered_batches(0, 51, 11),
+ ])
+ .await
+}
+
+/// Merge a set of input streams using SortPreservingMergeExec and
+/// `Vec::sort` and ensure the results are the same.
+///
+/// For each case, the `input` streams are turned into a set of of
+/// streams which are then merged together by [SortPreservingMerge]
+///
+/// Each `Vec<RecordBatch>` in `input` must be sorted and have a
+/// single Int32 field named 'x'.
+async fn run_merge_test(input: Vec<Vec<RecordBatch>>) {
+ // Produce output with the specified output batch sizes
+ let batch_sizes = [1, 2, 7, 49, 50, 51, 100];
+
+ for batch_size in batch_sizes {
+ let first_batch = input
+ .iter()
+ .map(|p| p.iter())
+ .flatten()
+ .next()
+ .expect("at least one batch");
+ let schema = first_batch.schema();
+
+ let sort = vec![PhysicalSortExpr {
+ expr: col("x", &schema).unwrap(),
+ options: SortOptions {
+ descending: false,
+ nulls_first: true,
+ },
+ }];
+
+ let exec = MemoryExec::try_new(&input, schema, None).unwrap();
+ let merge = Arc::new(SortPreservingMergeExec::new(sort,
Arc::new(exec)));
+
+ let runtime_config = RuntimeConfig::new().with_batch_size(batch_size);
+
+ let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap());
+ let collected = collect(merge, runtime).await.unwrap();
+
+ // verify the output batch size: all batches except the last
+ // should contain `batch_size` rows
+ for (i, batch) in collected.iter().enumerate() {
+ if i < collected.len() - 1 {
+ assert_eq!(
+ batch.num_rows(),
+ batch_size,
+ "Expected batch {} to have {} rows, got {}",
+ i,
+ batch_size,
+ batch.num_rows()
+ );
+ }
+ }
+
+ let expected = partitions_to_sorted_vec(&input);
+ let actual = batches_to_vec(&collected);
+
+ assert_eq!(expected, actual, "failure in @ batch_size {}", batch_size);
+ }
+}
+
+/// Extracts the i32 values from the set of batches and returns them as a
single Vec
+fn batches_to_vec(batches: &[RecordBatch]) -> Vec<Option<i32>> {
+ batches
+ .iter()
+ .map(|batch| {
+ assert_eq!(batch.num_columns(), 1);
+ batch
+ .column(0)
+ .as_any()
+ .downcast_ref::<Int32Array>()
+ .unwrap()
+ .iter()
+ })
+ .flatten()
+ .collect()
+}
+
+// extract values from batches and sort them
+fn partitions_to_sorted_vec(partitions: &[Vec<RecordBatch>]) ->
Vec<Option<i32>> {
+ let mut values: Vec<_> = partitions
+ .iter()
+ .map(|batches| batches_to_vec(batches).into_iter())
+ .flatten()
+ .collect();
+
+ values.sort_unstable();
+ values
+}
+
+/// Return the values `low..high` in order, in randomly sized
+/// record batches in a field named 'x' of type `Int32`
+fn make_staggered_batches(low: i32, high: i32, seed: u64) -> Vec<RecordBatch> {
+ let input: Int32Array = (low..high).map(Some).collect();
+
+ // split into several record batches
+ let mut remainder =
+ RecordBatch::try_from_iter(vec![("x", Arc::new(input) as
ArrayRef)]).unwrap();
+
+ let mut batches = vec![];
+
+ // use a random number generator to pick a random sized output
+ let mut rng = StdRng::seed_from_u64(seed);
+ while remainder.num_rows() > 0 {
+ let batch_size = rng.gen_range(0..remainder.num_rows() + 1);
+
+ batches.push(remainder.slice(0, batch_size));
+ remainder = remainder.slice(batch_size, remainder.num_rows() -
batch_size);
+ }
+
+ add_empty_batches(batches, &mut rng)
+}
+
+/// Adds a random number of empty record batches into the stream
+fn add_empty_batches(batches: Vec<RecordBatch>, rng: &mut StdRng) ->
Vec<RecordBatch> {
+ let schema = batches[0].schema();
+
+ batches
+ .into_iter()
+ .map(|batch| {
+ // insert 0, or 1 empty batches before and after the current batch
+ let empty_batch = RecordBatch::new_empty(schema.clone());
+ std::iter::repeat(empty_batch.clone())
+ .take(rng.gen_range(0..2))
+ .chain(std::iter::once(batch))
+
.chain(std::iter::repeat(empty_batch).take(rng.gen_range(0..2)))
+ })
+ .flatten()
+ .collect()
+}
+
+fn concat(mut v1: Vec<RecordBatch>, v2: Vec<RecordBatch>) -> Vec<RecordBatch> {
+ v1.extend(v2);
+ v1
+}