This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 6b55adf5c1 Fix Unnest for array aggregations. (#7300)
6b55adf5c1 is described below
commit 6b55adf5c15f10ebea6d9c08053f9a5c10e5f08c
Author: vincev <[email protected]>
AuthorDate: Wed Aug 16 19:39:15 2023 +0200
Fix Unnest for array aggregations. (#7300)
---
datafusion/core/src/physical_plan/unnest.rs | 40 +++----------
datafusion/core/tests/dataframe/mod.rs | 90 +++++++++++++++++++++++++++++
2 files changed, 97 insertions(+), 33 deletions(-)
diff --git a/datafusion/core/src/physical_plan/unnest.rs
b/datafusion/core/src/physical_plan/unnest.rs
index b022cf751f..ec224b286e 100644
--- a/datafusion/core/src/physical_plan/unnest.rs
+++ b/datafusion/core/src/physical_plan/unnest.rs
@@ -18,8 +18,8 @@
//! Defines the unnest column plan for unnesting values in a column that
contains a list
//! type, conceptually is like joining each row with all the values in the
list column.
use arrow::array::{
- new_null_array, Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType,
- FixedSizeListArray, Int32Array, LargeListArray, ListArray, PrimitiveArray,
+ new_empty_array, new_null_array, Array, ArrayAccessor, ArrayRef,
ArrowPrimitiveType,
+ FixedSizeListArray, LargeListArray, ListArray, PrimitiveArray,
};
use arrow::compute::kernels;
use arrow::datatypes::{
@@ -307,7 +307,7 @@ where
// 1, null, 3, null, 2
//
// Depending on the list type the result may be Int32Array or Int64Array.
- let list_lengths = list_lengths(list_array)?;
+ let list_lengths = kernels::length::length(list_array)?;
// Create the indices for the take kernel and then use those indices to
create
// the unnested record batch.
@@ -459,6 +459,10 @@ where
}
};
+ if list_array.is_empty() {
+ return Ok(new_empty_array(elem_type));
+ }
+
let null_row = new_null_array(elem_type, 1);
// Create a vec of ArrayRef from the list elements.
@@ -478,33 +482,3 @@ where
Ok(kernels::concat::concat(&arrays)?)
}
-
-/// Returns an array with the lengths of each list in `list_array`. Returns
null
-/// for a null value.
-fn list_lengths<T>(list_array: &T) -> Result<Arc<dyn Array + 'static>>
-where
- T: ArrayAccessor<Item = ArrayRef>,
-{
- match list_array.data_type() {
- DataType::List(_) | DataType::LargeList(_) => {
- Ok(kernels::length::length(list_array)?)
- }
- DataType::FixedSizeList(_, size) => {
- // Handle FixedSizeList as it is not handled by the `length`
kernel.
- // https://github.com/apache/arrow-rs/issues/4517
- let mut lengths = Vec::with_capacity(list_array.len());
- for row in 0..list_array.len() {
- if list_array.is_null(row) {
- lengths.push(None)
- } else {
- lengths.push(Some(*size));
- }
- }
-
- Ok(Arc::new(Int32Array::from(lengths)))
- }
- dt => Err(DataFusionError::Execution(format!(
- "Invalid type {dt} for list_lengths"
- ))),
- }
-}
diff --git a/datafusion/core/tests/dataframe/mod.rs
b/datafusion/core/tests/dataframe/mod.rs
index bfdb2bda1b..48ec4ef073 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -1298,6 +1298,96 @@ async fn unnest_aggregate_columns() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn unnest_array_agg() -> Result<()> {
+ let mut shape_id_builder = UInt32Builder::new();
+ let mut tag_id_builder = UInt32Builder::new();
+
+ for shape_id in 1..=3 {
+ for tag_id in 1..=3 {
+ shape_id_builder.append_value(shape_id as u32);
+ tag_id_builder.append_value((shape_id * 10 + tag_id) as u32);
+ }
+ }
+
+ let batch = RecordBatch::try_from_iter(vec![
+ ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef),
+ ("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef),
+ ])?;
+
+ let ctx = SessionContext::new();
+ ctx.register_batch("shapes", batch)?;
+ let df = ctx.table("shapes").await?;
+
+ let results = df.clone().collect().await?;
+ let expected = vec![
+ "+----------+--------+",
+ "| shape_id | tag_id |",
+ "+----------+--------+",
+ "| 1 | 11 |",
+ "| 1 | 12 |",
+ "| 1 | 13 |",
+ "| 2 | 21 |",
+ "| 2 | 22 |",
+ "| 2 | 23 |",
+ "| 3 | 31 |",
+ "| 3 | 32 |",
+ "| 3 | 33 |",
+ "+----------+--------+",
+ ];
+ assert_batches_sorted_eq!(expected, &results);
+
+ // Doing an `array_agg` by `shape_id` produces:
+ let results = df
+ .clone()
+ .aggregate(
+ vec![col("shape_id")],
+ vec![array_agg(col("tag_id")).alias("tag_id")],
+ )?
+ .collect()
+ .await?;
+ let expected = vec![
+ "+----------+--------------+",
+ "| shape_id | tag_id |",
+ "+----------+--------------+",
+ "| 1 | [11, 12, 13] |",
+ "| 2 | [21, 22, 23] |",
+ "| 3 | [31, 32, 33] |",
+ "+----------+--------------+",
+ ];
+ assert_batches_sorted_eq!(expected, &results);
+
+ // Unnesting again should produce the original batch.
+ let results = ctx
+ .table("shapes")
+ .await?
+ .aggregate(
+ vec![col("shape_id")],
+ vec![array_agg(col("tag_id")).alias("tag_id")],
+ )?
+ .unnest_column("tag_id")?
+ .collect()
+ .await?;
+ let expected = vec![
+ "+----------+--------+",
+ "| shape_id | tag_id |",
+ "+----------+--------+",
+ "| 1 | 11 |",
+ "| 1 | 12 |",
+ "| 1 | 13 |",
+ "| 2 | 21 |",
+ "| 2 | 22 |",
+ "| 2 | 23 |",
+ "| 3 | 31 |",
+ "| 3 | 32 |",
+ "| 3 | 33 |",
+ "+----------+--------+",
+ ];
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
async fn create_test_table(name: &str) -> Result<DataFrame> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),