This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new a96bb5e  Implement ARRAY_AGG(DISTINCT ...) (#1579)
a96bb5e is described below

commit a96bb5eea8aa0dddc2bdf30de737b76ccb132054
Author: James Katz <[email protected]>
AuthorDate: Thu Jan 20 14:53:44 2022 -0500

    Implement ARRAY_AGG(DISTINCT ...) (#1579)
    
    * Implement ARRAY_AGG(DISTINCT)
    
    * Add integration test
    
    * Move distinct_expressions into physical_plan/expressions
    
    * Add physical plan unit tests
    
    * Fix clippy
    
    * Fix rebase import mistake
    
    * Clean up distinct tests
    
    * Fix clippy
---
 datafusion/src/physical_plan/aggregates.rs         |  77 +++---
 .../{ => expressions}/distinct_expressions.rs      | 274 ++++++++++++++++++++-
 datafusion/src/physical_plan/expressions/mod.rs    |   2 +
 datafusion/src/physical_plan/mod.rs                |   1 -
 datafusion/tests/sql/aggregates.rs                 |  51 ++++
 5 files changed, 361 insertions(+), 44 deletions(-)

diff --git a/datafusion/src/physical_plan/aggregates.rs 
b/datafusion/src/physical_plan/aggregates.rs
index 1495e05..f7beb76 100644
--- a/datafusion/src/physical_plan/aggregates.rs
+++ b/datafusion/src/physical_plan/aggregates.rs
@@ -32,7 +32,6 @@ use super::{
 };
 use crate::error::{DataFusionError, Result};
 use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, 
coerce_types};
-use crate::physical_plan::distinct_expressions;
 use crate::physical_plan::expressions;
 use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
 use expressions::{
@@ -198,14 +197,12 @@ pub fn create_aggregate_expr(
             name,
             return_type,
         )),
-        (AggregateFunction::Count, true) => {
-            Arc::new(distinct_expressions::DistinctCount::new(
-                coerced_exprs_types,
-                coerced_phy_exprs,
-                name,
-                return_type,
-            ))
-        }
+        (AggregateFunction::Count, true) => 
Arc::new(expressions::DistinctCount::new(
+            coerced_exprs_types,
+            coerced_phy_exprs,
+            name,
+            return_type,
+        )),
         (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new(
             coerced_phy_exprs[0].clone(),
             name,
@@ -229,9 +226,11 @@ pub fn create_aggregate_expr(
             coerced_exprs_types[0].clone(),
         )),
         (AggregateFunction::ArrayAgg, true) => {
-            return Err(DataFusionError::NotImplemented(
-                "ARRAY_AGG(DISTINCT) aggregations are not 
available".to_string(),
-            ));
+            Arc::new(expressions::DistinctArrayAgg::new(
+                coerced_phy_exprs[0].clone(),
+                name,
+                coerced_exprs_types[0].clone(),
+            ))
         }
         (AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
             coerced_phy_exprs[0].clone(),
@@ -396,12 +395,10 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::error::DataFusionError::NotImplemented;
     use crate::error::Result;
-    use crate::physical_plan::distinct_expressions::DistinctCount;
     use crate::physical_plan::expressions::{
-        ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance, Max, 
Min, Stddev,
-        Sum, Variance,
+        ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance, 
DistinctArrayAgg,
+        DistinctCount, Max, Min, Stddev, Sum, Variance,
     };
 
     #[test]
@@ -475,42 +472,40 @@ mod tests {
                     &input_phy_exprs[0..1],
                     &input_schema,
                     "c1",
-                );
+                )?;
                 match fun {
                     AggregateFunction::Count => {
-                        let result_agg_phy_exprs_distinct = result_distinct?;
-                        assert!(result_agg_phy_exprs_distinct
-                            .as_any()
-                            .is::<DistinctCount>());
-                        assert_eq!("c1", result_agg_phy_exprs_distinct.name());
+                        
assert!(result_distinct.as_any().is::<DistinctCount>());
+                        assert_eq!("c1", result_distinct.name());
                         assert_eq!(
                             Field::new("c1", DataType::UInt64, true),
-                            result_agg_phy_exprs_distinct.field().unwrap()
+                            result_distinct.field().unwrap()
                         );
                     }
                     AggregateFunction::ApproxDistinct => {
-                        let result_agg_phy_exprs_distinct = result_distinct?;
-                        assert!(result_agg_phy_exprs_distinct
-                            .as_any()
-                            .is::<ApproxDistinct>());
-                        assert_eq!("c1", result_agg_phy_exprs_distinct.name());
+                        
assert!(result_distinct.as_any().is::<ApproxDistinct>());
+                        assert_eq!("c1", result_distinct.name());
                         assert_eq!(
                             Field::new("c1", DataType::UInt64, false),
-                            result_agg_phy_exprs_distinct.field().unwrap()
+                            result_distinct.field().unwrap()
+                        );
+                    }
+                    AggregateFunction::ArrayAgg => {
+                        
assert!(result_distinct.as_any().is::<DistinctArrayAgg>());
+                        assert_eq!("c1", result_distinct.name());
+                        assert_eq!(
+                            Field::new(
+                                "c1",
+                                DataType::List(Box::new(Field::new(
+                                    "item",
+                                    data_type.clone(),
+                                    true
+                                ))),
+                                false
+                            ),
+                            result_agg_phy_exprs.field().unwrap()
                         );
                     }
-                    AggregateFunction::ArrayAgg => match result_distinct {
-                        Err(NotImplemented(s)) => {
-                            assert_eq!(
-                                s,
-                                "ARRAY_AGG(DISTINCT) aggregations are not 
available"
-                                    .to_string()
-                            );
-                        }
-                        _ => {
-                            unreachable!()
-                        }
-                    },
                     _ => {}
                 };
             }
diff --git a/datafusion/src/physical_plan/distinct_expressions.rs 
b/datafusion/src/physical_plan/expressions/distinct_expressions.rs
similarity index 75%
rename from datafusion/src/physical_plan/distinct_expressions.rs
rename to datafusion/src/physical_plan/expressions/distinct_expressions.rs
index 080308a..bdbd82d 100644
--- a/datafusion/src/physical_plan/distinct_expressions.rs
+++ b/datafusion/src/physical_plan/expressions/distinct_expressions.rs
@@ -17,7 +17,6 @@
 
 //! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)`
 
-use arrow::array::ArrayRef;
 use arrow::datatypes::{DataType, Field};
 use std::any::Any;
 use std::fmt::Debug;
@@ -25,6 +24,7 @@ use std::hash::Hash;
 use std::sync::Arc;
 
 use ahash::RandomState;
+use arrow::array::{Array, ArrayRef};
 use std::collections::HashSet;
 
 use crate::error::{DataFusionError, Result};
@@ -232,17 +232,148 @@ impl Accumulator for DistinctCountAccumulator {
     }
 }
 
+/// Expression for a ARRAY_AGG(DISTINCT) aggregation.
+#[derive(Debug)]
+pub struct DistinctArrayAgg {
+    /// Column name
+    name: String,
+    /// The DataType for the input expression
+    input_data_type: DataType,
+    /// The input expression
+    expr: Arc<dyn PhysicalExpr>,
+}
+
+impl DistinctArrayAgg {
+    /// Create a new DistinctArrayAgg aggregate function
+    pub fn new(
+        expr: Arc<dyn PhysicalExpr>,
+        name: impl Into<String>,
+        input_data_type: DataType,
+    ) -> Self {
+        let name = name.into();
+        Self {
+            name,
+            expr,
+            input_data_type,
+        }
+    }
+}
+
+impl AggregateExpr for DistinctArrayAgg {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn field(&self) -> Result<Field> {
+        Ok(Field::new(
+            &self.name,
+            DataType::List(Box::new(Field::new(
+                "item",
+                self.input_data_type.clone(),
+                true,
+            ))),
+            false,
+        ))
+    }
+
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(DistinctArrayAggAccumulator::try_new(
+            &self.input_data_type,
+        )?))
+    }
+
+    fn state_fields(&self) -> Result<Vec<Field>> {
+        Ok(vec![Field::new(
+            &format_state_name(&self.name, "distinct_array_agg"),
+            DataType::List(Box::new(Field::new(
+                "item",
+                self.input_data_type.clone(),
+                true,
+            ))),
+            false,
+        )])
+    }
+
+    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![self.expr.clone()]
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+}
+
+#[derive(Debug)]
+struct DistinctArrayAggAccumulator {
+    values: HashSet<ScalarValue>,
+    datatype: DataType,
+}
+
+impl DistinctArrayAggAccumulator {
+    pub fn try_new(datatype: &DataType) -> Result<Self> {
+        Ok(Self {
+            values: HashSet::new(),
+            datatype: datatype.clone(),
+        })
+    }
+}
+
+impl Accumulator for DistinctArrayAggAccumulator {
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        Ok(vec![ScalarValue::List(
+            Some(Box::new(self.values.clone().into_iter().collect())),
+            Box::new(self.datatype.clone()),
+        )])
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        assert_eq!(values.len(), 1, "batch input should only include 1 
column!");
+
+        let arr = &values[0];
+        for i in 0..arr.len() {
+            self.values.insert(ScalarValue::try_from_array(arr, i)?);
+        }
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        if states.is_empty() {
+            return Ok(());
+        };
+
+        for array in states {
+            for j in 0..array.len() {
+                self.values.insert(ScalarValue::try_from_array(array, j)?);
+            }
+        }
+
+        Ok(())
+    }
+
+    fn evaluate(&self) -> Result<ScalarValue> {
+        Ok(ScalarValue::List(
+            Some(Box::new(self.values.clone().into_iter().collect())),
+            Box::new(self.datatype.clone()),
+        ))
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
     use crate::from_slice::FromSlice;
+    use crate::physical_plan::expressions::col;
+    use crate::physical_plan::expressions::tests::aggregate;
+
     use arrow::array::{
         ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, 
Int32Array,
         Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array, 
UInt64Array,
         UInt8Array,
     };
     use arrow::array::{Int32Builder, ListBuilder, UInt64Builder};
-    use arrow::datatypes::DataType;
+    use arrow::datatypes::{DataType, Schema};
+    use arrow::record_batch::RecordBatch;
 
     macro_rules! build_list {
         ($LISTS:expr, $BUILDER_TYPE:ident) => {{
@@ -741,4 +872,143 @@ mod tests {
 
         Ok(())
     }
+
+    fn check_distinct_array_agg(
+        input: ArrayRef,
+        expected: ScalarValue,
+        datatype: DataType,
+    ) -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", datatype.clone(), 
false)]);
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![input])?;
+
+        let agg = Arc::new(DistinctArrayAgg::new(
+            col("a", &schema)?,
+            "bla".to_string(),
+            datatype,
+        ));
+        let actual = aggregate(&batch, agg)?;
+
+        match (expected, actual) {
+            (ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a), 
_)) => {
+                // workaround lack of Ord of ScalarValue
+                let cmp = |a: &ScalarValue, b: &ScalarValue| {
+                    a.partial_cmp(b).expect("Can compare ScalarValues")
+                };
+
+                e.sort_by(cmp);
+                a.sort_by(cmp);
+                // Check that the inputs are the same
+                assert_eq!(e, a);
+            }
+            _ => {
+                unreachable!()
+            }
+        }
+
+        Ok(())
+    }
+
+    #[test]
+    fn distinct_array_agg_i32() -> Result<()> {
+        let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2]));
+
+        let out = ScalarValue::List(
+            Some(Box::new(vec![
+                ScalarValue::Int32(Some(1)),
+                ScalarValue::Int32(Some(2)),
+                ScalarValue::Int32(Some(7)),
+                ScalarValue::Int32(Some(4)),
+                ScalarValue::Int32(Some(5)),
+            ])),
+            Box::new(DataType::Int32),
+        );
+
+        check_distinct_array_agg(col, out, DataType::Int32)
+    }
+
+    #[test]
+    fn distinct_array_agg_nested() -> Result<()> {
+        // [[1, 2, 3], [4, 5]]
+        let l1 = ScalarValue::List(
+            Some(Box::new(vec![
+                ScalarValue::List(
+                    Some(Box::new(vec![
+                        ScalarValue::from(1i32),
+                        ScalarValue::from(2i32),
+                        ScalarValue::from(3i32),
+                    ])),
+                    Box::new(DataType::Int32),
+                ),
+                ScalarValue::List(
+                    Some(Box::new(vec![
+                        ScalarValue::from(4i32),
+                        ScalarValue::from(5i32),
+                    ])),
+                    Box::new(DataType::Int32),
+                ),
+            ])),
+            Box::new(DataType::List(Box::new(Field::new(
+                "item",
+                DataType::Int32,
+                true,
+            )))),
+        );
+
+        // [[6], [7, 8]]
+        let l2 = ScalarValue::List(
+            Some(Box::new(vec![
+                ScalarValue::List(
+                    Some(Box::new(vec![ScalarValue::from(6i32)])),
+                    Box::new(DataType::Int32),
+                ),
+                ScalarValue::List(
+                    Some(Box::new(vec![
+                        ScalarValue::from(7i32),
+                        ScalarValue::from(8i32),
+                    ])),
+                    Box::new(DataType::Int32),
+                ),
+            ])),
+            Box::new(DataType::List(Box::new(Field::new(
+                "item",
+                DataType::Int32,
+                true,
+            )))),
+        );
+
+        // [[9]]
+        let l3 = ScalarValue::List(
+            Some(Box::new(vec![ScalarValue::List(
+                Some(Box::new(vec![ScalarValue::from(9i32)])),
+                Box::new(DataType::Int32),
+            )])),
+            Box::new(DataType::List(Box::new(Field::new(
+                "item",
+                DataType::Int32,
+                true,
+            )))),
+        );
+
+        let list = ScalarValue::List(
+            Some(Box::new(vec![l1.clone(), l2.clone(), l3.clone()])),
+            Box::new(DataType::List(Box::new(Field::new(
+                "item",
+                DataType::Int32,
+                true,
+            )))),
+        );
+
+        // Duplicate l1 in the input array and check that it is deduped in the 
output.
+        let array = ScalarValue::iter_to_array(vec![l1.clone(), l2, l3, 
l1]).unwrap();
+
+        check_distinct_array_agg(
+            array,
+            list,
+            DataType::List(Box::new(Field::new(
+                "item",
+                DataType::List(Box::new(Field::new("item", DataType::Int32, 
true))),
+                true,
+            ))),
+        )
+    }
 }
diff --git a/datafusion/src/physical_plan/expressions/mod.rs 
b/datafusion/src/physical_plan/expressions/mod.rs
index 9ed1693..ca14d7f 100644
--- a/datafusion/src/physical_plan/expressions/mod.rs
+++ b/datafusion/src/physical_plan/expressions/mod.rs
@@ -45,6 +45,7 @@ mod literal;
 mod min_max;
 mod correlation;
 mod covariance;
+mod distinct_expressions;
 mod negative;
 mod not;
 mod nth_value;
@@ -80,6 +81,7 @@ pub(crate) use covariance::{
     covariance_return_type, is_covariance_support_arg_type, Covariance, 
CovariancePop,
 };
 pub use cume_dist::cume_dist;
+pub use distinct_expressions::{DistinctArrayAgg, DistinctCount};
 pub use get_indexed_field::GetIndexedFieldExpr;
 pub use in_list::{in_list, InListExpr};
 pub use is_not_null::{is_not_null, IsNotNullExpr};
diff --git a/datafusion/src/physical_plan/mod.rs 
b/datafusion/src/physical_plan/mod.rs
index be59968..ce12722 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -598,7 +598,6 @@ pub mod cross_join;
 pub mod crypto_expressions;
 pub mod datetime_expressions;
 pub mod display;
-pub mod distinct_expressions;
 pub mod empty;
 pub mod explain;
 pub mod expressions;
diff --git a/datafusion/tests/sql/aggregates.rs 
b/datafusion/tests/sql/aggregates.rs
index 785d308..9d72752 100644
--- a/datafusion/tests/sql/aggregates.rs
+++ b/datafusion/tests/sql/aggregates.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 use super::*;
+use datafusion::scalar::ScalarValue;
 
 #[tokio::test]
 async fn csv_query_avg_multi_batch() -> Result<()> {
@@ -422,3 +423,53 @@ async fn csv_query_array_agg_one() -> Result<()> {
     assert_batches_eq!(expected, &actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn csv_query_array_agg_distinct() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx).await?;
+    let sql = "SELECT array_agg(distinct c2) FROM aggregate_test_100";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+
+    // The results for this query should be something like the following:
+    //    +------------------------------------------+
+    //    | ARRAYAGG(DISTINCT aggregate_test_100.c2) |
+    //    +------------------------------------------+
+    //    | [4, 2, 3, 5, 1]                          |
+    //    +------------------------------------------+
+    // Since ARRAY_AGG(DISTINCT) ordering is nondeterministic, check the 
schema and contents.
+    assert_eq!(
+        *actual[0].schema(),
+        Schema::new(vec![Field::new(
+            "ARRAYAGG(DISTINCT aggregate_test_100.c2)",
+            DataType::List(Box::new(Field::new("item", DataType::UInt32, 
true))),
+            false
+        ),])
+    );
+
+    // We should have 1 row containing a list
+    let column = actual[0].column(0);
+    assert_eq!(column.len(), 1);
+
+    if let ScalarValue::List(Some(mut v), _) = 
ScalarValue::try_from_array(column, 0)? {
+        // workaround lack of Ord of ScalarValue
+        let cmp = |a: &ScalarValue, b: &ScalarValue| {
+            a.partial_cmp(b).expect("Can compare ScalarValues")
+        };
+        v.sort_by(cmp);
+        assert_eq!(
+            *v,
+            vec![
+                ScalarValue::UInt32(Some(1)),
+                ScalarValue::UInt32(Some(2)),
+                ScalarValue::UInt32(Some(3)),
+                ScalarValue::UInt32(Some(4)),
+                ScalarValue::UInt32(Some(5))
+            ]
+        );
+    } else {
+        unreachable!();
+    }
+
+    Ok(())
+}

Reply via email to