This is an automated email from the ASF dual-hosted git repository.
ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 63ce486589 Chore: Do not return empty record batches from streams
(#13794)
63ce486589 is described below
commit 63ce4865896b906ca34fcbf85fdc55bff3080c30
Author: mertak-synnada <[email protected]>
AuthorDate: Wed Dec 18 08:30:03 2024 +0000
Chore: Do not return empty record batches from streams (#13794)
* do not emit empty record batches in plans
* change function signatures to Option<RecordBatch> if empty batches are
possible
* format code
* shorten code
* change list_unnest_at_level for returning Option value
* add documentation
take concat_batches into compute_aggregates function again
* create unit test for row_hash.rs
* add test for unnest
* add test for unnest
* add test for partial sort
* add test for bounded window agg
* add test for window agg
* apply simplifications and fix typo
* apply simplifications and fix typo
---
datafusion/core/src/dataframe/mod.rs | 24 ++++++
datafusion/core/tests/dataframe/mod.rs | 43 ++++++++++
datafusion/core/tests/fuzz_cases/window_fuzz.rs | 8 +-
.../physical-plan/src/aggregates/row_hash.rs | 46 ++++++----
datafusion/physical-plan/src/sorts/partial_sort.rs | 97 +++++++++++++++-------
datafusion/physical-plan/src/topk/mod.rs | 36 ++++----
datafusion/physical-plan/src/unnest.rs | 62 +++++++-------
.../src/windows/bounded_window_agg_exec.rs | 23 +++--
.../physical-plan/src/windows/window_agg_exec.rs | 28 ++++---
9 files changed, 256 insertions(+), 111 deletions(-)
diff --git a/datafusion/core/src/dataframe/mod.rs
b/datafusion/core/src/dataframe/mod.rs
index 82ee52d7b2..414d6da7bc 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -2380,6 +2380,30 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn aggregate_assert_no_empty_batches() -> Result<()> {
+ // build plan using DataFrame API
+ let df = test_table().await?;
+ let group_expr = vec![col("c1")];
+ let aggr_expr = vec![
+ min(col("c12")),
+ max(col("c12")),
+ avg(col("c12")),
+ sum(col("c12")),
+ count(col("c12")),
+ count_distinct(col("c12")),
+ median(col("c12")),
+ ];
+
+ let df: Vec<RecordBatch> = df.aggregate(group_expr,
aggr_expr)?.collect().await?;
+ // Empty batches should not be produced
+ for batch in df {
+ assert!(batch.num_rows() > 0);
+ }
+
+ Ok(())
+ }
+
#[tokio::test]
async fn test_aggregate_with_pk() -> Result<()> {
// create the dataframe
diff --git a/datafusion/core/tests/dataframe/mod.rs
b/datafusion/core/tests/dataframe/mod.rs
index 7c3c96025b..f4f754b11c 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -1246,6 +1246,43 @@ async fn unnest_aggregate_columns() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn unnest_no_empty_batches() -> Result<()> {
+ let mut shape_id_builder = UInt32Builder::new();
+ let mut tag_id_builder = UInt32Builder::new();
+
+ for shape_id in 1..=10 {
+ for tag_id in 1..=10 {
+ shape_id_builder.append_value(shape_id as u32);
+ tag_id_builder.append_value((shape_id * 10 + tag_id) as u32);
+ }
+ }
+
+ let batch = RecordBatch::try_from_iter(vec![
+ ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef),
+ ("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef),
+ ])?;
+
+ let ctx = SessionContext::new();
+ ctx.register_batch("shapes", batch)?;
+ let df = ctx.table("shapes").await?;
+
+ let results = df
+ .clone()
+ .aggregate(
+ vec![col("shape_id")],
+ vec![array_agg(col("tag_id")).alias("tag_id")],
+ )?
+ .collect()
+ .await?;
+
+ // Assert that there are no empty batches in result
+ for rb in results {
+ assert!(rb.num_rows() > 0);
+ }
+ Ok(())
+}
+
#[tokio::test]
async fn unnest_array_agg() -> Result<()> {
let mut shape_id_builder = UInt32Builder::new();
@@ -1268,6 +1305,12 @@ async fn unnest_array_agg() -> Result<()> {
let df = ctx.table("shapes").await?;
let results = df.clone().collect().await?;
+
+ // Assert that there are no empty batches in result
+ for rb in results.clone() {
+ assert!(rb.num_rows() > 0);
+ }
+
let expected = vec![
"+----------+--------+",
"| shape_id | tag_id |",
diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index eaa84988a8..67666f5d7a 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -299,9 +299,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
Linear,
)?);
let task_ctx = ctx.task_ctx();
- let mut collected_results =
- collect(running_window_exec, task_ctx).await?;
- collected_results.retain(|batch| batch.num_rows() > 0);
+ let collected_results = collect(running_window_exec,
task_ctx).await?;
let input_batch_sizes = batches
.iter()
.map(|batch| batch.num_rows())
@@ -310,6 +308,8 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
.iter()
.map(|batch| batch.num_rows())
.collect::<Vec<_>>();
+ // There should be no empty batches at results
+ assert!(result_batch_sizes.iter().all(|e| *e > 0));
if causal {
// For causal window frames, we can generate results
immediately
// for each input batch. Hence, batch sizes should match.
@@ -688,8 +688,8 @@ async fn run_window_test(
let collected_running = collect(running_window_exec, task_ctx)
.await?
.into_iter()
- .filter(|b| b.num_rows() > 0)
.collect::<Vec<_>>();
+ assert!(collected_running.iter().all(|rb| rb.num_rows() > 0));
// BoundedWindowAggExec should produce more chunk than the usual
WindowAggExec.
// Otherwise it means that we cannot generate result in running mode.
diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs
b/datafusion/physical-plan/src/aggregates/row_hash.rs
index 965adbb8c7..c261310f56 100644
--- a/datafusion/physical-plan/src/aggregates/row_hash.rs
+++ b/datafusion/physical-plan/src/aggregates/row_hash.rs
@@ -654,9 +654,13 @@ impl Stream for GroupedHashAggregateStream {
}
if let Some(to_emit) =
self.group_ordering.emit_to() {
- let batch = extract_ok!(self.emit(to_emit,
false));
- self.exec_state =
ExecutionState::ProducingOutput(batch);
timer.done();
+ if let Some(batch) =
+ extract_ok!(self.emit(to_emit, false))
+ {
+ self.exec_state =
+ ExecutionState::ProducingOutput(batch);
+ };
// make sure the exec_state just set is not
overwritten below
break 'reading_input;
}
@@ -693,9 +697,13 @@ impl Stream for GroupedHashAggregateStream {
}
if let Some(to_emit) =
self.group_ordering.emit_to() {
- let batch = extract_ok!(self.emit(to_emit,
false));
- self.exec_state =
ExecutionState::ProducingOutput(batch);
timer.done();
+ if let Some(batch) =
+ extract_ok!(self.emit(to_emit, false))
+ {
+ self.exec_state =
+ ExecutionState::ProducingOutput(batch);
+ };
// make sure the exec_state just set is not
overwritten below
break 'reading_input;
}
@@ -768,6 +776,9 @@ impl Stream for GroupedHashAggregateStream {
let output = batch.slice(0, size);
(ExecutionState::ProducingOutput(remaining), output)
};
+ // Empty record batches should not be emitted.
+ // They need to be treated as [`Option<RecordBatch>`]es
and handled separately
+ debug_assert!(output_batch.num_rows() > 0);
return Poll::Ready(Some(Ok(
output_batch.record_output(&self.baseline_metrics)
)));
@@ -902,14 +913,14 @@ impl GroupedHashAggregateStream {
/// Create an output RecordBatch with the group keys and
/// accumulator states/values specified in emit_to
- fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<RecordBatch>
{
+ fn emit(&mut self, emit_to: EmitTo, spilling: bool) ->
Result<Option<RecordBatch>> {
let schema = if spilling {
Arc::clone(&self.spill_state.spill_schema)
} else {
self.schema()
};
if self.group_values.is_empty() {
- return Ok(RecordBatch::new_empty(schema));
+ return Ok(None);
}
let mut output = self.group_values.emit(emit_to)?;
@@ -937,7 +948,8 @@ impl GroupedHashAggregateStream {
// over the target memory size after emission, we can emit again
rather than returning Err.
let _ = self.update_memory_reservation();
let batch = RecordBatch::try_new(schema, output)?;
- Ok(batch)
+ debug_assert!(batch.num_rows() > 0);
+ Ok(Some(batch))
}
/// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the
memory target slightly
@@ -963,7 +975,9 @@ impl GroupedHashAggregateStream {
/// Emit all rows, sort them, and store them on disk.
fn spill(&mut self) -> Result<()> {
- let emit = self.emit(EmitTo::All, true)?;
+ let Some(emit) = self.emit(EmitTo::All, true)? else {
+ return Ok(());
+ };
let sorted = sort_batch(&emit, self.spill_state.spill_expr.as_ref(),
None)?;
let spillfile =
self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
// TODO: slice large `sorted` and write to multiple files in parallel
@@ -1008,8 +1022,9 @@ impl GroupedHashAggregateStream {
{
assert_eq!(self.mode, AggregateMode::Partial);
let n = self.group_values.len() / self.batch_size *
self.batch_size;
- let batch = self.emit(EmitTo::First(n), false)?;
- self.exec_state = ExecutionState::ProducingOutput(batch);
+ if let Some(batch) = self.emit(EmitTo::First(n), false)? {
+ self.exec_state = ExecutionState::ProducingOutput(batch);
+ };
}
Ok(())
}
@@ -1019,7 +1034,9 @@ impl GroupedHashAggregateStream {
/// Conduct a streaming merge sort between the batch and spilled data.
Since the stream is fully
/// sorted, set `self.group_ordering` to Full, then later we can read with
[`EmitTo::First`].
fn update_merged_stream(&mut self) -> Result<()> {
- let batch = self.emit(EmitTo::All, true)?;
+ let Some(batch) = self.emit(EmitTo::All, true)? else {
+ return Ok(());
+ };
// clear up memory for streaming_merge
self.clear_all();
self.update_memory_reservation()?;
@@ -1067,7 +1084,7 @@ impl GroupedHashAggregateStream {
let timer = elapsed_compute.timer();
self.exec_state = if self.spill_state.spills.is_empty() {
let batch = self.emit(EmitTo::All, false)?;
- ExecutionState::ProducingOutput(batch)
+ batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput)
} else {
// If spill files exist, stream-merge them.
self.update_merged_stream()?;
@@ -1096,8 +1113,9 @@ impl GroupedHashAggregateStream {
fn switch_to_skip_aggregation(&mut self) -> Result<()> {
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
if probe.should_skip() {
- let batch = self.emit(EmitTo::All, false)?;
- self.exec_state = ExecutionState::ProducingOutput(batch);
+ if let Some(batch) = self.emit(EmitTo::All, false)? {
+ self.exec_state = ExecutionState::ProducingOutput(batch);
+ };
}
}
diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs
b/datafusion/physical-plan/src/sorts/partial_sort.rs
index dde19f46cd..77636f9e49 100644
--- a/datafusion/physical-plan/src/sorts/partial_sort.rs
+++ b/datafusion/physical-plan/src/sorts/partial_sort.rs
@@ -363,31 +363,31 @@ impl PartialSortStream {
if self.is_closed {
return Poll::Ready(None);
}
- let result = match ready!(self.input.poll_next_unpin(cx)) {
- Some(Ok(batch)) => {
- if let Some(slice_point) =
- self.get_slice_point(self.common_prefix_length, &batch)?
- {
- self.in_mem_batches.push(batch.slice(0, slice_point));
- let remaining_batch =
- batch.slice(slice_point, batch.num_rows() -
slice_point);
- let sorted_batch = self.sort_in_mem_batches();
- self.in_mem_batches.push(remaining_batch);
- sorted_batch
- } else {
- self.in_mem_batches.push(batch);
- Ok(RecordBatch::new_empty(self.schema()))
+ loop {
+ return Poll::Ready(Some(match
ready!(self.input.poll_next_unpin(cx)) {
+ Some(Ok(batch)) => {
+ if let Some(slice_point) =
+ self.get_slice_point(self.common_prefix_length,
&batch)?
+ {
+ self.in_mem_batches.push(batch.slice(0, slice_point));
+ let remaining_batch =
+ batch.slice(slice_point, batch.num_rows() -
slice_point);
+ let sorted_batch = self.sort_in_mem_batches();
+ self.in_mem_batches.push(remaining_batch);
+ sorted_batch
+ } else {
+ self.in_mem_batches.push(batch);
+ continue;
+ }
}
- }
- Some(Err(e)) => Err(e),
- None => {
- self.is_closed = true;
- // once input is consumed, sort the rest of the inserted
batches
- self.sort_in_mem_batches()
- }
- };
-
- Poll::Ready(Some(result))
+ Some(Err(e)) => Err(e),
+ None => {
+ self.is_closed = true;
+ // once input is consumed, sort the rest of the inserted
batches
+ self.sort_in_mem_batches()
+ }
+ }));
+ }
}
/// Returns a sorted RecordBatch from in_mem_batches and clears
in_mem_batches
@@ -407,6 +407,9 @@ impl PartialSortStream {
self.is_closed = true;
}
}
+ // Empty record batches should not be emitted.
+ // They need to be treated as [`Option<RecordBatch>`]es and handle
separately
+ debug_assert!(result.num_rows() > 0);
Ok(result)
}
@@ -731,7 +734,7 @@ mod tests {
let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
assert_eq!(
result.iter().map(|r| r.num_rows()).collect_vec(),
- [0, 125, 125, 0, 150]
+ [125, 125, 150]
);
assert_eq!(
@@ -760,10 +763,10 @@ mod tests {
nulls_first: false,
};
for (fetch_size, expected_batch_num_rows) in [
- (Some(50), vec![0, 50]),
- (Some(120), vec![0, 120]),
- (Some(150), vec![0, 125, 25]),
- (Some(250), vec![0, 125, 125]),
+ (Some(50), vec![50]),
+ (Some(120), vec![120]),
+ (Some(150), vec![125, 25]),
+ (Some(250), vec![125, 125]),
] {
let partial_sort_executor = PartialSortExec::new(
LexOrdering::new(vec![
@@ -810,6 +813,42 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn test_partial_sort_no_empty_batches() -> Result<()> {
+ let task_ctx = Arc::new(TaskContext::default());
+ let mem_exec = prepare_partitioned_input();
+ let schema = mem_exec.schema();
+ let option_asc = SortOptions {
+ descending: false,
+ nulls_first: false,
+ };
+ let fetch_size = Some(250);
+ let partial_sort_executor = PartialSortExec::new(
+ LexOrdering::new(vec![
+ PhysicalSortExpr {
+ expr: col("a", &schema)?,
+ options: option_asc,
+ },
+ PhysicalSortExpr {
+ expr: col("c", &schema)?,
+ options: option_asc,
+ },
+ ]),
+ Arc::clone(&mem_exec),
+ 1,
+ )
+ .with_fetch(fetch_size);
+
+ let partial_sort_exec =
+ Arc::new(partial_sort_executor.clone()) as Arc<dyn ExecutionPlan>;
+ let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
+ for rb in result {
+ assert!(rb.num_rows() > 0);
+ }
+
+ Ok(())
+ }
+
#[tokio::test]
async fn test_sort_metadata() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
diff --git a/datafusion/physical-plan/src/topk/mod.rs
b/datafusion/physical-plan/src/topk/mod.rs
index 0f722ec143..c8842965ac 100644
--- a/datafusion/physical-plan/src/topk/mod.rs
+++ b/datafusion/physical-plan/src/topk/mod.rs
@@ -200,21 +200,22 @@ impl TopK {
} = self;
let _timer = metrics.baseline.elapsed_compute().timer(); // time
updated on drop
- let mut batch = heap.emit()?;
- metrics.baseline.output_rows().add(batch.num_rows());
-
// break into record batches as needed
let mut batches = vec![];
- loop {
- if batch.num_rows() <= batch_size {
- batches.push(Ok(batch));
- break;
- } else {
- batches.push(Ok(batch.slice(0, batch_size)));
- let remaining_length = batch.num_rows() - batch_size;
- batch = batch.slice(batch_size, remaining_length);
+ if let Some(mut batch) = heap.emit()? {
+ metrics.baseline.output_rows().add(batch.num_rows());
+
+ loop {
+ if batch.num_rows() <= batch_size {
+ batches.push(Ok(batch));
+ break;
+ } else {
+ batches.push(Ok(batch.slice(0, batch_size)));
+ let remaining_length = batch.num_rows() - batch_size;
+ batch = batch.slice(batch_size, remaining_length);
+ }
}
- }
+ };
Ok(Box::pin(RecordBatchStreamAdapter::new(
schema,
futures::stream::iter(batches),
@@ -345,21 +346,21 @@ impl TopKHeap {
/// Returns the values stored in this heap, from values low to
/// high, as a single [`RecordBatch`], resetting the inner heap
- pub fn emit(&mut self) -> Result<RecordBatch> {
+ pub fn emit(&mut self) -> Result<Option<RecordBatch>> {
Ok(self.emit_with_state()?.0)
}
/// Returns the values stored in this heap, from values low to
/// high, as a single [`RecordBatch`], and a sorted vec of the
/// current heap's contents
- pub fn emit_with_state(&mut self) -> Result<(RecordBatch, Vec<TopKRow>)> {
+ pub fn emit_with_state(&mut self) -> Result<(Option<RecordBatch>,
Vec<TopKRow>)> {
let schema = Arc::clone(self.store.schema());
// generate sorted rows
let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec();
if self.store.is_empty() {
- return Ok((RecordBatch::new_empty(schema), topk_rows));
+ return Ok((None, topk_rows));
}
// Indices for each row within its respective RecordBatch
@@ -393,7 +394,7 @@ impl TopKHeap {
.collect::<Result<_>>()?;
let new_batch = RecordBatch::try_new(schema, output_columns)?;
- Ok((new_batch, topk_rows))
+ Ok((Some(new_batch), topk_rows))
}
/// Compact this heap, rewriting all stored batches into a single
@@ -418,6 +419,9 @@ impl TopKHeap {
// Note: new batch is in the same order as inner
let num_rows = self.inner.len();
let (new_batch, mut topk_rows) = self.emit_with_state()?;
+ let Some(new_batch) = new_batch else {
+ return Ok(());
+ };
// clear all old entries in store (this invalidates all
// store_ids in `inner`)
diff --git a/datafusion/physical-plan/src/unnest.rs
b/datafusion/physical-plan/src/unnest.rs
index 9f03385f09..19b1b46953 100644
--- a/datafusion/physical-plan/src/unnest.rs
+++ b/datafusion/physical-plan/src/unnest.rs
@@ -18,6 +18,7 @@
//! Define a plan for unnesting values in columns that contain a list type.
use std::cmp::{self, Ordering};
+use std::task::{ready, Poll};
use std::{any::Any, sync::Arc};
use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
@@ -267,7 +268,7 @@ impl Stream for UnnestStream {
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<Option<Self::Item>> {
+ ) -> Poll<Option<Self::Item>> {
self.poll_next_impl(cx)
}
}
@@ -278,28 +279,31 @@ impl UnnestStream {
fn poll_next_impl(
&mut self,
cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<Option<Result<RecordBatch>>> {
- self.input
- .poll_next_unpin(cx)
- .map(|maybe_batch| match maybe_batch {
+ ) -> Poll<Option<Result<RecordBatch>>> {
+ loop {
+ return Poll::Ready(match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let timer = self.metrics.elapsed_compute.timer();
+ self.metrics.input_batches.add(1);
+ self.metrics.input_rows.add(batch.num_rows());
let result = build_batch(
&batch,
&self.schema,
&self.list_type_columns,
&self.struct_column_indices,
&self.options,
- );
- self.metrics.input_batches.add(1);
- self.metrics.input_rows.add(batch.num_rows());
- if let Ok(ref batch) = result {
- timer.done();
- self.metrics.output_batches.add(1);
- self.metrics.output_rows.add(batch.num_rows());
- }
-
- Some(result)
+ )?;
+ timer.done();
+ let Some(result_batch) = result else {
+ continue;
+ };
+ self.metrics.output_batches.add(1);
+ self.metrics.output_rows.add(result_batch.num_rows());
+
+ // Empty record batches should not be emitted.
+ // They need to be treated as [`Option<RecordBatch>`]es
and handled separately
+ debug_assert!(result_batch.num_rows() > 0);
+ Some(Ok(result_batch))
}
other => {
trace!(
@@ -313,7 +317,8 @@ impl UnnestStream {
);
other
}
- })
+ });
+ }
}
}
@@ -408,7 +413,7 @@ fn list_unnest_at_level(
temp_unnested_arrs: &mut HashMap<ListUnnest, ArrayRef>,
level_to_unnest: usize,
options: &UnnestOptions,
-) -> Result<(Vec<ArrayRef>, usize)> {
+) -> Result<Option<Vec<ArrayRef>>> {
// Extract unnestable columns at this level
let (arrs_to_unnest, list_unnest_specs): (Vec<Arc<dyn Array>>, Vec<_>) =
list_type_unnests
@@ -444,7 +449,7 @@ fn list_unnest_at_level(
})? as usize
};
if total_length == 0 {
- return Ok((vec![], 0));
+ return Ok(None);
}
// Unnest all the list arrays
@@ -483,7 +488,7 @@ fn list_unnest_at_level(
// as the side effect of unnesting
let ret = repeat_arrs_from_indices(batch, &take_indices, &repeat_mask)?;
- Ok((ret, total_length))
+ Ok(Some(ret))
}
struct UnnestingResult {
arr: ArrayRef,
@@ -552,7 +557,7 @@ fn build_batch(
list_type_columns: &[ListUnnest],
struct_column_indices: &HashSet<usize>,
options: &UnnestOptions,
-) -> Result<RecordBatch> {
+) -> Result<Option<RecordBatch>> {
let transformed = match list_type_columns.len() {
0 => flatten_struct_cols(batch.columns(), schema,
struct_column_indices),
_ => {
@@ -573,16 +578,16 @@ fn build_batch(
true => batch.columns(),
false => &flatten_arrs,
};
- let (temp_result, num_rows) = list_unnest_at_level(
+ let Some(temp_result) = list_unnest_at_level(
input,
list_type_columns,
&mut temp_unnested_result,
depth,
options,
- )?;
- if num_rows == 0 {
- return Ok(RecordBatch::new_empty(Arc::clone(schema)));
- }
+ )?
+ else {
+ return Ok(None);
+ };
flatten_arrs = temp_result;
}
let unnested_array_map: HashMap<usize, Vec<UnnestingResult>> =
@@ -666,8 +671,8 @@ fn build_batch(
flatten_struct_cols(&ret, schema, struct_column_indices)
}
- };
- transformed
+ }?;
+ Ok(Some(transformed))
}
/// Find the longest list length among the given list arrays for each row.
@@ -1134,7 +1139,8 @@ mod tests {
preserve_nulls: true,
recursions: vec![],
},
- )?;
+ )?
+ .unwrap();
let expected = &[
"+---------------------------------+---------------------------------+---------------------------------+",
diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
index c1cfd91be0..c6003fe0a8 100644
--- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
+++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
@@ -937,7 +937,7 @@ impl BoundedWindowAggStream {
})
}
- fn compute_aggregates(&mut self) -> Result<RecordBatch> {
+ fn compute_aggregates(&mut self) -> Result<Option<RecordBatch>> {
// calculate window cols
for (cur_window_expr, state) in
self.window_expr.iter().zip(&mut self.window_agg_states)
@@ -964,9 +964,9 @@ impl BoundedWindowAggStream {
.collect::<Vec<_>>();
let n_generated = columns_to_show[0].len();
self.prune_state(n_generated)?;
- Ok(RecordBatch::try_new(schema, columns_to_show)?)
+ Ok(Some(RecordBatch::try_new(schema, columns_to_show)?))
} else {
- Ok(RecordBatch::new_empty(schema))
+ Ok(None)
}
}
@@ -979,7 +979,7 @@ impl BoundedWindowAggStream {
return Poll::Ready(None);
}
- let result = match ready!(self.input.poll_next_unpin(cx)) {
+ match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
self.search_mode.update_partition_batch(
&mut self.input_buffer,
@@ -987,18 +987,23 @@ impl BoundedWindowAggStream {
&self.window_expr,
&mut self.partition_buffers,
)?;
- self.compute_aggregates()
+ if let Some(batch) = self.compute_aggregates()? {
+ return Poll::Ready(Some(Ok(batch)));
+ }
+ self.poll_next_inner(cx)
}
- Some(Err(e)) => Err(e),
+ Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => {
self.finished = true;
for (_, partition_batch_state) in
self.partition_buffers.iter_mut() {
partition_batch_state.is_end = true;
}
- self.compute_aggregates()
+ if let Some(batch) = self.compute_aggregates()? {
+ return Poll::Ready(Some(Ok(batch)));
+ }
+ Poll::Ready(None)
}
- };
- Poll::Ready(Some(result))
+ }
}
/// Prunes the sections of the record batch (for each partition)
diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs
b/datafusion/physical-plan/src/windows/window_agg_exec.rs
index f71a0b9fd0..c0ac96d22e 100644
--- a/datafusion/physical-plan/src/windows/window_agg_exec.rs
+++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs
@@ -312,12 +312,13 @@ impl WindowAggStream {
})
}
- fn compute_aggregates(&self) -> Result<RecordBatch> {
+ fn compute_aggregates(&self) -> Result<Option<RecordBatch>> {
// record compute time on drop
let _timer = self.baseline_metrics.elapsed_compute().timer();
+
let batch = concat_batches(&self.input.schema(), &self.batches)?;
if batch.num_rows() == 0 {
- return Ok(RecordBatch::new_empty(Arc::clone(&self.schema)));
+ return Ok(None);
}
let partition_by_sort_keys = self
@@ -350,10 +351,10 @@ impl WindowAggStream {
let mut batch_columns = batch.columns().to_vec();
// calculate window cols
batch_columns.extend_from_slice(&columns);
- Ok(RecordBatch::try_new(
+ Ok(Some(RecordBatch::try_new(
Arc::clone(&self.schema),
batch_columns,
- )?)
+ )?))
}
}
@@ -380,18 +381,23 @@ impl WindowAggStream {
}
loop {
- let result = match ready!(self.input.poll_next_unpin(cx)) {
+ return Poll::Ready(Some(match
ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
self.batches.push(batch);
continue;
}
Some(Err(e)) => Err(e),
- None => self.compute_aggregates(),
- };
-
- self.finished = true;
-
- return Poll::Ready(Some(result));
+ None => {
+ let Some(result) = self.compute_aggregates()? else {
+ return Poll::Ready(None);
+ };
+ self.finished = true;
+ // Empty record batches should not be emitted.
+ // They need to be treated as [`Option<RecordBatch>`]es
and handled separately
+ debug_assert!(result.num_rows() > 0);
+ Ok(result)
+ }
+ }));
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]