This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch branch-42
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/branch-42 by this push:
new 7fbc134c8b Patch for PR 12586 (#12976)
7fbc134c8b is described below
commit 7fbc134c8b7656bdd27bd9bf7df4e34588804fbb
Author: Matthew Turner <[email protected]>
AuthorDate: Wed Oct 16 17:00:52 2024 -0400
Patch for PR 12586 (#12976)
---
.../src/aggregates/group_values/row.rs | 60 ++++++++--
datafusion/physical-plan/src/aggregates/mod.rs | 128 ++++++++++++++++++++-
2 files changed, 175 insertions(+), 13 deletions(-)
diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs
b/datafusion/physical-plan/src/aggregates/group_values/row.rs
index dc948e28bb..93a3e04a90 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/row.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs
@@ -20,13 +20,14 @@ use ahash::RandomState;
use arrow::compute::cast;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, Rows, SortField};
-use arrow_array::{Array, ArrayRef};
+use arrow_array::{Array, ArrayRef, ListArray, StructArray};
use arrow_schema::{DataType, SchemaRef};
use datafusion_common::hash_utils::create_hashes;
-use datafusion_common::{DataFusionError, Result};
+use datafusion_common::Result;
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_expr::EmitTo;
use hashbrown::raw::RawTable;
+use std::sync::Arc;
/// A [`GroupValues`] making use of [`Rows`]
pub struct GroupValuesRows {
@@ -221,15 +222,10 @@ impl GroupValues for GroupValuesRows {
// TODO: Materialize dictionaries in group keys (#7647)
for (field, array) in self.schema.fields.iter().zip(&mut output) {
let expected = field.data_type();
- if let DataType::Dictionary(_, v) = expected {
- let actual = array.data_type();
- if v.as_ref() != actual {
- return Err(DataFusionError::Internal(format!(
- "Converted group rows expected dictionary of {v} got
{actual}"
- )));
- }
- *array = cast(array.as_ref(), expected)?;
- }
+ *array = dictionary_encode_if_necessary(
+ Arc::<dyn arrow_array::Array>::clone(array),
+ expected,
+ )?;
}
self.group_values = Some(group_values);
@@ -249,3 +245,45 @@ impl GroupValues for GroupValuesRows {
self.hashes_buffer.shrink_to(count);
}
}
+
+fn dictionary_encode_if_necessary(
+ array: ArrayRef,
+ expected: &DataType,
+) -> Result<ArrayRef> {
+ match (expected, array.data_type()) {
+ (DataType::Struct(expected_fields), _) => {
+ let struct_array =
array.as_any().downcast_ref::<StructArray>().unwrap();
+ let arrays = expected_fields
+ .iter()
+ .zip(struct_array.columns())
+ .map(|(expected_field, column)| {
+ dictionary_encode_if_necessary(
+ Arc::<dyn arrow_array::Array>::clone(column),
+ expected_field.data_type(),
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(Arc::new(StructArray::try_new(
+ expected_fields.clone(),
+ arrays,
+ struct_array.nulls().cloned(),
+ )?))
+ }
+ (DataType::List(expected_field), &DataType::List(_)) => {
+ let list = array.as_any().downcast_ref::<ListArray>().unwrap();
+
+ Ok(Arc::new(ListArray::try_new(
+ Arc::<arrow_schema::Field>::clone(expected_field),
+ list.offsets().clone(),
+ dictionary_encode_if_necessary(
+ Arc::<dyn arrow_array::Array>::clone(list.values()),
+ expected_field.data_type(),
+ )?,
+ list.nulls().cloned(),
+ )?))
+ }
+ (DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?),
+ (_, _) => Ok(Arc::<dyn arrow_array::Array>::clone(&array)),
+ }
+}
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs
b/datafusion/physical-plan/src/aggregates/mod.rs
index c3bc7b042e..617f1da3ab 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -1200,8 +1200,10 @@ mod tests {
use arrow::array::{Float64Array, UInt32Array};
use arrow::compute::{concat_batches, SortOptions};
- use arrow::datatypes::DataType;
- use arrow_array::{Float32Array, Int32Array};
+ use arrow::datatypes::{DataType, Int32Type};
+ use arrow_array::{
+ DictionaryArray, Float32Array, Int32Array, StructArray, UInt64Array,
+ };
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, internal_err,
DataFusionError,
ScalarValue,
@@ -1214,6 +1216,7 @@ mod tests {
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::first_last::{first_value_udaf,
last_value_udaf};
use datafusion_functions_aggregate::median::median_udaf;
+ use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::lit;
use datafusion_physical_expr::PhysicalSortExpr;
@@ -2316,6 +2319,127 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn test_agg_exec_struct_of_dicts() -> Result<()> {
+ let batch = RecordBatch::try_new(
+ Arc::new(Schema::new(vec![
+ Field::new(
+ "labels".to_string(),
+ DataType::Struct(
+ vec![
+ Field::new_dict(
+ "a".to_string(),
+ DataType::Dictionary(
+ Box::new(DataType::Int32),
+ Box::new(DataType::Utf8),
+ ),
+ true,
+ 0,
+ false,
+ ),
+ Field::new_dict(
+ "b".to_string(),
+ DataType::Dictionary(
+ Box::new(DataType::Int32),
+ Box::new(DataType::Utf8),
+ ),
+ true,
+ 0,
+ false,
+ ),
+ ]
+ .into(),
+ ),
+ false,
+ ),
+ Field::new("value", DataType::UInt64, false),
+ ])),
+ vec![
+ Arc::new(StructArray::from(vec![
+ (
+ Arc::new(Field::new_dict(
+ "a".to_string(),
+ DataType::Dictionary(
+ Box::new(DataType::Int32),
+ Box::new(DataType::Utf8),
+ ),
+ true,
+ 0,
+ false,
+ )),
+ Arc::new(
+ vec![Some("a"), None, Some("a")]
+ .into_iter()
+ .collect::<DictionaryArray<Int32Type>>(),
+ ) as ArrayRef,
+ ),
+ (
+ Arc::new(Field::new_dict(
+ "b".to_string(),
+ DataType::Dictionary(
+ Box::new(DataType::Int32),
+ Box::new(DataType::Utf8),
+ ),
+ true,
+ 0,
+ false,
+ )),
+ Arc::new(
+ vec![Some("b"), Some("c"), Some("b")]
+ .into_iter()
+ .collect::<DictionaryArray<Int32Type>>(),
+ ) as ArrayRef,
+ ),
+ ])),
+ Arc::new(UInt64Array::from(vec![1, 1, 1])),
+ ],
+ )
+ .expect("Failed to create RecordBatch");
+
+ let group_by = PhysicalGroupBy::new_single(vec![(
+ col("labels", &batch.schema())?,
+ "labels".to_string(),
+ )]);
+
+ let aggr_expr = vec![AggregateExprBuilder::new(
+ sum_udaf(),
+ vec![col("value", &batch.schema())?],
+ )
+ .schema(Arc::clone(&batch.schema()))
+ .alias(String::from("SUM(value)"))
+ .build()?];
+
+ let input = Arc::new(MemoryExec::try_new(
+ &[vec![batch.clone()]],
+ Arc::<arrow_schema::Schema>::clone(&batch.schema()),
+ None,
+ )?);
+ let aggregate_exec = Arc::new(AggregateExec::try_new(
+ AggregateMode::FinalPartitioned,
+ group_by,
+ aggr_expr,
+ vec![None],
+ Arc::clone(&input) as Arc<dyn ExecutionPlan>,
+ batch.schema(),
+ )?);
+
+ let session_config = SessionConfig::default();
+ let ctx = TaskContext::default().with_session_config(session_config);
+ let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
+
+ let expected = [
+ "+--------------+------------+",
+ "| labels | SUM(value) |",
+ "+--------------+------------+",
+ "| {a: a, b: b} | 2 |",
+ "| {a: , b: c} | 1 |",
+ "+--------------+------------+",
+ ];
+ assert_batches_eq!(expected, &output);
+
+ Ok(())
+ }
+
#[tokio::test]
async fn test_skip_aggregation_after_first_batch() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]