This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 91a44c1e5a Fix ArrayAgg schema mismatch issue (#8055)
91a44c1e5a is described below

commit 91a44c1e5aaed6b3037eb620c7a753b51755b187
Author: Jay Zhan <[email protected]>
AuthorDate: Fri Nov 10 01:44:54 2023 +0800

    Fix ArrayAgg schema mismatch issue (#8055)
    
    * fix schema
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * upd parquet-testing
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * avoid parquet file
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * reset parquet-testing
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * remove file
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rename and upd docstring
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion/core/src/dataframe/mod.rs               | 86 ++++++++++++++++++++++
 .../physical-expr/src/aggregate/array_agg.rs       | 43 +++++++++--
 .../src/aggregate/array_agg_distinct.rs            | 13 +++-
 .../src/aggregate/array_agg_ordered.rs             | 20 +++--
 datafusion/physical-expr/src/aggregate/build_in.rs | 18 +++--
 5 files changed, 160 insertions(+), 20 deletions(-)

diff --git a/datafusion/core/src/dataframe/mod.rs 
b/datafusion/core/src/dataframe/mod.rs
index 0a99c33182..89e82fa952 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -1340,6 +1340,92 @@ mod tests {
 
     use super::*;
 
+    async fn assert_logical_expr_schema_eq_physical_expr_schema(
+        df: DataFrame,
+    ) -> Result<()> {
+        let logical_expr_dfschema = df.schema();
+        let logical_expr_schema = 
SchemaRef::from(logical_expr_dfschema.to_owned());
+        let batches = df.collect().await?;
+        let physical_expr_schema = batches[0].schema();
+        assert_eq!(logical_expr_schema, physical_expr_schema);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_array_agg_ord_schema() -> Result<()> {
+        let ctx = SessionContext::new();
+
+        let create_table_query = r#"
+            CREATE TABLE test_table (
+                "double_field" DOUBLE,
+                "string_field" VARCHAR
+            ) AS VALUES
+                (1.0, 'a'),
+                (2.0, 'b'),
+                (3.0, 'c')
+        "#;
+        ctx.sql(create_table_query).await?;
+
+        let query = r#"SELECT
+        array_agg("double_field" ORDER BY "string_field") as "double_field",
+        array_agg("string_field" ORDER BY "string_field") as "string_field"
+    FROM test_table"#;
+
+        let result = ctx.sql(query).await?;
+        assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_array_agg_schema() -> Result<()> {
+        let ctx = SessionContext::new();
+
+        let create_table_query = r#"
+            CREATE TABLE test_table (
+                "double_field" DOUBLE,
+                "string_field" VARCHAR
+            ) AS VALUES
+                (1.0, 'a'),
+                (2.0, 'b'),
+                (3.0, 'c')
+        "#;
+        ctx.sql(create_table_query).await?;
+
+        let query = r#"SELECT
+        array_agg("double_field") as "double_field",
+        array_agg("string_field") as "string_field"
+    FROM test_table"#;
+
+        let result = ctx.sql(query).await?;
+        assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_array_agg_distinct_schema() -> Result<()> {
+        let ctx = SessionContext::new();
+
+        let create_table_query = r#"
+            CREATE TABLE test_table (
+                "double_field" DOUBLE,
+                "string_field" VARCHAR
+            ) AS VALUES
+                (1.0, 'a'),
+                (2.0, 'b'),
+                (2.0, 'a')
+        "#;
+        ctx.sql(create_table_query).await?;
+
+        let query = r#"SELECT
+        array_agg(distinct "double_field") as "double_field",
+        array_agg(distinct "string_field") as "string_field"
+    FROM test_table"#;
+
+        let result = ctx.sql(query).await?;
+        assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
+        Ok(())
+    }
+
     #[tokio::test]
     async fn select_columns() -> Result<()> {
         // build plan using Table API
diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs 
b/datafusion/physical-expr/src/aggregate/array_agg.rs
index 4dccbfef07..91d5c867d3 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg.rs
@@ -34,9 +34,14 @@ use std::sync::Arc;
 /// ARRAY_AGG aggregate expression
 #[derive(Debug)]
 pub struct ArrayAgg {
+    /// Column name
     name: String,
+    /// The DataType for the input expression
     input_data_type: DataType,
+    /// The input expression
     expr: Arc<dyn PhysicalExpr>,
+    /// If the input expression can have NULLs
+    nullable: bool,
 }
 
 impl ArrayAgg {
@@ -45,11 +50,13 @@ impl ArrayAgg {
         expr: Arc<dyn PhysicalExpr>,
         name: impl Into<String>,
         data_type: DataType,
+        nullable: bool,
     ) -> Self {
         Self {
             name: name.into(),
-            expr,
             input_data_type: data_type,
+            expr,
+            nullable,
         }
     }
 }
@@ -62,8 +69,9 @@ impl AggregateExpr for ArrayAgg {
     fn field(&self) -> Result<Field> {
         Ok(Field::new_list(
             &self.name,
+            // This should be the same as return type of 
AggregateFunction::ArrayAgg
             Field::new("item", self.input_data_type.clone(), true),
-            false,
+            self.nullable,
         ))
     }
 
@@ -77,7 +85,7 @@ impl AggregateExpr for ArrayAgg {
         Ok(vec![Field::new_list(
             format_state_name(&self.name, "array_agg"),
             Field::new("item", self.input_data_type.clone(), true),
-            false,
+            self.nullable,
         )])
     }
 
@@ -184,7 +192,6 @@ mod tests {
     use super::*;
     use crate::expressions::col;
     use crate::expressions::tests::aggregate;
-    use crate::generic_test_op;
     use arrow::array::ArrayRef;
     use arrow::array::Int32Array;
     use arrow::datatypes::*;
@@ -195,6 +202,30 @@ mod tests {
     use datafusion_common::DataFusionError;
     use datafusion_common::Result;
 
+    macro_rules! test_op {
+        ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => {
+            test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type())
+        };
+        ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, 
$EXPECTED_DATATYPE:expr) => {{
+            let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]);
+
+            let batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![$ARRAY])?;
+
+            let agg = Arc::new(<$OP>::new(
+                col("a", &schema)?,
+                "bla".to_string(),
+                $EXPECTED_DATATYPE,
+                true,
+            ));
+            let actual = aggregate(&batch, agg)?;
+            let expected = ScalarValue::from($EXPECTED);
+
+            assert_eq!(expected, actual);
+
+            Ok(()) as Result<(), DataFusionError>
+        }};
+    }
+
     #[test]
     fn array_agg_i32() -> Result<()> {
         let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
@@ -208,7 +239,7 @@ mod tests {
         ])]);
         let list = ScalarValue::List(Arc::new(list));
 
-        generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
+        test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
     }
 
     #[test]
@@ -264,7 +295,7 @@ mod tests {
 
         let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
 
-        generic_test_op!(
+        test_op!(
             array,
             DataType::List(Arc::new(Field::new_list(
                 "item",
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs 
b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
index 9b391b0c42..1efae424cc 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
@@ -40,6 +40,8 @@ pub struct DistinctArrayAgg {
     input_data_type: DataType,
     /// The input expression
     expr: Arc<dyn PhysicalExpr>,
+    /// If the input expression can have NULLs
+    nullable: bool,
 }
 
 impl DistinctArrayAgg {
@@ -48,12 +50,14 @@ impl DistinctArrayAgg {
         expr: Arc<dyn PhysicalExpr>,
         name: impl Into<String>,
         input_data_type: DataType,
+        nullable: bool,
     ) -> Self {
         let name = name.into();
         Self {
             name,
-            expr,
             input_data_type,
+            expr,
+            nullable,
         }
     }
 }
@@ -67,8 +71,9 @@ impl AggregateExpr for DistinctArrayAgg {
     fn field(&self) -> Result<Field> {
         Ok(Field::new_list(
             &self.name,
+            // This should be the same as return type of 
AggregateFunction::ArrayAgg
             Field::new("item", self.input_data_type.clone(), true),
-            false,
+            self.nullable,
         ))
     }
 
@@ -82,7 +87,7 @@ impl AggregateExpr for DistinctArrayAgg {
         Ok(vec![Field::new_list(
             format_state_name(&self.name, "distinct_array_agg"),
             Field::new("item", self.input_data_type.clone(), true),
-            false,
+            self.nullable,
         )])
     }
 
@@ -238,6 +243,7 @@ mod tests {
             col("a", &schema)?,
             "bla".to_string(),
             datatype,
+            true,
         ));
         let actual = aggregate(&batch, agg)?;
 
@@ -255,6 +261,7 @@ mod tests {
             col("a", &schema)?,
             "bla".to_string(),
             datatype,
+            true,
         ));
 
         let mut accum1 = agg.create_accumulator()?;
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs 
b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
index a53d53107a..9ca83a781a 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
@@ -48,10 +48,17 @@ use itertools::izip;
 /// and that can merge aggregations from multiple partitions.
 #[derive(Debug)]
 pub struct OrderSensitiveArrayAgg {
+    /// Column name
     name: String,
+    /// The DataType for the input expression
     input_data_type: DataType,
-    order_by_data_types: Vec<DataType>,
+    /// The input expression
     expr: Arc<dyn PhysicalExpr>,
+    /// If the input expression can have NULLs
+    nullable: bool,
+    /// Ordering data types
+    order_by_data_types: Vec<DataType>,
+    /// Ordering requirement
     ordering_req: LexOrdering,
 }
 
@@ -61,13 +68,15 @@ impl OrderSensitiveArrayAgg {
         expr: Arc<dyn PhysicalExpr>,
         name: impl Into<String>,
         input_data_type: DataType,
+        nullable: bool,
         order_by_data_types: Vec<DataType>,
         ordering_req: LexOrdering,
     ) -> Self {
         Self {
             name: name.into(),
-            expr,
             input_data_type,
+            expr,
+            nullable,
             order_by_data_types,
             ordering_req,
         }
@@ -82,8 +91,9 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
     fn field(&self) -> Result<Field> {
         Ok(Field::new_list(
             &self.name,
+            // This should be the same as return type of 
AggregateFunction::ArrayAgg
             Field::new("item", self.input_data_type.clone(), true),
-            false,
+            self.nullable,
         ))
     }
 
@@ -99,13 +109,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
         let mut fields = vec![Field::new_list(
             format_state_name(&self.name, "array_agg"),
             Field::new("item", self.input_data_type.clone(), true),
-            false,
+            self.nullable, // This should be the same as field()
         )];
         let orderings = ordering_fields(&self.ordering_req, 
&self.order_by_data_types);
         fields.push(Field::new_list(
             format_state_name(&self.name, "array_agg_orderings"),
             Field::new("item", DataType::Struct(Fields::from(orderings)), 
true),
-            false,
+            self.nullable,
         ));
         Ok(fields)
     }
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs 
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 6568457bc2..596197b4ee 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -114,13 +114,16 @@ pub fn create_aggregate_expr(
         ),
         (AggregateFunction::ArrayAgg, false) => {
             let expr = input_phy_exprs[0].clone();
+            let nullable = expr.nullable(input_schema)?;
+
             if ordering_req.is_empty() {
-                Arc::new(expressions::ArrayAgg::new(expr, name, data_type))
+                Arc::new(expressions::ArrayAgg::new(expr, name, data_type, 
nullable))
             } else {
                 Arc::new(expressions::OrderSensitiveArrayAgg::new(
                     expr,
                     name,
                     data_type,
+                    nullable,
                     ordering_types,
                     ordering_req.to_vec(),
                 ))
@@ -132,10 +135,13 @@ pub fn create_aggregate_expr(
                     "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive 
aggregations are not available"
                 );
             }
+            let expr = input_phy_exprs[0].clone();
+            let is_expr_nullable = expr.nullable(input_schema)?;
             Arc::new(expressions::DistinctArrayAgg::new(
-                input_phy_exprs[0].clone(),
+                expr,
                 name,
                 data_type,
+                is_expr_nullable,
             ))
         }
         (AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
@@ -432,8 +438,8 @@ mod tests {
                         assert_eq!(
                             Field::new_list(
                                 "c1",
-                                Field::new("item", data_type.clone(), true,),
-                                false,
+                                Field::new("item", data_type.clone(), true),
+                                true,
                             ),
                             result_agg_phy_exprs.field().unwrap()
                         );
@@ -471,8 +477,8 @@ mod tests {
                         assert_eq!(
                             Field::new_list(
                                 "c1",
-                                Field::new("item", data_type.clone(), true,),
-                                false,
+                                Field::new("item", data_type.clone(), true),
+                                true,
                             ),
                             result_agg_phy_exprs.field().unwrap()
                         );

Reply via email to