This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 6b55adf5c1 Fix Unnest for array aggregations. (#7300)
6b55adf5c1 is described below

commit 6b55adf5c15f10ebea6d9c08053f9a5c10e5f08c
Author: vincev <[email protected]>
AuthorDate: Wed Aug 16 19:39:15 2023 +0200

    Fix Unnest for array aggregations. (#7300)
---
 datafusion/core/src/physical_plan/unnest.rs | 40 +++----------
 datafusion/core/tests/dataframe/mod.rs      | 90 +++++++++++++++++++++++++++++
 2 files changed, 97 insertions(+), 33 deletions(-)

diff --git a/datafusion/core/src/physical_plan/unnest.rs 
b/datafusion/core/src/physical_plan/unnest.rs
index b022cf751f..ec224b286e 100644
--- a/datafusion/core/src/physical_plan/unnest.rs
+++ b/datafusion/core/src/physical_plan/unnest.rs
@@ -18,8 +18,8 @@
 //! Defines the unnest column plan for unnesting values in a column that 
contains a list
 //! type, conceptually is like joining each row with all the values in the 
list column.
 use arrow::array::{
-    new_null_array, Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType,
-    FixedSizeListArray, Int32Array, LargeListArray, ListArray, PrimitiveArray,
+    new_empty_array, new_null_array, Array, ArrayAccessor, ArrayRef, 
ArrowPrimitiveType,
+    FixedSizeListArray, LargeListArray, ListArray, PrimitiveArray,
 };
 use arrow::compute::kernels;
 use arrow::datatypes::{
@@ -307,7 +307,7 @@ where
     //   1, null, 3, null, 2
     //
     // Depending on the list type the result may be Int32Array or Int64Array.
-    let list_lengths = list_lengths(list_array)?;
+    let list_lengths = kernels::length::length(list_array)?;
 
     // Create the indices for the take kernel and then use those indices to 
create
     // the unnested record batch.
@@ -459,6 +459,10 @@ where
         }
     };
 
+    if list_array.is_empty() {
+        return Ok(new_empty_array(elem_type));
+    }
+
     let null_row = new_null_array(elem_type, 1);
 
     // Create a vec of ArrayRef from the list elements.
@@ -478,33 +482,3 @@ where
 
     Ok(kernels::concat::concat(&arrays)?)
 }
-
-/// Returns an array with the lengths of each list in `list_array`. Returns 
null
-/// for a null value.
-fn list_lengths<T>(list_array: &T) -> Result<Arc<dyn Array + 'static>>
-where
-    T: ArrayAccessor<Item = ArrayRef>,
-{
-    match list_array.data_type() {
-        DataType::List(_) | DataType::LargeList(_) => {
-            Ok(kernels::length::length(list_array)?)
-        }
-        DataType::FixedSizeList(_, size) => {
-            // Handle FixedSizeList as it is not handled by the `length` 
kernel.
-            // https://github.com/apache/arrow-rs/issues/4517
-            let mut lengths = Vec::with_capacity(list_array.len());
-            for row in 0..list_array.len() {
-                if list_array.is_null(row) {
-                    lengths.push(None)
-                } else {
-                    lengths.push(Some(*size));
-                }
-            }
-
-            Ok(Arc::new(Int32Array::from(lengths)))
-        }
-        dt => Err(DataFusionError::Execution(format!(
-            "Invalid type {dt} for list_lengths"
-        ))),
-    }
-}
diff --git a/datafusion/core/tests/dataframe/mod.rs 
b/datafusion/core/tests/dataframe/mod.rs
index bfdb2bda1b..48ec4ef073 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -1298,6 +1298,96 @@ async fn unnest_aggregate_columns() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn unnest_array_agg() -> Result<()> {
+    let mut shape_id_builder = UInt32Builder::new();
+    let mut tag_id_builder = UInt32Builder::new();
+
+    for shape_id in 1..=3 {
+        for tag_id in 1..=3 {
+            shape_id_builder.append_value(shape_id as u32);
+            tag_id_builder.append_value((shape_id * 10 + tag_id) as u32);
+        }
+    }
+
+    let batch = RecordBatch::try_from_iter(vec![
+        ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef),
+        ("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef),
+    ])?;
+
+    let ctx = SessionContext::new();
+    ctx.register_batch("shapes", batch)?;
+    let df = ctx.table("shapes").await?;
+
+    let results = df.clone().collect().await?;
+    let expected = vec![
+        "+----------+--------+",
+        "| shape_id | tag_id |",
+        "+----------+--------+",
+        "| 1        | 11     |",
+        "| 1        | 12     |",
+        "| 1        | 13     |",
+        "| 2        | 21     |",
+        "| 2        | 22     |",
+        "| 2        | 23     |",
+        "| 3        | 31     |",
+        "| 3        | 32     |",
+        "| 3        | 33     |",
+        "+----------+--------+",
+    ];
+    assert_batches_sorted_eq!(expected, &results);
+
+    // Doing an `array_agg` by `shape_id` produces:
+    let results = df
+        .clone()
+        .aggregate(
+            vec![col("shape_id")],
+            vec![array_agg(col("tag_id")).alias("tag_id")],
+        )?
+        .collect()
+        .await?;
+    let expected = vec![
+        "+----------+--------------+",
+        "| shape_id | tag_id       |",
+        "+----------+--------------+",
+        "| 1        | [11, 12, 13] |",
+        "| 2        | [21, 22, 23] |",
+        "| 3        | [31, 32, 33] |",
+        "+----------+--------------+",
+    ];
+    assert_batches_sorted_eq!(expected, &results);
+
+    // Unnesting again should produce the original batch.
+    let results = ctx
+        .table("shapes")
+        .await?
+        .aggregate(
+            vec![col("shape_id")],
+            vec![array_agg(col("tag_id")).alias("tag_id")],
+        )?
+        .unnest_column("tag_id")?
+        .collect()
+        .await?;
+    let expected = vec![
+        "+----------+--------+",
+        "| shape_id | tag_id |",
+        "+----------+--------+",
+        "| 1        | 11     |",
+        "| 1        | 12     |",
+        "| 1        | 13     |",
+        "| 2        | 21     |",
+        "| 2        | 22     |",
+        "| 2        | 23     |",
+        "| 3        | 31     |",
+        "| 3        | 32     |",
+        "| 3        | 33     |",
+        "+----------+--------+",
+    ];
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
 async fn create_test_table(name: &str) -> Result<DataFrame> {
     let schema = Arc::new(Schema::new(vec![
         Field::new("a", DataType::Utf8, false),

Reply via email to