This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch branch-42
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/branch-42 by this push:
     new 7fbc134c8b Patch for PR 12586 (#12976)
7fbc134c8b is described below

commit 7fbc134c8b7656bdd27bd9bf7df4e34588804fbb
Author: Matthew Turner <[email protected]>
AuthorDate: Wed Oct 16 17:00:52 2024 -0400

    Patch for PR 12586 (#12976)
---
 .../src/aggregates/group_values/row.rs             |  60 ++++++++--
 datafusion/physical-plan/src/aggregates/mod.rs     | 128 ++++++++++++++++++++-
 2 files changed, 175 insertions(+), 13 deletions(-)

diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs 
b/datafusion/physical-plan/src/aggregates/group_values/row.rs
index dc948e28bb..93a3e04a90 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/row.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs
@@ -20,13 +20,14 @@ use ahash::RandomState;
 use arrow::compute::cast;
 use arrow::record_batch::RecordBatch;
 use arrow::row::{RowConverter, Rows, SortField};
-use arrow_array::{Array, ArrayRef};
+use arrow_array::{Array, ArrayRef, ListArray, StructArray};
 use arrow_schema::{DataType, SchemaRef};
 use datafusion_common::hash_utils::create_hashes;
-use datafusion_common::{DataFusionError, Result};
+use datafusion_common::Result;
 use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
 use datafusion_expr::EmitTo;
 use hashbrown::raw::RawTable;
+use std::sync::Arc;
 
 /// A [`GroupValues`] making use of [`Rows`]
 pub struct GroupValuesRows {
@@ -221,15 +222,10 @@ impl GroupValues for GroupValuesRows {
         // TODO: Materialize dictionaries in group keys (#7647)
         for (field, array) in self.schema.fields.iter().zip(&mut output) {
             let expected = field.data_type();
-            if let DataType::Dictionary(_, v) = expected {
-                let actual = array.data_type();
-                if v.as_ref() != actual {
-                    return Err(DataFusionError::Internal(format!(
-                        "Converted group rows expected dictionary of {v} got 
{actual}"
-                    )));
-                }
-                *array = cast(array.as_ref(), expected)?;
-            }
+            *array = dictionary_encode_if_necessary(
+                Arc::<dyn arrow_array::Array>::clone(array),
+                expected,
+            )?;
         }
 
         self.group_values = Some(group_values);
@@ -249,3 +245,45 @@ impl GroupValues for GroupValuesRows {
         self.hashes_buffer.shrink_to(count);
     }
 }
+
+fn dictionary_encode_if_necessary(
+    array: ArrayRef,
+    expected: &DataType,
+) -> Result<ArrayRef> {
+    match (expected, array.data_type()) {
+        (DataType::Struct(expected_fields), _) => {
+            let struct_array = 
array.as_any().downcast_ref::<StructArray>().unwrap();
+            let arrays = expected_fields
+                .iter()
+                .zip(struct_array.columns())
+                .map(|(expected_field, column)| {
+                    dictionary_encode_if_necessary(
+                        Arc::<dyn arrow_array::Array>::clone(column),
+                        expected_field.data_type(),
+                    )
+                })
+                .collect::<Result<Vec<_>>>()?;
+
+            Ok(Arc::new(StructArray::try_new(
+                expected_fields.clone(),
+                arrays,
+                struct_array.nulls().cloned(),
+            )?))
+        }
+        (DataType::List(expected_field), &DataType::List(_)) => {
+            let list = array.as_any().downcast_ref::<ListArray>().unwrap();
+
+            Ok(Arc::new(ListArray::try_new(
+                Arc::<arrow_schema::Field>::clone(expected_field),
+                list.offsets().clone(),
+                dictionary_encode_if_necessary(
+                    Arc::<dyn arrow_array::Array>::clone(list.values()),
+                    expected_field.data_type(),
+                )?,
+                list.nulls().cloned(),
+            )?))
+        }
+        (DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?),
+        (_, _) => Ok(Arc::<dyn arrow_array::Array>::clone(&array)),
+    }
+}
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs 
b/datafusion/physical-plan/src/aggregates/mod.rs
index c3bc7b042e..617f1da3ab 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -1200,8 +1200,10 @@ mod tests {
 
     use arrow::array::{Float64Array, UInt32Array};
     use arrow::compute::{concat_batches, SortOptions};
-    use arrow::datatypes::DataType;
-    use arrow_array::{Float32Array, Int32Array};
+    use arrow::datatypes::{DataType, Int32Type};
+    use arrow_array::{
+        DictionaryArray, Float32Array, Int32Array, StructArray, UInt64Array,
+    };
     use datafusion_common::{
         assert_batches_eq, assert_batches_sorted_eq, internal_err, 
DataFusionError,
         ScalarValue,
@@ -1214,6 +1216,7 @@ mod tests {
     use datafusion_functions_aggregate::count::count_udaf;
     use datafusion_functions_aggregate::first_last::{first_value_udaf, 
last_value_udaf};
     use datafusion_functions_aggregate::median::median_udaf;
+    use datafusion_functions_aggregate::sum::sum_udaf;
     use datafusion_physical_expr::expressions::lit;
     use datafusion_physical_expr::PhysicalSortExpr;
 
@@ -2316,6 +2319,127 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn test_agg_exec_struct_of_dicts() -> Result<()> {
+        let batch = RecordBatch::try_new(
+            Arc::new(Schema::new(vec![
+                Field::new(
+                    "labels".to_string(),
+                    DataType::Struct(
+                        vec![
+                            Field::new_dict(
+                                "a".to_string(),
+                                DataType::Dictionary(
+                                    Box::new(DataType::Int32),
+                                    Box::new(DataType::Utf8),
+                                ),
+                                true,
+                                0,
+                                false,
+                            ),
+                            Field::new_dict(
+                                "b".to_string(),
+                                DataType::Dictionary(
+                                    Box::new(DataType::Int32),
+                                    Box::new(DataType::Utf8),
+                                ),
+                                true,
+                                0,
+                                false,
+                            ),
+                        ]
+                        .into(),
+                    ),
+                    false,
+                ),
+                Field::new("value", DataType::UInt64, false),
+            ])),
+            vec![
+                Arc::new(StructArray::from(vec![
+                    (
+                        Arc::new(Field::new_dict(
+                            "a".to_string(),
+                            DataType::Dictionary(
+                                Box::new(DataType::Int32),
+                                Box::new(DataType::Utf8),
+                            ),
+                            true,
+                            0,
+                            false,
+                        )),
+                        Arc::new(
+                            vec![Some("a"), None, Some("a")]
+                                .into_iter()
+                                .collect::<DictionaryArray<Int32Type>>(),
+                        ) as ArrayRef,
+                    ),
+                    (
+                        Arc::new(Field::new_dict(
+                            "b".to_string(),
+                            DataType::Dictionary(
+                                Box::new(DataType::Int32),
+                                Box::new(DataType::Utf8),
+                            ),
+                            true,
+                            0,
+                            false,
+                        )),
+                        Arc::new(
+                            vec![Some("b"), Some("c"), Some("b")]
+                                .into_iter()
+                                .collect::<DictionaryArray<Int32Type>>(),
+                        ) as ArrayRef,
+                    ),
+                ])),
+                Arc::new(UInt64Array::from(vec![1, 1, 1])),
+            ],
+        )
+        .expect("Failed to create RecordBatch");
+
+        let group_by = PhysicalGroupBy::new_single(vec![(
+            col("labels", &batch.schema())?,
+            "labels".to_string(),
+        )]);
+
+        let aggr_expr = vec![AggregateExprBuilder::new(
+            sum_udaf(),
+            vec![col("value", &batch.schema())?],
+        )
+        .schema(Arc::clone(&batch.schema()))
+        .alias(String::from("SUM(value)"))
+        .build()?];
+
+        let input = Arc::new(MemoryExec::try_new(
+            &[vec![batch.clone()]],
+            Arc::<arrow_schema::Schema>::clone(&batch.schema()),
+            None,
+        )?);
+        let aggregate_exec = Arc::new(AggregateExec::try_new(
+            AggregateMode::FinalPartitioned,
+            group_by,
+            aggr_expr,
+            vec![None],
+            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
+            batch.schema(),
+        )?);
+
+        let session_config = SessionConfig::default();
+        let ctx = TaskContext::default().with_session_config(session_config);
+        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
+
+        let expected = [
+            "+--------------+------------+",
+            "| labels       | SUM(value) |",
+            "+--------------+------------+",
+            "| {a: a, b: b} | 2          |",
+            "| {a: , b: c}  | 1          |",
+            "+--------------+------------+",
+        ];
+        assert_batches_eq!(expected, &output);
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn test_skip_aggregation_after_first_batch() -> Result<()> {
         let schema = Arc::new(Schema::new(vec![


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to