This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 77c68b2  Implement array_agg aggregate function (#1300)
77c68b2 is described below

commit 77c68b26c534bec0cd314cc24bbbc7b4e3e33868
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Nov 16 03:37:19 2021 -0800

    Implement array_agg aggregate function (#1300)
    
    * Implement array_agg aggregate function.
    
    * Avoid copying.
    
    * Fix clippy.
    
    * For review comment.
    
    * Add e2e tests.
    
    * Add assert and order by.
---
 ballista/rust/core/proto/ballista.proto            |   1 +
 .../rust/core/src/serde/logical_plan/to_proto.rs   |   2 +
 ballista/rust/core/src/serde/mod.rs                |   1 +
 datafusion/src/physical_plan/aggregates.rs         |  21 +-
 .../src/physical_plan/expressions/array_agg.rs     | 257 +++++++++++++++++++++
 datafusion/src/physical_plan/expressions/mod.rs    |   2 +
 datafusion/tests/sql.rs                            |  54 +++++
 7 files changed, 333 insertions(+), 5 deletions(-)

diff --git a/ballista/rust/core/proto/ballista.proto 
b/ballista/rust/core/proto/ballista.proto
index 1815811..493fb97 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -168,6 +168,7 @@ enum AggregateFunction {
   AVG = 3;
   COUNT = 4;
   APPROX_DISTINCT = 5;
+  ARRAY_AGG = 6;
 }
 
 message AggregateExprNode {
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs 
b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index e4c7656..805fe31 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -1124,6 +1124,7 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
                     AggregateFunction::ApproxDistinct => {
                         protobuf::AggregateFunction::ApproxDistinct
                     }
+                    AggregateFunction::ArrayAgg => 
protobuf::AggregateFunction::ArrayAgg,
                     AggregateFunction::Min => protobuf::AggregateFunction::Min,
                     AggregateFunction::Max => protobuf::AggregateFunction::Max,
                     AggregateFunction::Sum => protobuf::AggregateFunction::Sum,
@@ -1358,6 +1359,7 @@ impl From<&AggregateFunction> for 
protobuf::AggregateFunction {
             AggregateFunction::Avg => Self::Avg,
             AggregateFunction::Count => Self::Count,
             AggregateFunction::ApproxDistinct => Self::ApproxDistinct,
+            AggregateFunction::ArrayAgg => Self::ArrayAgg,
         }
     }
 }
diff --git a/ballista/rust/core/src/serde/mod.rs 
b/ballista/rust/core/src/serde/mod.rs
index 4a32b24..b5c3c3c 100644
--- a/ballista/rust/core/src/serde/mod.rs
+++ b/ballista/rust/core/src/serde/mod.rs
@@ -117,6 +117,7 @@ impl From<protobuf::AggregateFunction> for 
AggregateFunction {
             protobuf::AggregateFunction::ApproxDistinct => {
                 AggregateFunction::ApproxDistinct
             }
+            protobuf::AggregateFunction::ArrayAgg => 
AggregateFunction::ArrayAgg,
         }
     }
 }
diff --git a/datafusion/src/physical_plan/aggregates.rs 
b/datafusion/src/physical_plan/aggregates.rs
index eb3f6ca..0c99c4f 100644
--- a/datafusion/src/physical_plan/aggregates.rs
+++ b/datafusion/src/physical_plan/aggregates.rs
@@ -34,7 +34,7 @@ use super::{
 use crate::error::{DataFusionError, Result};
 use crate::physical_plan::distinct_expressions;
 use crate::physical_plan::expressions;
-use arrow::datatypes::{DataType, Schema, TimeUnit};
+use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
 use expressions::{avg_return_type, sum_return_type};
 use std::{fmt, str::FromStr, sync::Arc};
 /// the implementation of an aggregate function
@@ -46,7 +46,7 @@ pub type AccumulatorFunctionImplementation =
 pub type StateTypeFunction =
     Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
 
-/// Enum of all built-in scalar functions
+/// Enum of all built-in aggregate functions
 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd)]
 pub enum AggregateFunction {
     /// count
@@ -61,6 +61,8 @@ pub enum AggregateFunction {
     Avg,
     /// Approximate aggregate function
     ApproxDistinct,
+    /// array_agg
+    ArrayAgg,
 }
 
 impl fmt::Display for AggregateFunction {
@@ -80,6 +82,7 @@ impl FromStr for AggregateFunction {
             "avg" => AggregateFunction::Avg,
             "sum" => AggregateFunction::Sum,
             "approx_distinct" => AggregateFunction::ApproxDistinct,
+            "array_agg" => AggregateFunction::ArrayAgg,
             _ => {
                 return Err(DataFusionError::Plan(format!(
                     "There is no built-in function named {}",
@@ -105,6 +108,11 @@ pub fn return_type(fun: &AggregateFunction, arg_types: 
&[DataType]) -> Result<Da
         AggregateFunction::Max | AggregateFunction::Min => 
Ok(arg_types[0].clone()),
         AggregateFunction::Sum => sum_return_type(&arg_types[0]),
         AggregateFunction::Avg => avg_return_type(&arg_types[0]),
+        AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new(
+            "item",
+            arg_types[0].clone(),
+            true,
+        )))),
     }
 }
 
@@ -157,6 +165,9 @@ pub fn create_aggregate_expr(
         (AggregateFunction::ApproxDistinct, _) => Arc::new(
             expressions::ApproxDistinct::new(arg, name, arg_types[0].clone()),
         ),
+        (AggregateFunction::ArrayAgg, _) => {
+            Arc::new(expressions::ArrayAgg::new(arg, name, 
arg_types[0].clone()))
+        }
         (AggregateFunction::Min, _) => {
             Arc::new(expressions::Min::new(arg, name, return_type))
         }
@@ -202,9 +213,9 @@ static DATES: &[DataType] = &[DataType::Date32, 
DataType::Date64];
 pub fn signature(fun: &AggregateFunction) -> Signature {
     // note: the physical expression must accept the type returned by this 
function or the execution panics.
     match fun {
-        AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
-            Signature::any(1, Volatility::Immutable)
-        }
+        AggregateFunction::Count
+        | AggregateFunction::ApproxDistinct
+        | AggregateFunction::ArrayAgg => Signature::any(1, 
Volatility::Immutable),
         AggregateFunction::Min | AggregateFunction::Max => {
             let valid = STRINGS
                 .iter()
diff --git a/datafusion/src/physical_plan/expressions/array_agg.rs 
b/datafusion/src/physical_plan/expressions/array_agg.rs
new file mode 100644
index 0000000..213b392
--- /dev/null
+++ b/datafusion/src/physical_plan/expressions/array_agg.rs
@@ -0,0 +1,257 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines physical expressions that can evaluated at runtime during query 
execution
+
+use super::format_state_name;
+use crate::error::Result;
+use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
+use crate::scalar::ScalarValue;
+use arrow::datatypes::{DataType, Field};
+use std::any::Any;
+use std::sync::Arc;
+
+/// ARRAY_AGG aggregate expression
+#[derive(Debug)]
+pub struct ArrayAgg {
+    name: String,
+    input_data_type: DataType,
+    expr: Arc<dyn PhysicalExpr>,
+}
+
+impl ArrayAgg {
+    /// Create a new ArrayAgg aggregate function
+    pub fn new(
+        expr: Arc<dyn PhysicalExpr>,
+        name: impl Into<String>,
+        data_type: DataType,
+    ) -> Self {
+        Self {
+            name: name.into(),
+            expr,
+            input_data_type: data_type,
+        }
+    }
+}
+
+impl AggregateExpr for ArrayAgg {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn field(&self) -> Result<Field> {
+        Ok(Field::new(
+            &self.name,
+            DataType::List(Box::new(Field::new(
+                "item",
+                self.input_data_type.clone(),
+                true,
+            ))),
+            false,
+        ))
+    }
+
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(ArrayAggAccumulator::try_new(
+            &self.input_data_type,
+        )?))
+    }
+
+    fn state_fields(&self) -> Result<Vec<Field>> {
+        Ok(vec![Field::new(
+            &format_state_name(&self.name, "array_agg"),
+            DataType::List(Box::new(Field::new(
+                "item",
+                self.input_data_type.clone(),
+                true,
+            ))),
+            false,
+        )])
+    }
+
+    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![self.expr.clone()]
+    }
+}
+
+#[derive(Debug)]
+pub(crate) struct ArrayAggAccumulator {
+    array: Vec<ScalarValue>,
+    datatype: DataType,
+}
+
+impl ArrayAggAccumulator {
+    /// new array_agg accumulator based on given item data type
+    pub fn try_new(datatype: &DataType) -> Result<Self> {
+        Ok(Self {
+            array: vec![],
+            datatype: datatype.clone(),
+        })
+    }
+}
+
+impl Accumulator for ArrayAggAccumulator {
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        Ok(vec![ScalarValue::List(
+            Some(Box::new(self.array.clone())),
+            Box::new(self.datatype.clone()),
+        )])
+    }
+
+    fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
+        let value = &values[0];
+        self.array.push(value.clone());
+
+        Ok(())
+    }
+
+    fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
+        if states.is_empty() {
+            return Ok(());
+        };
+
+        assert!(states.len() == 1, "states length should be 1!");
+        match &states[0] {
+            ScalarValue::List(Some(array), _) => {
+                self.array.extend((&**array).clone());
+            }
+            _ => unreachable!(),
+        }
+        Ok(())
+    }
+
+    fn evaluate(&self) -> Result<ScalarValue> {
+        Ok(ScalarValue::List(
+            Some(Box::new(self.array.clone())),
+            Box::new(self.datatype.clone()),
+        ))
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::physical_plan::expressions::col;
+    use crate::physical_plan::expressions::tests::aggregate;
+    use crate::{error::Result, generic_test_op};
+    use arrow::array::ArrayRef;
+    use arrow::array::Int32Array;
+    use arrow::datatypes::*;
+    use arrow::record_batch::RecordBatch;
+
+    #[test]
+    fn array_agg_i32() -> Result<()> {
+        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
+
+        let list = ScalarValue::List(
+            Some(Box::new(vec![
+                ScalarValue::Int32(Some(1)),
+                ScalarValue::Int32(Some(2)),
+                ScalarValue::Int32(Some(3)),
+                ScalarValue::Int32(Some(4)),
+                ScalarValue::Int32(Some(5)),
+            ])),
+            Box::new(DataType::Int32),
+        );
+
+        generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
+    }
+
+    #[test]
+    fn array_agg_nested() -> Result<()> {
+        let l1 = ScalarValue::List(
+            Some(Box::new(vec![
+                ScalarValue::List(
+                    Some(Box::new(vec![
+                        ScalarValue::from(1i32),
+                        ScalarValue::from(2i32),
+                        ScalarValue::from(3i32),
+                    ])),
+                    Box::new(DataType::Int32),
+                ),
+                ScalarValue::List(
+                    Some(Box::new(vec![
+                        ScalarValue::from(4i32),
+                        ScalarValue::from(5i32),
+                    ])),
+                    Box::new(DataType::Int32),
+                ),
+            ])),
+            Box::new(DataType::List(Box::new(Field::new(
+                "item",
+                DataType::Int32,
+                true,
+            )))),
+        );
+
+        let l2 = ScalarValue::List(
+            Some(Box::new(vec![
+                ScalarValue::List(
+                    Some(Box::new(vec![ScalarValue::from(6i32)])),
+                    Box::new(DataType::Int32),
+                ),
+                ScalarValue::List(
+                    Some(Box::new(vec![
+                        ScalarValue::from(7i32),
+                        ScalarValue::from(8i32),
+                    ])),
+                    Box::new(DataType::Int32),
+                ),
+            ])),
+            Box::new(DataType::List(Box::new(Field::new(
+                "item",
+                DataType::Int32,
+                true,
+            )))),
+        );
+
+        let l3 = ScalarValue::List(
+            Some(Box::new(vec![ScalarValue::List(
+                Some(Box::new(vec![ScalarValue::from(9i32)])),
+                Box::new(DataType::Int32),
+            )])),
+            Box::new(DataType::List(Box::new(Field::new(
+                "item",
+                DataType::Int32,
+                true,
+            )))),
+        );
+
+        let list = ScalarValue::List(
+            Some(Box::new(vec![l1.clone(), l2.clone(), l3.clone()])),
+            Box::new(DataType::List(Box::new(Field::new(
+                "item",
+                DataType::Int32,
+                true,
+            )))),
+        );
+
+        let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
+
+        generic_test_op!(
+            array,
+            DataType::List(Box::new(Field::new(
+                "item",
+                DataType::List(Box::new(Field::new("item", DataType::Int32, 
true,))),
+                true,
+            ))),
+            ArrayAgg,
+            list,
+            DataType::List(Box::new(Field::new("item", DataType::Int32, 
true,)))
+        )
+    }
+}
diff --git a/datafusion/src/physical_plan/expressions/mod.rs 
b/datafusion/src/physical_plan/expressions/mod.rs
index dba3bde..5647ee0 100644
--- a/datafusion/src/physical_plan/expressions/mod.rs
+++ b/datafusion/src/physical_plan/expressions/mod.rs
@@ -26,6 +26,7 @@ use arrow::compute::kernels::sort::{SortColumn, SortOptions};
 use arrow::record_batch::RecordBatch;
 
 mod approx_distinct;
+mod array_agg;
 mod average;
 #[macro_use]
 mod binary;
@@ -58,6 +59,7 @@ pub mod helpers {
 }
 
 pub use approx_distinct::ApproxDistinct;
+pub use array_agg::ArrayAgg;
 pub use average::{avg_return_type, Avg, AvgAccumulator};
 pub use binary::{binary, binary_operator_data_type, BinaryExpr};
 pub use case::{case, CaseExpr};
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index eeb6c10..15241ee 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -1281,6 +1281,60 @@ async fn csv_query_approx_count() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn csv_query_array_agg() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx).await?;
+    let sql =
+        "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY 
c13 LIMIT 2) test";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+------------------------------------------------------------------+",
+        "| ARRAYAGG(test.c13)                                               |",
+        "+------------------------------------------------------------------+",
+        "| [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB] |",
+        "+------------------------------------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_array_agg_empty() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx).await?;
+    let sql =
+        "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) 
test";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+--------------------+",
+        "| ARRAYAGG(test.c13) |",
+        "+--------------------+",
+        "| []                 |",
+        "+--------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_array_agg_one() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx).await?;
+    let sql =
+        "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY 
c13 LIMIT 1) test";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+----------------------------------+",
+        "| ARRAYAGG(test.c13)               |",
+        "+----------------------------------+",
+        "| [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm] |",
+        "+----------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
 /// for window functions without order by the first, last, and nth function 
call does not make sense
 #[tokio::test]
 async fn csv_query_window_with_empty_over() -> Result<()> {

Reply via email to