This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ff27d90734 Fix regression by reverting Materialize dictionaries in 
group keys (#8740)
ff27d90734 is described below

commit ff27d9073421d527439e6e338f31fb568227bbb2
Author: Andrew Lamb <[email protected]>
AuthorDate: Mon Jan 8 09:29:33 2024 -0500

    Fix regression by reverting Materialize dictionaries in group keys (#8740)
    
    * revert eb8aff7becaf5d4a44c723b29445deb958fbe3b4 / Materialize 
dictionaries in group keys
    
    * Update tests
    
    * Update tests
---
 datafusion/core/tests/path_partition.rs            | 15 +++-
 .../src/aggregates/group_values/row.rs             | 27 ++++++--
 datafusion/physical-plan/src/aggregates/mod.rs     | 35 +---------
 .../physical-plan/src/aggregates/row_hash.rs       |  4 +-
 datafusion/sqllogictest/test_files/aggregate.slt   | 10 +--
 datafusion/sqllogictest/test_files/dictionary.slt  | 81 +++++++++++++++++++++-
 6 files changed, 124 insertions(+), 48 deletions(-)

diff --git a/datafusion/core/tests/path_partition.rs 
b/datafusion/core/tests/path_partition.rs
index abe6ab283a..dd8eb52f67 100644
--- a/datafusion/core/tests/path_partition.rs
+++ b/datafusion/core/tests/path_partition.rs
@@ -168,9 +168,9 @@ async fn parquet_distinct_partition_col() -> Result<()> {
     assert_eq!(min_limit, resulting_limit);
 
     let s = ScalarValue::try_from_array(results[0].column(1), 0)?;
-    let month = match s {
-        ScalarValue::Utf8(Some(month)) => month,
-        s => panic!("Expected month as Utf8 found {s:?}"),
+    let month = match extract_as_utf(&s) {
+        Some(month) => month,
+        s => panic!("Expected month as Dict(_, Utf8) found {s:?}"),
     };
 
     let sql_on_partition_boundary = format!(
@@ -191,6 +191,15 @@ async fn parquet_distinct_partition_col() -> Result<()> {
     Ok(())
 }
 
+fn extract_as_utf(v: &ScalarValue) -> Option<String> {
+    if let ScalarValue::Dictionary(_, v) = v {
+        if let ScalarValue::Utf8(v) = v.as_ref() {
+            return v.clone();
+        }
+    }
+    None
+}
+
 #[tokio::test]
 async fn csv_filter_with_file_col() -> Result<()> {
     let ctx = SessionContext::new();
diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs 
b/datafusion/physical-plan/src/aggregates/group_values/row.rs
index e7c7a42cf9..10ff9edb89 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/row.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs
@@ -17,18 +17,22 @@
 
 use crate::aggregates::group_values::GroupValues;
 use ahash::RandomState;
+use arrow::compute::cast;
 use arrow::record_batch::RecordBatch;
 use arrow::row::{RowConverter, Rows, SortField};
-use arrow_array::ArrayRef;
-use arrow_schema::SchemaRef;
+use arrow_array::{Array, ArrayRef};
+use arrow_schema::{DataType, SchemaRef};
 use datafusion_common::hash_utils::create_hashes;
-use datafusion_common::Result;
+use datafusion_common::{DataFusionError, Result};
 use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
 use datafusion_physical_expr::EmitTo;
 use hashbrown::raw::RawTable;
 
 /// A [`GroupValues`] making use of [`Rows`]
 pub struct GroupValuesRows {
+    /// The output schema
+    schema: SchemaRef,
+
     /// Converter for the group values
     row_converter: RowConverter,
 
@@ -75,6 +79,7 @@ impl GroupValuesRows {
         let map = RawTable::with_capacity(0);
 
         Ok(Self {
+            schema,
             row_converter,
             map,
             map_size: 0,
@@ -165,7 +170,7 @@ impl GroupValues for GroupValuesRows {
             .take()
             .expect("Can not emit from empty rows");
 
-        let output = match emit_to {
+        let mut output = match emit_to {
             EmitTo::All => {
                 let output = self.row_converter.convert_rows(&group_values)?;
                 group_values.clear();
@@ -198,6 +203,20 @@ impl GroupValues for GroupValuesRows {
             }
         };
 
+        // TODO: Materialize dictionaries in group keys (#7647)
+        for (field, array) in self.schema.fields.iter().zip(&mut output) {
+            let expected = field.data_type();
+            if let DataType::Dictionary(_, v) = expected {
+                let actual = array.data_type();
+                if v.as_ref() != actual {
+                    return Err(DataFusionError::Internal(format!(
+                        "Converted group rows expected dictionary of {v} got 
{actual}"
+                    )));
+                }
+                *array = cast(array.as_ref(), expected)?;
+            }
+        }
+
         self.group_values = Some(group_values);
         Ok(output)
     }
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs 
b/datafusion/physical-plan/src/aggregates/mod.rs
index a38044de02..0b94dd01cf 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -36,7 +36,6 @@ use crate::{
 use arrow::array::ArrayRef;
 use arrow::datatypes::{Field, Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
-use arrow_schema::DataType;
 use datafusion_common::stats::Precision;
 use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result};
 use datafusion_execution::TaskContext;
@@ -254,9 +253,6 @@ pub struct AggregateExec {
     limit: Option<usize>,
     /// Input plan, could be a partial aggregate or the input to the aggregate
     pub input: Arc<dyn ExecutionPlan>,
-    /// Original aggregation schema, could be different from `schema` before 
dictionary group
-    /// keys get materialized
-    original_schema: SchemaRef,
     /// Schema after the aggregate is applied
     schema: SchemaRef,
     /// Input schema before any aggregation is applied. For partial aggregate 
this will be the
@@ -287,7 +283,7 @@ impl AggregateExec {
         input: Arc<dyn ExecutionPlan>,
         input_schema: SchemaRef,
     ) -> Result<Self> {
-        let original_schema = create_schema(
+        let schema = create_schema(
             &input.schema(),
             &group_by.expr,
             &aggr_expr,
@@ -295,11 +291,7 @@ impl AggregateExec {
             mode,
         )?;
 
-        let schema = Arc::new(materialize_dict_group_keys(
-            &original_schema,
-            group_by.expr.len(),
-        ));
-        let original_schema = Arc::new(original_schema);
+        let schema = Arc::new(schema);
         AggregateExec::try_new_with_schema(
             mode,
             group_by,
@@ -308,7 +300,6 @@ impl AggregateExec {
             input,
             input_schema,
             schema,
-            original_schema,
         )
     }
 
@@ -329,7 +320,6 @@ impl AggregateExec {
         input: Arc<dyn ExecutionPlan>,
         input_schema: SchemaRef,
         schema: SchemaRef,
-        original_schema: SchemaRef,
     ) -> Result<Self> {
         let input_eq_properties = input.equivalence_properties();
         // Get GROUP BY expressions:
@@ -382,7 +372,6 @@ impl AggregateExec {
             aggr_expr,
             filter_expr,
             input,
-            original_schema,
             schema,
             input_schema,
             projection_mapping,
@@ -693,7 +682,7 @@ impl ExecutionPlan for AggregateExec {
             children[0].clone(),
             self.input_schema.clone(),
             self.schema.clone(),
-            self.original_schema.clone(),
+            //self.original_schema.clone(),
         )?;
         me.limit = self.limit;
         Ok(Arc::new(me))
@@ -800,24 +789,6 @@ fn create_schema(
     Ok(Schema::new(fields))
 }
 
-/// returns schema with dictionary group keys materialized as their value types
-/// The actual convertion happens in `RowConverter` and we don't do unnecessary
-/// conversion back into dictionaries
-fn materialize_dict_group_keys(schema: &Schema, group_count: usize) -> Schema {
-    let fields = schema
-        .fields
-        .iter()
-        .enumerate()
-        .map(|(i, field)| match field.data_type() {
-            DataType::Dictionary(_, value_data_type) if i < group_count => {
-                Field::new(field.name(), *value_data_type.clone(), 
field.is_nullable())
-            }
-            _ => Field::clone(field),
-        })
-        .collect::<Vec<_>>();
-    Schema::new(fields)
-}
-
 fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef {
     let group_fields = schema.fields()[0..group_count].to_vec();
     Arc::new(Schema::new(group_fields))
diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs 
b/datafusion/physical-plan/src/aggregates/row_hash.rs
index 89614fd302..6a0c02f5ca 100644
--- a/datafusion/physical-plan/src/aggregates/row_hash.rs
+++ b/datafusion/physical-plan/src/aggregates/row_hash.rs
@@ -324,9 +324,7 @@ impl GroupedHashAggregateStream {
             .map(create_group_accumulator)
             .collect::<Result<_>>()?;
 
-        // we need to use original schema so RowConverter in group_values below
-        // will do the proper coversion of dictionaries into value types
-        let group_schema = group_schema(&agg.original_schema, 
agg_group_by.expr.len());
+        let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
         let spill_expr = group_schema
             .fields
             .into_iter()
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index 78575c9dff..aa512f6e26 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -2469,11 +2469,11 @@ select max(x_dict) from value_dict group by x_dict % 2 
order by max(x_dict);
 query T
 select arrow_typeof(x_dict) from value_dict group by x_dict;
 ----
-Int32
-Int32
-Int32
-Int32
-Int32
+Dictionary(Int64, Int32)
+Dictionary(Int64, Int32)
+Dictionary(Int64, Int32)
+Dictionary(Int64, Int32)
+Dictionary(Int64, Int32)
 
 statement ok
 drop table value
diff --git a/datafusion/sqllogictest/test_files/dictionary.slt 
b/datafusion/sqllogictest/test_files/dictionary.slt
index b7f375dd6c..002aade252 100644
--- a/datafusion/sqllogictest/test_files/dictionary.slt
+++ b/datafusion/sqllogictest/test_files/dictionary.slt
@@ -169,7 +169,7 @@ order by date_bin('30 minutes', time) DESC
 
 # Reproducer for https://github.com/apache/arrow-datafusion/issues/8738
 # This query should work correctly
-query error DataFusion error: External error: Arrow error: Invalid argument 
error: RowConverter column schema mismatch, expected Utf8 got 
Dictionary\(Int32, Utf8\)
+query P?TT rowsort
 SELECT
   "data"."timestamp" as "time",
   "data"."tag_id",
@@ -201,3 +201,82 @@ ORDER BY
   "time",
   "data"."tag_id"
 ;
+----
+2023-12-20T00:00:00 1000 f1 32.0
+2023-12-20T00:00:00 1000 f2 foo
+2023-12-20T00:10:00 1000 f1 32.0
+2023-12-20T00:10:00 1000 f2 foo
+2023-12-20T00:20:00 1000 f1 32.0
+2023-12-20T00:20:00 1000 f2 foo
+2023-12-20T00:30:00 1000 f1 32.0
+2023-12-20T00:30:00 1000 f2 foo
+2023-12-20T00:40:00 1000 f1 32.0
+2023-12-20T00:40:00 1000 f2 foo
+2023-12-20T00:50:00 1000 f1 32.0
+2023-12-20T00:50:00 1000 f2 foo
+2023-12-20T01:00:00 1000 f1 32.0
+2023-12-20T01:00:00 1000 f2 foo
+2023-12-20T01:10:00 1000 f1 32.0
+2023-12-20T01:10:00 1000 f2 foo
+2023-12-20T01:20:00 1000 f1 32.0
+2023-12-20T01:20:00 1000 f2 foo
+2023-12-20T01:30:00 1000 f1 32.0
+2023-12-20T01:30:00 1000 f2 foo
+
+
+# deterministic sort (so we can avoid rowsort)
+query P?TT
+SELECT
+  "data"."timestamp" as "time",
+  "data"."tag_id",
+  "data"."field",
+  "data"."value"
+FROM (
+  (
+      SELECT "m2"."time" as "timestamp", "m2"."tag_id", 'active_power' as 
"field", "m2"."f5" as "value"
+        FROM "m2"
+       WHERE "m2"."time" >= '2023-12-05T14:46:35+01:00' AND "m2"."time" < 
'2024-01-03T14:46:35+01:00'
+         AND "m2"."f5" IS NOT NULL
+         AND "m2"."type" IN ('active')
+         AND "m2"."tag_id" IN ('1000')
+  ) UNION (
+      SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f1' as "field", 
"m1"."f1" as "value"
+        FROM "m1"
+       WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < 
'2024-01-03T14:46:35+01:00'
+         AND "m1"."f1" IS NOT NULL
+         AND "m1"."tag_id" IN ('1000')
+  ) UNION (
+      SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f2' as "field", 
"m1"."f2" as "value"
+        FROM "m1"
+       WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < 
'2024-01-03T14:46:35+01:00'
+         AND "m1"."f2" IS NOT NULL
+         AND "m1"."tag_id" IN ('1000')
+  )
+) as "data"
+ORDER BY
+  "time",
+  "data"."tag_id",
+  "data"."field",
+  "data"."value"
+;
+----
+2023-12-20T00:00:00 1000 f1 32.0
+2023-12-20T00:00:00 1000 f2 foo
+2023-12-20T00:10:00 1000 f1 32.0
+2023-12-20T00:10:00 1000 f2 foo
+2023-12-20T00:20:00 1000 f1 32.0
+2023-12-20T00:20:00 1000 f2 foo
+2023-12-20T00:30:00 1000 f1 32.0
+2023-12-20T00:30:00 1000 f2 foo
+2023-12-20T00:40:00 1000 f1 32.0
+2023-12-20T00:40:00 1000 f2 foo
+2023-12-20T00:50:00 1000 f1 32.0
+2023-12-20T00:50:00 1000 f2 foo
+2023-12-20T01:00:00 1000 f1 32.0
+2023-12-20T01:00:00 1000 f2 foo
+2023-12-20T01:10:00 1000 f1 32.0
+2023-12-20T01:10:00 1000 f2 foo
+2023-12-20T01:20:00 1000 f1 32.0
+2023-12-20T01:20:00 1000 f2 foo
+2023-12-20T01:30:00 1000 f1 32.0
+2023-12-20T01:30:00 1000 f2 foo

Reply via email to