This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new a96bb5e Implement ARRAY_AGG(DISTINCT ...) (#1579)
a96bb5e is described below
commit a96bb5eea8aa0dddc2bdf30de737b76ccb132054
Author: James Katz <[email protected]>
AuthorDate: Thu Jan 20 14:53:44 2022 -0500
Implement ARRAY_AGG(DISTINCT ...) (#1579)
* Implement ARRAY_AGG(DISTINCT)
* Add integration test
* Move distinct_expressions into physical_plan/expressions
* Add physical plan unit tests
* Fix clippy
* Fix rebase import mistake
* Clean up distinct tests
* Fix clippy
---
datafusion/src/physical_plan/aggregates.rs | 77 +++---
.../{ => expressions}/distinct_expressions.rs | 274 ++++++++++++++++++++-
datafusion/src/physical_plan/expressions/mod.rs | 2 +
datafusion/src/physical_plan/mod.rs | 1 -
datafusion/tests/sql/aggregates.rs | 51 ++++
5 files changed, 361 insertions(+), 44 deletions(-)
diff --git a/datafusion/src/physical_plan/aggregates.rs
b/datafusion/src/physical_plan/aggregates.rs
index 1495e05..f7beb76 100644
--- a/datafusion/src/physical_plan/aggregates.rs
+++ b/datafusion/src/physical_plan/aggregates.rs
@@ -32,7 +32,6 @@ use super::{
};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs,
coerce_types};
-use crate::physical_plan::distinct_expressions;
use crate::physical_plan::expressions;
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use expressions::{
@@ -198,14 +197,12 @@ pub fn create_aggregate_expr(
name,
return_type,
)),
- (AggregateFunction::Count, true) => {
- Arc::new(distinct_expressions::DistinctCount::new(
- coerced_exprs_types,
- coerced_phy_exprs,
- name,
- return_type,
- ))
- }
+ (AggregateFunction::Count, true) =>
Arc::new(expressions::DistinctCount::new(
+ coerced_exprs_types,
+ coerced_phy_exprs,
+ name,
+ return_type,
+ )),
(AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new(
coerced_phy_exprs[0].clone(),
name,
@@ -229,9 +226,11 @@ pub fn create_aggregate_expr(
coerced_exprs_types[0].clone(),
)),
(AggregateFunction::ArrayAgg, true) => {
- return Err(DataFusionError::NotImplemented(
- "ARRAY_AGG(DISTINCT) aggregations are not
available".to_string(),
- ));
+ Arc::new(expressions::DistinctArrayAgg::new(
+ coerced_phy_exprs[0].clone(),
+ name,
+ coerced_exprs_types[0].clone(),
+ ))
}
(AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
coerced_phy_exprs[0].clone(),
@@ -396,12 +395,10 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
#[cfg(test)]
mod tests {
use super::*;
- use crate::error::DataFusionError::NotImplemented;
use crate::error::Result;
- use crate::physical_plan::distinct_expressions::DistinctCount;
use crate::physical_plan::expressions::{
- ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance, Max,
Min, Stddev,
- Sum, Variance,
+ ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance,
DistinctArrayAgg,
+ DistinctCount, Max, Min, Stddev, Sum, Variance,
};
#[test]
@@ -475,42 +472,40 @@ mod tests {
&input_phy_exprs[0..1],
&input_schema,
"c1",
- );
+ )?;
match fun {
AggregateFunction::Count => {
- let result_agg_phy_exprs_distinct = result_distinct?;
- assert!(result_agg_phy_exprs_distinct
- .as_any()
- .is::<DistinctCount>());
- assert_eq!("c1", result_agg_phy_exprs_distinct.name());
+
assert!(result_distinct.as_any().is::<DistinctCount>());
+ assert_eq!("c1", result_distinct.name());
assert_eq!(
Field::new("c1", DataType::UInt64, true),
- result_agg_phy_exprs_distinct.field().unwrap()
+ result_distinct.field().unwrap()
);
}
AggregateFunction::ApproxDistinct => {
- let result_agg_phy_exprs_distinct = result_distinct?;
- assert!(result_agg_phy_exprs_distinct
- .as_any()
- .is::<ApproxDistinct>());
- assert_eq!("c1", result_agg_phy_exprs_distinct.name());
+
assert!(result_distinct.as_any().is::<ApproxDistinct>());
+ assert_eq!("c1", result_distinct.name());
assert_eq!(
Field::new("c1", DataType::UInt64, false),
- result_agg_phy_exprs_distinct.field().unwrap()
+ result_distinct.field().unwrap()
+ );
+ }
+ AggregateFunction::ArrayAgg => {
+
assert!(result_distinct.as_any().is::<DistinctArrayAgg>());
+ assert_eq!("c1", result_distinct.name());
+ assert_eq!(
+ Field::new(
+ "c1",
+ DataType::List(Box::new(Field::new(
+ "item",
+ data_type.clone(),
+ true
+ ))),
+ false
+ ),
+ result_agg_phy_exprs.field().unwrap()
);
}
- AggregateFunction::ArrayAgg => match result_distinct {
- Err(NotImplemented(s)) => {
- assert_eq!(
- s,
- "ARRAY_AGG(DISTINCT) aggregations are not
available"
- .to_string()
- );
- }
- _ => {
- unreachable!()
- }
- },
_ => {}
};
}
diff --git a/datafusion/src/physical_plan/distinct_expressions.rs
b/datafusion/src/physical_plan/expressions/distinct_expressions.rs
similarity index 75%
rename from datafusion/src/physical_plan/distinct_expressions.rs
rename to datafusion/src/physical_plan/expressions/distinct_expressions.rs
index 080308a..bdbd82d 100644
--- a/datafusion/src/physical_plan/distinct_expressions.rs
+++ b/datafusion/src/physical_plan/expressions/distinct_expressions.rs
@@ -17,7 +17,6 @@
//! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)`
-use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};
use std::any::Any;
use std::fmt::Debug;
@@ -25,6 +24,7 @@ use std::hash::Hash;
use std::sync::Arc;
use ahash::RandomState;
+use arrow::array::{Array, ArrayRef};
use std::collections::HashSet;
use crate::error::{DataFusionError, Result};
@@ -232,17 +232,148 @@ impl Accumulator for DistinctCountAccumulator {
}
}
+/// Expression for a ARRAY_AGG(DISTINCT) aggregation.
+#[derive(Debug)]
+pub struct DistinctArrayAgg {
+ /// Column name
+ name: String,
+ /// The DataType for the input expression
+ input_data_type: DataType,
+ /// The input expression
+ expr: Arc<dyn PhysicalExpr>,
+}
+
+impl DistinctArrayAgg {
+ /// Create a new DistinctArrayAgg aggregate function
+ pub fn new(
+ expr: Arc<dyn PhysicalExpr>,
+ name: impl Into<String>,
+ input_data_type: DataType,
+ ) -> Self {
+ let name = name.into();
+ Self {
+ name,
+ expr,
+ input_data_type,
+ }
+ }
+}
+
+impl AggregateExpr for DistinctArrayAgg {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(
+ &self.name,
+ DataType::List(Box::new(Field::new(
+ "item",
+ self.input_data_type.clone(),
+ true,
+ ))),
+ false,
+ ))
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(DistinctArrayAggAccumulator::try_new(
+ &self.input_data_type,
+ )?))
+ }
+
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ Ok(vec![Field::new(
+ &format_state_name(&self.name, "distinct_array_agg"),
+ DataType::List(Box::new(Field::new(
+ "item",
+ self.input_data_type.clone(),
+ true,
+ ))),
+ false,
+ )])
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.expr.clone()]
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+}
+
+#[derive(Debug)]
+struct DistinctArrayAggAccumulator {
+ values: HashSet<ScalarValue>,
+ datatype: DataType,
+}
+
+impl DistinctArrayAggAccumulator {
+ pub fn try_new(datatype: &DataType) -> Result<Self> {
+ Ok(Self {
+ values: HashSet::new(),
+ datatype: datatype.clone(),
+ })
+ }
+}
+
+impl Accumulator for DistinctArrayAggAccumulator {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![ScalarValue::List(
+ Some(Box::new(self.values.clone().into_iter().collect())),
+ Box::new(self.datatype.clone()),
+ )])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ assert_eq!(values.len(), 1, "batch input should only include 1
column!");
+
+ let arr = &values[0];
+ for i in 0..arr.len() {
+ self.values.insert(ScalarValue::try_from_array(arr, i)?);
+ }
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ if states.is_empty() {
+ return Ok(());
+ };
+
+ for array in states {
+ for j in 0..array.len() {
+ self.values.insert(ScalarValue::try_from_array(array, j)?);
+ }
+ }
+
+ Ok(())
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ Ok(ScalarValue::List(
+ Some(Box::new(self.values.clone().into_iter().collect())),
+ Box::new(self.datatype.clone()),
+ ))
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
use crate::from_slice::FromSlice;
+ use crate::physical_plan::expressions::col;
+ use crate::physical_plan::expressions::tests::aggregate;
+
use arrow::array::{
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array,
Int32Array,
Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array,
UInt64Array,
UInt8Array,
};
use arrow::array::{Int32Builder, ListBuilder, UInt64Builder};
- use arrow::datatypes::DataType;
+ use arrow::datatypes::{DataType, Schema};
+ use arrow::record_batch::RecordBatch;
macro_rules! build_list {
($LISTS:expr, $BUILDER_TYPE:ident) => {{
@@ -741,4 +872,143 @@ mod tests {
Ok(())
}
+
+ fn check_distinct_array_agg(
+ input: ArrayRef,
+ expected: ScalarValue,
+ datatype: DataType,
+ ) -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", datatype.clone(),
false)]);
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![input])?;
+
+ let agg = Arc::new(DistinctArrayAgg::new(
+ col("a", &schema)?,
+ "bla".to_string(),
+ datatype,
+ ));
+ let actual = aggregate(&batch, agg)?;
+
+ match (expected, actual) {
+ (ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a),
_)) => {
+ // workaround lack of Ord of ScalarValue
+ let cmp = |a: &ScalarValue, b: &ScalarValue| {
+ a.partial_cmp(b).expect("Can compare ScalarValues")
+ };
+
+ e.sort_by(cmp);
+ a.sort_by(cmp);
+ // Check that the inputs are the same
+ assert_eq!(e, a);
+ }
+ _ => {
+ unreachable!()
+ }
+ }
+
+ Ok(())
+ }
+
+ #[test]
+ fn distinct_array_agg_i32() -> Result<()> {
+ let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2]));
+
+ let out = ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::Int32(Some(1)),
+ ScalarValue::Int32(Some(2)),
+ ScalarValue::Int32(Some(7)),
+ ScalarValue::Int32(Some(4)),
+ ScalarValue::Int32(Some(5)),
+ ])),
+ Box::new(DataType::Int32),
+ );
+
+ check_distinct_array_agg(col, out, DataType::Int32)
+ }
+
+ #[test]
+ fn distinct_array_agg_nested() -> Result<()> {
+ // [[1, 2, 3], [4, 5]]
+ let l1 = ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::from(1i32),
+ ScalarValue::from(2i32),
+ ScalarValue::from(3i32),
+ ])),
+ Box::new(DataType::Int32),
+ ),
+ ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::from(4i32),
+ ScalarValue::from(5i32),
+ ])),
+ Box::new(DataType::Int32),
+ ),
+ ])),
+ Box::new(DataType::List(Box::new(Field::new(
+ "item",
+ DataType::Int32,
+ true,
+ )))),
+ );
+
+ // [[6], [7, 8]]
+ let l2 = ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::List(
+ Some(Box::new(vec![ScalarValue::from(6i32)])),
+ Box::new(DataType::Int32),
+ ),
+ ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::from(7i32),
+ ScalarValue::from(8i32),
+ ])),
+ Box::new(DataType::Int32),
+ ),
+ ])),
+ Box::new(DataType::List(Box::new(Field::new(
+ "item",
+ DataType::Int32,
+ true,
+ )))),
+ );
+
+ // [[9]]
+ let l3 = ScalarValue::List(
+ Some(Box::new(vec![ScalarValue::List(
+ Some(Box::new(vec![ScalarValue::from(9i32)])),
+ Box::new(DataType::Int32),
+ )])),
+ Box::new(DataType::List(Box::new(Field::new(
+ "item",
+ DataType::Int32,
+ true,
+ )))),
+ );
+
+ let list = ScalarValue::List(
+ Some(Box::new(vec![l1.clone(), l2.clone(), l3.clone()])),
+ Box::new(DataType::List(Box::new(Field::new(
+ "item",
+ DataType::Int32,
+ true,
+ )))),
+ );
+
+ // Duplicate l1 in the input array and check that it is deduped in the
output.
+ let array = ScalarValue::iter_to_array(vec![l1.clone(), l2, l3,
l1]).unwrap();
+
+ check_distinct_array_agg(
+ array,
+ list,
+ DataType::List(Box::new(Field::new(
+ "item",
+ DataType::List(Box::new(Field::new("item", DataType::Int32,
true))),
+ true,
+ ))),
+ )
+ }
}
diff --git a/datafusion/src/physical_plan/expressions/mod.rs
b/datafusion/src/physical_plan/expressions/mod.rs
index 9ed1693..ca14d7f 100644
--- a/datafusion/src/physical_plan/expressions/mod.rs
+++ b/datafusion/src/physical_plan/expressions/mod.rs
@@ -45,6 +45,7 @@ mod literal;
mod min_max;
mod correlation;
mod covariance;
+mod distinct_expressions;
mod negative;
mod not;
mod nth_value;
@@ -80,6 +81,7 @@ pub(crate) use covariance::{
covariance_return_type, is_covariance_support_arg_type, Covariance,
CovariancePop,
};
pub use cume_dist::cume_dist;
+pub use distinct_expressions::{DistinctArrayAgg, DistinctCount};
pub use get_indexed_field::GetIndexedFieldExpr;
pub use in_list::{in_list, InListExpr};
pub use is_not_null::{is_not_null, IsNotNullExpr};
diff --git a/datafusion/src/physical_plan/mod.rs
b/datafusion/src/physical_plan/mod.rs
index be59968..ce12722 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -598,7 +598,6 @@ pub mod cross_join;
pub mod crypto_expressions;
pub mod datetime_expressions;
pub mod display;
-pub mod distinct_expressions;
pub mod empty;
pub mod explain;
pub mod expressions;
diff --git a/datafusion/tests/sql/aggregates.rs
b/datafusion/tests/sql/aggregates.rs
index 785d308..9d72752 100644
--- a/datafusion/tests/sql/aggregates.rs
+++ b/datafusion/tests/sql/aggregates.rs
@@ -16,6 +16,7 @@
// under the License.
use super::*;
+use datafusion::scalar::ScalarValue;
#[tokio::test]
async fn csv_query_avg_multi_batch() -> Result<()> {
@@ -422,3 +423,53 @@ async fn csv_query_array_agg_one() -> Result<()> {
assert_batches_eq!(expected, &actual);
Ok(())
}
+
+#[tokio::test]
+async fn csv_query_array_agg_distinct() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx).await?;
+ let sql = "SELECT array_agg(distinct c2) FROM aggregate_test_100";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+
+ // The results for this query should be something like the following:
+ // +------------------------------------------+
+ // | ARRAYAGG(DISTINCT aggregate_test_100.c2) |
+ // +------------------------------------------+
+ // | [4, 2, 3, 5, 1] |
+ // +------------------------------------------+
+ // Since ARRAY_AGG(DISTINCT) ordering is nondeterministic, check the
schema and contents.
+ assert_eq!(
+ *actual[0].schema(),
+ Schema::new(vec![Field::new(
+ "ARRAYAGG(DISTINCT aggregate_test_100.c2)",
+ DataType::List(Box::new(Field::new("item", DataType::UInt32,
true))),
+ false
+ ),])
+ );
+
+ // We should have 1 row containing a list
+ let column = actual[0].column(0);
+ assert_eq!(column.len(), 1);
+
+ if let ScalarValue::List(Some(mut v), _) =
ScalarValue::try_from_array(column, 0)? {
+ // workaround lack of Ord of ScalarValue
+ let cmp = |a: &ScalarValue, b: &ScalarValue| {
+ a.partial_cmp(b).expect("Can compare ScalarValues")
+ };
+ v.sort_by(cmp);
+ assert_eq!(
+ *v,
+ vec![
+ ScalarValue::UInt32(Some(1)),
+ ScalarValue::UInt32(Some(2)),
+ ScalarValue::UInt32(Some(3)),
+ ScalarValue::UInt32(Some(4)),
+ ScalarValue::UInt32(Some(5))
+ ]
+ );
+ } else {
+ unreachable!();
+ }
+
+ Ok(())
+}