This is an automated email from the ASF dual-hosted git repository.

ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 63ce486589 Chore: Do not return empty record batches from streams 
(#13794)
63ce486589 is described below

commit 63ce4865896b906ca34fcbf85fdc55bff3080c30
Author: mertak-synnada <[email protected]>
AuthorDate: Wed Dec 18 08:30:03 2024 +0000

    Chore: Do not return empty record batches from streams (#13794)
    
    * do not emit empty record batches in plans
    
    * change function signatures to Option<RecordBatch> if empty batches are 
possible
    
    * format code
    
    * shorten code
    
    * change list_unnest_at_level for returning Option value
    
    * add documentation
    take concat_batches into compute_aggregates function again
    
    * create unit test for row_hash.rs
    
    * add test for unnest
    
    * add test for unnest
    
    * add test for partial sort
    
    * add test for bounded window agg
    
    * add test for window agg
    
    * apply simplifications and fix typo
    
    * apply simplifications and fix typo
---
 datafusion/core/src/dataframe/mod.rs               | 24 ++++++
 datafusion/core/tests/dataframe/mod.rs             | 43 ++++++++++
 datafusion/core/tests/fuzz_cases/window_fuzz.rs    |  8 +-
 .../physical-plan/src/aggregates/row_hash.rs       | 46 ++++++----
 datafusion/physical-plan/src/sorts/partial_sort.rs | 97 +++++++++++++++-------
 datafusion/physical-plan/src/topk/mod.rs           | 36 ++++----
 datafusion/physical-plan/src/unnest.rs             | 62 +++++++-------
 .../src/windows/bounded_window_agg_exec.rs         | 23 +++--
 .../physical-plan/src/windows/window_agg_exec.rs   | 28 ++++---
 9 files changed, 256 insertions(+), 111 deletions(-)

diff --git a/datafusion/core/src/dataframe/mod.rs 
b/datafusion/core/src/dataframe/mod.rs
index 82ee52d7b2..414d6da7bc 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -2380,6 +2380,30 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn aggregate_assert_no_empty_batches() -> Result<()> {
+        // build plan using DataFrame API
+        let df = test_table().await?;
+        let group_expr = vec![col("c1")];
+        let aggr_expr = vec![
+            min(col("c12")),
+            max(col("c12")),
+            avg(col("c12")),
+            sum(col("c12")),
+            count(col("c12")),
+            count_distinct(col("c12")),
+            median(col("c12")),
+        ];
+
+        let df: Vec<RecordBatch> = df.aggregate(group_expr, 
aggr_expr)?.collect().await?;
+        // Empty batches should not be produced
+        for batch in df {
+            assert!(batch.num_rows() > 0);
+        }
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn test_aggregate_with_pk() -> Result<()> {
         // create the dataframe
diff --git a/datafusion/core/tests/dataframe/mod.rs 
b/datafusion/core/tests/dataframe/mod.rs
index 7c3c96025b..f4f754b11c 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -1246,6 +1246,43 @@ async fn unnest_aggregate_columns() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn unnest_no_empty_batches() -> Result<()> {
+    let mut shape_id_builder = UInt32Builder::new();
+    let mut tag_id_builder = UInt32Builder::new();
+
+    for shape_id in 1..=10 {
+        for tag_id in 1..=10 {
+            shape_id_builder.append_value(shape_id as u32);
+            tag_id_builder.append_value((shape_id * 10 + tag_id) as u32);
+        }
+    }
+
+    let batch = RecordBatch::try_from_iter(vec![
+        ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef),
+        ("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef),
+    ])?;
+
+    let ctx = SessionContext::new();
+    ctx.register_batch("shapes", batch)?;
+    let df = ctx.table("shapes").await?;
+
+    let results = df
+        .clone()
+        .aggregate(
+            vec![col("shape_id")],
+            vec![array_agg(col("tag_id")).alias("tag_id")],
+        )?
+        .collect()
+        .await?;
+
+    // Assert that there are no empty batches in result
+    for rb in results {
+        assert!(rb.num_rows() > 0);
+    }
+    Ok(())
+}
+
 #[tokio::test]
 async fn unnest_array_agg() -> Result<()> {
     let mut shape_id_builder = UInt32Builder::new();
@@ -1268,6 +1305,12 @@ async fn unnest_array_agg() -> Result<()> {
     let df = ctx.table("shapes").await?;
 
     let results = df.clone().collect().await?;
+
+    // Assert that there are no empty batches in result
+    for rb in results.clone() {
+        assert!(rb.num_rows() > 0);
+    }
+
     let expected = vec![
         "+----------+--------+",
         "| shape_id | tag_id |",
diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index eaa84988a8..67666f5d7a 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -299,9 +299,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
                     Linear,
                 )?);
                 let task_ctx = ctx.task_ctx();
-                let mut collected_results =
-                    collect(running_window_exec, task_ctx).await?;
-                collected_results.retain(|batch| batch.num_rows() > 0);
+                let collected_results = collect(running_window_exec, 
task_ctx).await?;
                 let input_batch_sizes = batches
                     .iter()
                     .map(|batch| batch.num_rows())
@@ -310,6 +308,8 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
                     .iter()
                     .map(|batch| batch.num_rows())
                     .collect::<Vec<_>>();
+                // There should be no empty batches at results
+                assert!(result_batch_sizes.iter().all(|e| *e > 0));
                 if causal {
                     // For causal window frames, we can generate results 
immediately
                     // for each input batch. Hence, batch sizes should match.
@@ -688,8 +688,8 @@ async fn run_window_test(
     let collected_running = collect(running_window_exec, task_ctx)
         .await?
         .into_iter()
-        .filter(|b| b.num_rows() > 0)
         .collect::<Vec<_>>();
+    assert!(collected_running.iter().all(|rb| rb.num_rows() > 0));
 
     // BoundedWindowAggExec should produce more chunk than the usual 
WindowAggExec.
     // Otherwise it means that we cannot generate result in running mode.
diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs 
b/datafusion/physical-plan/src/aggregates/row_hash.rs
index 965adbb8c7..c261310f56 100644
--- a/datafusion/physical-plan/src/aggregates/row_hash.rs
+++ b/datafusion/physical-plan/src/aggregates/row_hash.rs
@@ -654,9 +654,13 @@ impl Stream for GroupedHashAggregateStream {
                             }
 
                             if let Some(to_emit) = 
self.group_ordering.emit_to() {
-                                let batch = extract_ok!(self.emit(to_emit, 
false));
-                                self.exec_state = 
ExecutionState::ProducingOutput(batch);
                                 timer.done();
+                                if let Some(batch) =
+                                    extract_ok!(self.emit(to_emit, false))
+                                {
+                                    self.exec_state =
+                                        ExecutionState::ProducingOutput(batch);
+                                };
                                 // make sure the exec_state just set is not 
overwritten below
                                 break 'reading_input;
                             }
@@ -693,9 +697,13 @@ impl Stream for GroupedHashAggregateStream {
                             }
 
                             if let Some(to_emit) = 
self.group_ordering.emit_to() {
-                                let batch = extract_ok!(self.emit(to_emit, 
false));
-                                self.exec_state = 
ExecutionState::ProducingOutput(batch);
                                 timer.done();
+                                if let Some(batch) =
+                                    extract_ok!(self.emit(to_emit, false))
+                                {
+                                    self.exec_state =
+                                        ExecutionState::ProducingOutput(batch);
+                                };
                                 // make sure the exec_state just set is not 
overwritten below
                                 break 'reading_input;
                             }
@@ -768,6 +776,9 @@ impl Stream for GroupedHashAggregateStream {
                         let output = batch.slice(0, size);
                         (ExecutionState::ProducingOutput(remaining), output)
                     };
+                    // Empty record batches should not be emitted.
+                    // They need to be treated as  [`Option<RecordBatch>`]es 
and handled separately
+                    debug_assert!(output_batch.num_rows() > 0);
                     return Poll::Ready(Some(Ok(
                         output_batch.record_output(&self.baseline_metrics)
                     )));
@@ -902,14 +913,14 @@ impl GroupedHashAggregateStream {
 
     /// Create an output RecordBatch with the group keys and
     /// accumulator states/values specified in emit_to
-    fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<RecordBatch> 
{
+    fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> 
Result<Option<RecordBatch>> {
         let schema = if spilling {
             Arc::clone(&self.spill_state.spill_schema)
         } else {
             self.schema()
         };
         if self.group_values.is_empty() {
-            return Ok(RecordBatch::new_empty(schema));
+            return Ok(None);
         }
 
         let mut output = self.group_values.emit(emit_to)?;
@@ -937,7 +948,8 @@ impl GroupedHashAggregateStream {
         // over the target memory size after emission, we can emit again 
rather than returning Err.
         let _ = self.update_memory_reservation();
         let batch = RecordBatch::try_new(schema, output)?;
-        Ok(batch)
+        debug_assert!(batch.num_rows() > 0);
+        Ok(Some(batch))
     }
 
     /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the 
memory target slightly
@@ -963,7 +975,9 @@ impl GroupedHashAggregateStream {
 
     /// Emit all rows, sort them, and store them on disk.
     fn spill(&mut self) -> Result<()> {
-        let emit = self.emit(EmitTo::All, true)?;
+        let Some(emit) = self.emit(EmitTo::All, true)? else {
+            return Ok(());
+        };
         let sorted = sort_batch(&emit, self.spill_state.spill_expr.as_ref(), 
None)?;
         let spillfile = 
self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
         // TODO: slice large `sorted` and write to multiple files in parallel
@@ -1008,8 +1022,9 @@ impl GroupedHashAggregateStream {
         {
             assert_eq!(self.mode, AggregateMode::Partial);
             let n = self.group_values.len() / self.batch_size * 
self.batch_size;
-            let batch = self.emit(EmitTo::First(n), false)?;
-            self.exec_state = ExecutionState::ProducingOutput(batch);
+            if let Some(batch) = self.emit(EmitTo::First(n), false)? {
+                self.exec_state = ExecutionState::ProducingOutput(batch);
+            };
         }
         Ok(())
     }
@@ -1019,7 +1034,9 @@ impl GroupedHashAggregateStream {
     /// Conduct a streaming merge sort between the batch and spilled data. 
Since the stream is fully
     /// sorted, set `self.group_ordering` to Full, then later we can read with 
[`EmitTo::First`].
     fn update_merged_stream(&mut self) -> Result<()> {
-        let batch = self.emit(EmitTo::All, true)?;
+        let Some(batch) = self.emit(EmitTo::All, true)? else {
+            return Ok(());
+        };
         // clear up memory for streaming_merge
         self.clear_all();
         self.update_memory_reservation()?;
@@ -1067,7 +1084,7 @@ impl GroupedHashAggregateStream {
         let timer = elapsed_compute.timer();
         self.exec_state = if self.spill_state.spills.is_empty() {
             let batch = self.emit(EmitTo::All, false)?;
-            ExecutionState::ProducingOutput(batch)
+            batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput)
         } else {
             // If spill files exist, stream-merge them.
             self.update_merged_stream()?;
@@ -1096,8 +1113,9 @@ impl GroupedHashAggregateStream {
     fn switch_to_skip_aggregation(&mut self) -> Result<()> {
         if let Some(probe) = self.skip_aggregation_probe.as_mut() {
             if probe.should_skip() {
-                let batch = self.emit(EmitTo::All, false)?;
-                self.exec_state = ExecutionState::ProducingOutput(batch);
+                if let Some(batch) = self.emit(EmitTo::All, false)? {
+                    self.exec_state = ExecutionState::ProducingOutput(batch);
+                };
             }
         }
 
diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs 
b/datafusion/physical-plan/src/sorts/partial_sort.rs
index dde19f46cd..77636f9e49 100644
--- a/datafusion/physical-plan/src/sorts/partial_sort.rs
+++ b/datafusion/physical-plan/src/sorts/partial_sort.rs
@@ -363,31 +363,31 @@ impl PartialSortStream {
         if self.is_closed {
             return Poll::Ready(None);
         }
-        let result = match ready!(self.input.poll_next_unpin(cx)) {
-            Some(Ok(batch)) => {
-                if let Some(slice_point) =
-                    self.get_slice_point(self.common_prefix_length, &batch)?
-                {
-                    self.in_mem_batches.push(batch.slice(0, slice_point));
-                    let remaining_batch =
-                        batch.slice(slice_point, batch.num_rows() - 
slice_point);
-                    let sorted_batch = self.sort_in_mem_batches();
-                    self.in_mem_batches.push(remaining_batch);
-                    sorted_batch
-                } else {
-                    self.in_mem_batches.push(batch);
-                    Ok(RecordBatch::new_empty(self.schema()))
+        loop {
+            return Poll::Ready(Some(match 
ready!(self.input.poll_next_unpin(cx)) {
+                Some(Ok(batch)) => {
+                    if let Some(slice_point) =
+                        self.get_slice_point(self.common_prefix_length, 
&batch)?
+                    {
+                        self.in_mem_batches.push(batch.slice(0, slice_point));
+                        let remaining_batch =
+                            batch.slice(slice_point, batch.num_rows() - 
slice_point);
+                        let sorted_batch = self.sort_in_mem_batches();
+                        self.in_mem_batches.push(remaining_batch);
+                        sorted_batch
+                    } else {
+                        self.in_mem_batches.push(batch);
+                        continue;
+                    }
                 }
-            }
-            Some(Err(e)) => Err(e),
-            None => {
-                self.is_closed = true;
-                // once input is consumed, sort the rest of the inserted 
batches
-                self.sort_in_mem_batches()
-            }
-        };
-
-        Poll::Ready(Some(result))
+                Some(Err(e)) => Err(e),
+                None => {
+                    self.is_closed = true;
+                    // once input is consumed, sort the rest of the inserted 
batches
+                    self.sort_in_mem_batches()
+                }
+            }));
+        }
     }
 
     /// Returns a sorted RecordBatch from in_mem_batches and clears 
in_mem_batches
@@ -407,6 +407,9 @@ impl PartialSortStream {
                 self.is_closed = true;
             }
         }
+        // Empty record batches should not be emitted.
+        // They need to be treated as [`Option<RecordBatch>`]es and handle 
separately
+        debug_assert!(result.num_rows() > 0);
         Ok(result)
     }
 
@@ -731,7 +734,7 @@ mod tests {
         let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
         assert_eq!(
             result.iter().map(|r| r.num_rows()).collect_vec(),
-            [0, 125, 125, 0, 150]
+            [125, 125, 150]
         );
 
         assert_eq!(
@@ -760,10 +763,10 @@ mod tests {
             nulls_first: false,
         };
         for (fetch_size, expected_batch_num_rows) in [
-            (Some(50), vec![0, 50]),
-            (Some(120), vec![0, 120]),
-            (Some(150), vec![0, 125, 25]),
-            (Some(250), vec![0, 125, 125]),
+            (Some(50), vec![50]),
+            (Some(120), vec![120]),
+            (Some(150), vec![125, 25]),
+            (Some(250), vec![125, 125]),
         ] {
             let partial_sort_executor = PartialSortExec::new(
                 LexOrdering::new(vec![
@@ -810,6 +813,42 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn test_partial_sort_no_empty_batches() -> Result<()> {
+        let task_ctx = Arc::new(TaskContext::default());
+        let mem_exec = prepare_partitioned_input();
+        let schema = mem_exec.schema();
+        let option_asc = SortOptions {
+            descending: false,
+            nulls_first: false,
+        };
+        let fetch_size = Some(250);
+        let partial_sort_executor = PartialSortExec::new(
+            LexOrdering::new(vec![
+                PhysicalSortExpr {
+                    expr: col("a", &schema)?,
+                    options: option_asc,
+                },
+                PhysicalSortExpr {
+                    expr: col("c", &schema)?,
+                    options: option_asc,
+                },
+            ]),
+            Arc::clone(&mem_exec),
+            1,
+        )
+        .with_fetch(fetch_size);
+
+        let partial_sort_exec =
+            Arc::new(partial_sort_executor.clone()) as Arc<dyn ExecutionPlan>;
+        let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
+        for rb in result {
+            assert!(rb.num_rows() > 0);
+        }
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn test_sort_metadata() -> Result<()> {
         let task_ctx = Arc::new(TaskContext::default());
diff --git a/datafusion/physical-plan/src/topk/mod.rs 
b/datafusion/physical-plan/src/topk/mod.rs
index 0f722ec143..c8842965ac 100644
--- a/datafusion/physical-plan/src/topk/mod.rs
+++ b/datafusion/physical-plan/src/topk/mod.rs
@@ -200,21 +200,22 @@ impl TopK {
         } = self;
         let _timer = metrics.baseline.elapsed_compute().timer(); // time 
updated on drop
 
-        let mut batch = heap.emit()?;
-        metrics.baseline.output_rows().add(batch.num_rows());
-
         // break into record batches as needed
         let mut batches = vec![];
-        loop {
-            if batch.num_rows() <= batch_size {
-                batches.push(Ok(batch));
-                break;
-            } else {
-                batches.push(Ok(batch.slice(0, batch_size)));
-                let remaining_length = batch.num_rows() - batch_size;
-                batch = batch.slice(batch_size, remaining_length);
+        if let Some(mut batch) = heap.emit()? {
+            metrics.baseline.output_rows().add(batch.num_rows());
+
+            loop {
+                if batch.num_rows() <= batch_size {
+                    batches.push(Ok(batch));
+                    break;
+                } else {
+                    batches.push(Ok(batch.slice(0, batch_size)));
+                    let remaining_length = batch.num_rows() - batch_size;
+                    batch = batch.slice(batch_size, remaining_length);
+                }
             }
-        }
+        };
         Ok(Box::pin(RecordBatchStreamAdapter::new(
             schema,
             futures::stream::iter(batches),
@@ -345,21 +346,21 @@ impl TopKHeap {
 
     /// Returns the values stored in this heap, from values low to
     /// high, as a single [`RecordBatch`], resetting the inner heap
-    pub fn emit(&mut self) -> Result<RecordBatch> {
+    pub fn emit(&mut self) -> Result<Option<RecordBatch>> {
         Ok(self.emit_with_state()?.0)
     }
 
     /// Returns the values stored in this heap, from values low to
     /// high, as a single [`RecordBatch`], and a sorted vec of the
     /// current heap's contents
-    pub fn emit_with_state(&mut self) -> Result<(RecordBatch, Vec<TopKRow>)> {
+    pub fn emit_with_state(&mut self) -> Result<(Option<RecordBatch>, 
Vec<TopKRow>)> {
         let schema = Arc::clone(self.store.schema());
 
         // generate sorted rows
         let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec();
 
         if self.store.is_empty() {
-            return Ok((RecordBatch::new_empty(schema), topk_rows));
+            return Ok((None, topk_rows));
         }
 
         // Indices for each row within its respective RecordBatch
@@ -393,7 +394,7 @@ impl TopKHeap {
             .collect::<Result<_>>()?;
 
         let new_batch = RecordBatch::try_new(schema, output_columns)?;
-        Ok((new_batch, topk_rows))
+        Ok((Some(new_batch), topk_rows))
     }
 
     /// Compact this heap, rewriting all stored batches into a single
@@ -418,6 +419,9 @@ impl TopKHeap {
         // Note: new batch is in the same order as inner
         let num_rows = self.inner.len();
         let (new_batch, mut topk_rows) = self.emit_with_state()?;
+        let Some(new_batch) = new_batch else {
+            return Ok(());
+        };
 
         // clear all old entries in store (this invalidates all
         // store_ids in `inner`)
diff --git a/datafusion/physical-plan/src/unnest.rs 
b/datafusion/physical-plan/src/unnest.rs
index 9f03385f09..19b1b46953 100644
--- a/datafusion/physical-plan/src/unnest.rs
+++ b/datafusion/physical-plan/src/unnest.rs
@@ -18,6 +18,7 @@
 //! Define a plan for unnesting values in columns that contain a list type.
 
 use std::cmp::{self, Ordering};
+use std::task::{ready, Poll};
 use std::{any::Any, sync::Arc};
 
 use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
@@ -267,7 +268,7 @@ impl Stream for UnnestStream {
     fn poll_next(
         mut self: std::pin::Pin<&mut Self>,
         cx: &mut std::task::Context<'_>,
-    ) -> std::task::Poll<Option<Self::Item>> {
+    ) -> Poll<Option<Self::Item>> {
         self.poll_next_impl(cx)
     }
 }
@@ -278,28 +279,31 @@ impl UnnestStream {
     fn poll_next_impl(
         &mut self,
         cx: &mut std::task::Context<'_>,
-    ) -> std::task::Poll<Option<Result<RecordBatch>>> {
-        self.input
-            .poll_next_unpin(cx)
-            .map(|maybe_batch| match maybe_batch {
+    ) -> Poll<Option<Result<RecordBatch>>> {
+        loop {
+            return Poll::Ready(match ready!(self.input.poll_next_unpin(cx)) {
                 Some(Ok(batch)) => {
                     let timer = self.metrics.elapsed_compute.timer();
+                    self.metrics.input_batches.add(1);
+                    self.metrics.input_rows.add(batch.num_rows());
                     let result = build_batch(
                         &batch,
                         &self.schema,
                         &self.list_type_columns,
                         &self.struct_column_indices,
                         &self.options,
-                    );
-                    self.metrics.input_batches.add(1);
-                    self.metrics.input_rows.add(batch.num_rows());
-                    if let Ok(ref batch) = result {
-                        timer.done();
-                        self.metrics.output_batches.add(1);
-                        self.metrics.output_rows.add(batch.num_rows());
-                    }
-
-                    Some(result)
+                    )?;
+                    timer.done();
+                    let Some(result_batch) = result else {
+                        continue;
+                    };
+                    self.metrics.output_batches.add(1);
+                    self.metrics.output_rows.add(result_batch.num_rows());
+
+                    // Empty record batches should not be emitted.
+                    // They need to be treated as  [`Option<RecordBatch>`]es 
and handled separately
+                    debug_assert!(result_batch.num_rows() > 0);
+                    Some(Ok(result_batch))
                 }
                 other => {
                     trace!(
@@ -313,7 +317,8 @@ impl UnnestStream {
                     );
                     other
                 }
-            })
+            });
+        }
     }
 }
 
@@ -408,7 +413,7 @@ fn list_unnest_at_level(
     temp_unnested_arrs: &mut HashMap<ListUnnest, ArrayRef>,
     level_to_unnest: usize,
     options: &UnnestOptions,
-) -> Result<(Vec<ArrayRef>, usize)> {
+) -> Result<Option<Vec<ArrayRef>>> {
     // Extract unnestable columns at this level
     let (arrs_to_unnest, list_unnest_specs): (Vec<Arc<dyn Array>>, Vec<_>) =
         list_type_unnests
@@ -444,7 +449,7 @@ fn list_unnest_at_level(
         })? as usize
     };
     if total_length == 0 {
-        return Ok((vec![], 0));
+        return Ok(None);
     }
 
     // Unnest all the list arrays
@@ -483,7 +488,7 @@ fn list_unnest_at_level(
     // as the side effect of unnesting
     let ret = repeat_arrs_from_indices(batch, &take_indices, &repeat_mask)?;
 
-    Ok((ret, total_length))
+    Ok(Some(ret))
 }
 struct UnnestingResult {
     arr: ArrayRef,
@@ -552,7 +557,7 @@ fn build_batch(
     list_type_columns: &[ListUnnest],
     struct_column_indices: &HashSet<usize>,
     options: &UnnestOptions,
-) -> Result<RecordBatch> {
+) -> Result<Option<RecordBatch>> {
     let transformed = match list_type_columns.len() {
         0 => flatten_struct_cols(batch.columns(), schema, 
struct_column_indices),
         _ => {
@@ -573,16 +578,16 @@ fn build_batch(
                     true => batch.columns(),
                     false => &flatten_arrs,
                 };
-                let (temp_result, num_rows) = list_unnest_at_level(
+                let Some(temp_result) = list_unnest_at_level(
                     input,
                     list_type_columns,
                     &mut temp_unnested_result,
                     depth,
                     options,
-                )?;
-                if num_rows == 0 {
-                    return Ok(RecordBatch::new_empty(Arc::clone(schema)));
-                }
+                )?
+                else {
+                    return Ok(None);
+                };
                 flatten_arrs = temp_result;
             }
             let unnested_array_map: HashMap<usize, Vec<UnnestingResult>> =
@@ -666,8 +671,8 @@ fn build_batch(
 
             flatten_struct_cols(&ret, schema, struct_column_indices)
         }
-    };
-    transformed
+    }?;
+    Ok(Some(transformed))
 }
 
 /// Find the longest list length among the given list arrays for each row.
@@ -1134,7 +1139,8 @@ mod tests {
                 preserve_nulls: true,
                 recursions: vec![],
             },
-        )?;
+        )?
+        .unwrap();
 
         let expected = &[
 
"+---------------------------------+---------------------------------+---------------------------------+",
diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs 
b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
index c1cfd91be0..c6003fe0a8 100644
--- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
+++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
@@ -937,7 +937,7 @@ impl BoundedWindowAggStream {
         })
     }
 
-    fn compute_aggregates(&mut self) -> Result<RecordBatch> {
+    fn compute_aggregates(&mut self) -> Result<Option<RecordBatch>> {
         // calculate window cols
         for (cur_window_expr, state) in
             self.window_expr.iter().zip(&mut self.window_agg_states)
@@ -964,9 +964,9 @@ impl BoundedWindowAggStream {
                 .collect::<Vec<_>>();
             let n_generated = columns_to_show[0].len();
             self.prune_state(n_generated)?;
-            Ok(RecordBatch::try_new(schema, columns_to_show)?)
+            Ok(Some(RecordBatch::try_new(schema, columns_to_show)?))
         } else {
-            Ok(RecordBatch::new_empty(schema))
+            Ok(None)
         }
     }
 
@@ -979,7 +979,7 @@ impl BoundedWindowAggStream {
             return Poll::Ready(None);
         }
 
-        let result = match ready!(self.input.poll_next_unpin(cx)) {
+        match ready!(self.input.poll_next_unpin(cx)) {
             Some(Ok(batch)) => {
                 self.search_mode.update_partition_batch(
                     &mut self.input_buffer,
@@ -987,18 +987,23 @@ impl BoundedWindowAggStream {
                     &self.window_expr,
                     &mut self.partition_buffers,
                 )?;
-                self.compute_aggregates()
+                if let Some(batch) = self.compute_aggregates()? {
+                    return Poll::Ready(Some(Ok(batch)));
+                }
+                self.poll_next_inner(cx)
             }
-            Some(Err(e)) => Err(e),
+            Some(Err(e)) => Poll::Ready(Some(Err(e))),
             None => {
                 self.finished = true;
                 for (_, partition_batch_state) in 
self.partition_buffers.iter_mut() {
                     partition_batch_state.is_end = true;
                 }
-                self.compute_aggregates()
+                if let Some(batch) = self.compute_aggregates()? {
+                    return Poll::Ready(Some(Ok(batch)));
+                }
+                Poll::Ready(None)
             }
-        };
-        Poll::Ready(Some(result))
+        }
     }
 
     /// Prunes the sections of the record batch (for each partition)
diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs 
b/datafusion/physical-plan/src/windows/window_agg_exec.rs
index f71a0b9fd0..c0ac96d22e 100644
--- a/datafusion/physical-plan/src/windows/window_agg_exec.rs
+++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs
@@ -312,12 +312,13 @@ impl WindowAggStream {
         })
     }
 
-    fn compute_aggregates(&self) -> Result<RecordBatch> {
+    fn compute_aggregates(&self) -> Result<Option<RecordBatch>> {
         // record compute time on drop
         let _timer = self.baseline_metrics.elapsed_compute().timer();
+
         let batch = concat_batches(&self.input.schema(), &self.batches)?;
         if batch.num_rows() == 0 {
-            return Ok(RecordBatch::new_empty(Arc::clone(&self.schema)));
+            return Ok(None);
         }
 
         let partition_by_sort_keys = self
@@ -350,10 +351,10 @@ impl WindowAggStream {
         let mut batch_columns = batch.columns().to_vec();
         // calculate window cols
         batch_columns.extend_from_slice(&columns);
-        Ok(RecordBatch::try_new(
+        Ok(Some(RecordBatch::try_new(
             Arc::clone(&self.schema),
             batch_columns,
-        )?)
+        )?))
     }
 }
 
@@ -380,18 +381,23 @@ impl WindowAggStream {
         }
 
         loop {
-            let result = match ready!(self.input.poll_next_unpin(cx)) {
+            return Poll::Ready(Some(match 
ready!(self.input.poll_next_unpin(cx)) {
                 Some(Ok(batch)) => {
                     self.batches.push(batch);
                     continue;
                 }
                 Some(Err(e)) => Err(e),
-                None => self.compute_aggregates(),
-            };
-
-            self.finished = true;
-
-            return Poll::Ready(Some(result));
+                None => {
+                    let Some(result) = self.compute_aggregates()? else {
+                        return Poll::Ready(None);
+                    };
+                    self.finished = true;
+                    // Empty record batches should not be emitted.
+                    // They need to be treated as  [`Option<RecordBatch>`]es 
and handled separately
+                    debug_assert!(result.num_rows() > 0);
+                    Ok(result)
+                }
+            }));
         }
     }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to