This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 77c68b2 Implement array_agg aggregate function (#1300)
77c68b2 is described below
commit 77c68b26c534bec0cd314cc24bbbc7b4e3e33868
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Nov 16 03:37:19 2021 -0800
Implement array_agg aggregate function (#1300)
* Implement array_agg aggregate function.
* Avoid copying.
* Fix clippy.
* For review comment.
* Add e2e tests.
* Add assert and order by.
---
ballista/rust/core/proto/ballista.proto | 1 +
.../rust/core/src/serde/logical_plan/to_proto.rs | 2 +
ballista/rust/core/src/serde/mod.rs | 1 +
datafusion/src/physical_plan/aggregates.rs | 21 +-
.../src/physical_plan/expressions/array_agg.rs | 257 +++++++++++++++++++++
datafusion/src/physical_plan/expressions/mod.rs | 2 +
datafusion/tests/sql.rs | 54 +++++
7 files changed, 333 insertions(+), 5 deletions(-)
diff --git a/ballista/rust/core/proto/ballista.proto
b/ballista/rust/core/proto/ballista.proto
index 1815811..493fb97 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -168,6 +168,7 @@ enum AggregateFunction {
AVG = 3;
COUNT = 4;
APPROX_DISTINCT = 5;
+ ARRAY_AGG = 6;
}
message AggregateExprNode {
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index e4c7656..805fe31 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -1124,6 +1124,7 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
AggregateFunction::ApproxDistinct => {
protobuf::AggregateFunction::ApproxDistinct
}
+ AggregateFunction::ArrayAgg =>
protobuf::AggregateFunction::ArrayAgg,
AggregateFunction::Min => protobuf::AggregateFunction::Min,
AggregateFunction::Max => protobuf::AggregateFunction::Max,
AggregateFunction::Sum => protobuf::AggregateFunction::Sum,
@@ -1358,6 +1359,7 @@ impl From<&AggregateFunction> for
protobuf::AggregateFunction {
AggregateFunction::Avg => Self::Avg,
AggregateFunction::Count => Self::Count,
AggregateFunction::ApproxDistinct => Self::ApproxDistinct,
+ AggregateFunction::ArrayAgg => Self::ArrayAgg,
}
}
}
diff --git a/ballista/rust/core/src/serde/mod.rs
b/ballista/rust/core/src/serde/mod.rs
index 4a32b24..b5c3c3c 100644
--- a/ballista/rust/core/src/serde/mod.rs
+++ b/ballista/rust/core/src/serde/mod.rs
@@ -117,6 +117,7 @@ impl From<protobuf::AggregateFunction> for
AggregateFunction {
protobuf::AggregateFunction::ApproxDistinct => {
AggregateFunction::ApproxDistinct
}
+ protobuf::AggregateFunction::ArrayAgg =>
AggregateFunction::ArrayAgg,
}
}
}
diff --git a/datafusion/src/physical_plan/aggregates.rs
b/datafusion/src/physical_plan/aggregates.rs
index eb3f6ca..0c99c4f 100644
--- a/datafusion/src/physical_plan/aggregates.rs
+++ b/datafusion/src/physical_plan/aggregates.rs
@@ -34,7 +34,7 @@ use super::{
use crate::error::{DataFusionError, Result};
use crate::physical_plan::distinct_expressions;
use crate::physical_plan::expressions;
-use arrow::datatypes::{DataType, Schema, TimeUnit};
+use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use expressions::{avg_return_type, sum_return_type};
use std::{fmt, str::FromStr, sync::Arc};
/// the implementation of an aggregate function
@@ -46,7 +46,7 @@ pub type AccumulatorFunctionImplementation =
pub type StateTypeFunction =
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
-/// Enum of all built-in scalar functions
+/// Enum of all built-in aggregate functions
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)]
pub enum AggregateFunction {
/// count
@@ -61,6 +61,8 @@ pub enum AggregateFunction {
Avg,
/// Approximate aggregate function
ApproxDistinct,
+ /// array_agg
+ ArrayAgg,
}
impl fmt::Display for AggregateFunction {
@@ -80,6 +82,7 @@ impl FromStr for AggregateFunction {
"avg" => AggregateFunction::Avg,
"sum" => AggregateFunction::Sum,
"approx_distinct" => AggregateFunction::ApproxDistinct,
+ "array_agg" => AggregateFunction::ArrayAgg,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
@@ -105,6 +108,11 @@ pub fn return_type(fun: &AggregateFunction, arg_types:
&[DataType]) -> Result<Da
AggregateFunction::Max | AggregateFunction::Min =>
Ok(arg_types[0].clone()),
AggregateFunction::Sum => sum_return_type(&arg_types[0]),
AggregateFunction::Avg => avg_return_type(&arg_types[0]),
+ AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new(
+ "item",
+ arg_types[0].clone(),
+ true,
+ )))),
}
}
@@ -157,6 +165,9 @@ pub fn create_aggregate_expr(
(AggregateFunction::ApproxDistinct, _) => Arc::new(
expressions::ApproxDistinct::new(arg, name, arg_types[0].clone()),
),
+ (AggregateFunction::ArrayAgg, _) => {
+ Arc::new(expressions::ArrayAgg::new(arg, name,
arg_types[0].clone()))
+ }
(AggregateFunction::Min, _) => {
Arc::new(expressions::Min::new(arg, name, return_type))
}
@@ -202,9 +213,9 @@ static DATES: &[DataType] = &[DataType::Date32,
DataType::Date64];
pub fn signature(fun: &AggregateFunction) -> Signature {
// note: the physical expression must accept the type returned by this
function or the execution panics.
match fun {
- AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
- Signature::any(1, Volatility::Immutable)
- }
+ AggregateFunction::Count
+ | AggregateFunction::ApproxDistinct
+ | AggregateFunction::ArrayAgg => Signature::any(1,
Volatility::Immutable),
AggregateFunction::Min | AggregateFunction::Max => {
let valid = STRINGS
.iter()
diff --git a/datafusion/src/physical_plan/expressions/array_agg.rs
b/datafusion/src/physical_plan/expressions/array_agg.rs
new file mode 100644
index 0000000..213b392
--- /dev/null
+++ b/datafusion/src/physical_plan/expressions/array_agg.rs
@@ -0,0 +1,257 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines physical expressions that can evaluated at runtime during query
execution
+
+use super::format_state_name;
+use crate::error::Result;
+use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
+use crate::scalar::ScalarValue;
+use arrow::datatypes::{DataType, Field};
+use std::any::Any;
+use std::sync::Arc;
+
+/// ARRAY_AGG aggregate expression
+#[derive(Debug)]
+pub struct ArrayAgg {
+ name: String,
+ input_data_type: DataType,
+ expr: Arc<dyn PhysicalExpr>,
+}
+
+impl ArrayAgg {
+ /// Create a new ArrayAgg aggregate function
+ pub fn new(
+ expr: Arc<dyn PhysicalExpr>,
+ name: impl Into<String>,
+ data_type: DataType,
+ ) -> Self {
+ Self {
+ name: name.into(),
+ expr,
+ input_data_type: data_type,
+ }
+ }
+}
+
+impl AggregateExpr for ArrayAgg {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(
+ &self.name,
+ DataType::List(Box::new(Field::new(
+ "item",
+ self.input_data_type.clone(),
+ true,
+ ))),
+ false,
+ ))
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(ArrayAggAccumulator::try_new(
+ &self.input_data_type,
+ )?))
+ }
+
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ Ok(vec![Field::new(
+ &format_state_name(&self.name, "array_agg"),
+ DataType::List(Box::new(Field::new(
+ "item",
+ self.input_data_type.clone(),
+ true,
+ ))),
+ false,
+ )])
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.expr.clone()]
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct ArrayAggAccumulator {
+ array: Vec<ScalarValue>,
+ datatype: DataType,
+}
+
+impl ArrayAggAccumulator {
+ /// new array_agg accumulator based on given item data type
+ pub fn try_new(datatype: &DataType) -> Result<Self> {
+ Ok(Self {
+ array: vec![],
+ datatype: datatype.clone(),
+ })
+ }
+}
+
+impl Accumulator for ArrayAggAccumulator {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![ScalarValue::List(
+ Some(Box::new(self.array.clone())),
+ Box::new(self.datatype.clone()),
+ )])
+ }
+
+ fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
+ let value = &values[0];
+ self.array.push(value.clone());
+
+ Ok(())
+ }
+
+ fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
+ if states.is_empty() {
+ return Ok(());
+ };
+
+ assert!(states.len() == 1, "states length should be 1!");
+ match &states[0] {
+ ScalarValue::List(Some(array), _) => {
+ self.array.extend((&**array).clone());
+ }
+ _ => unreachable!(),
+ }
+ Ok(())
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ Ok(ScalarValue::List(
+ Some(Box::new(self.array.clone())),
+ Box::new(self.datatype.clone()),
+ ))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::physical_plan::expressions::col;
+ use crate::physical_plan::expressions::tests::aggregate;
+ use crate::{error::Result, generic_test_op};
+ use arrow::array::ArrayRef;
+ use arrow::array::Int32Array;
+ use arrow::datatypes::*;
+ use arrow::record_batch::RecordBatch;
+
+ #[test]
+ fn array_agg_i32() -> Result<()> {
+ let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
+
+ let list = ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::Int32(Some(1)),
+ ScalarValue::Int32(Some(2)),
+ ScalarValue::Int32(Some(3)),
+ ScalarValue::Int32(Some(4)),
+ ScalarValue::Int32(Some(5)),
+ ])),
+ Box::new(DataType::Int32),
+ );
+
+ generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
+ }
+
+ #[test]
+ fn array_agg_nested() -> Result<()> {
+ let l1 = ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::from(1i32),
+ ScalarValue::from(2i32),
+ ScalarValue::from(3i32),
+ ])),
+ Box::new(DataType::Int32),
+ ),
+ ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::from(4i32),
+ ScalarValue::from(5i32),
+ ])),
+ Box::new(DataType::Int32),
+ ),
+ ])),
+ Box::new(DataType::List(Box::new(Field::new(
+ "item",
+ DataType::Int32,
+ true,
+ )))),
+ );
+
+ let l2 = ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::List(
+ Some(Box::new(vec![ScalarValue::from(6i32)])),
+ Box::new(DataType::Int32),
+ ),
+ ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::from(7i32),
+ ScalarValue::from(8i32),
+ ])),
+ Box::new(DataType::Int32),
+ ),
+ ])),
+ Box::new(DataType::List(Box::new(Field::new(
+ "item",
+ DataType::Int32,
+ true,
+ )))),
+ );
+
+ let l3 = ScalarValue::List(
+ Some(Box::new(vec![ScalarValue::List(
+ Some(Box::new(vec![ScalarValue::from(9i32)])),
+ Box::new(DataType::Int32),
+ )])),
+ Box::new(DataType::List(Box::new(Field::new(
+ "item",
+ DataType::Int32,
+ true,
+ )))),
+ );
+
+ let list = ScalarValue::List(
+ Some(Box::new(vec![l1.clone(), l2.clone(), l3.clone()])),
+ Box::new(DataType::List(Box::new(Field::new(
+ "item",
+ DataType::Int32,
+ true,
+ )))),
+ );
+
+ let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
+
+ generic_test_op!(
+ array,
+ DataType::List(Box::new(Field::new(
+ "item",
+ DataType::List(Box::new(Field::new("item", DataType::Int32,
true,))),
+ true,
+ ))),
+ ArrayAgg,
+ list,
+ DataType::List(Box::new(Field::new("item", DataType::Int32,
true,)))
+ )
+ }
+}
diff --git a/datafusion/src/physical_plan/expressions/mod.rs
b/datafusion/src/physical_plan/expressions/mod.rs
index dba3bde..5647ee0 100644
--- a/datafusion/src/physical_plan/expressions/mod.rs
+++ b/datafusion/src/physical_plan/expressions/mod.rs
@@ -26,6 +26,7 @@ use arrow::compute::kernels::sort::{SortColumn, SortOptions};
use arrow::record_batch::RecordBatch;
mod approx_distinct;
+mod array_agg;
mod average;
#[macro_use]
mod binary;
@@ -58,6 +59,7 @@ pub mod helpers {
}
pub use approx_distinct::ApproxDistinct;
+pub use array_agg::ArrayAgg;
pub use average::{avg_return_type, Avg, AvgAccumulator};
pub use binary::{binary, binary_operator_data_type, BinaryExpr};
pub use case::{case, CaseExpr};
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index eeb6c10..15241ee 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -1281,6 +1281,60 @@ async fn csv_query_approx_count() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn csv_query_array_agg() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx).await?;
+ let sql =
+ "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY
c13 LIMIT 2) test";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ let expected = vec![
+ "+------------------------------------------------------------------+",
+ "| ARRAYAGG(test.c13) |",
+ "+------------------------------------------------------------------+",
+ "| [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB] |",
+ "+------------------------------------------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_array_agg_empty() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx).await?;
+ let sql =
+ "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0)
test";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ let expected = vec![
+ "+--------------------+",
+ "| ARRAYAGG(test.c13) |",
+ "+--------------------+",
+ "| [] |",
+ "+--------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_array_agg_one() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx).await?;
+ let sql =
+ "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY
c13 LIMIT 1) test";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ let expected = vec![
+ "+----------------------------------+",
+ "| ARRAYAGG(test.c13) |",
+ "+----------------------------------+",
+ "| [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm] |",
+ "+----------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
/// for window functions without order by the first, last, and nth function
call does not make sense
#[tokio::test]
async fn csv_query_window_with_empty_over() -> Result<()> {