This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ea9144e659 fix: inconsistent scalar types in 
`DistinctArrayAggAccumulator` state (#7385)
ea9144e659 is described below

commit ea9144e6597593c09a4ef0b71a4da1cdfaca8249
Author: Eduard Karacharov <[email protected]>
AuthorDate: Thu Aug 24 21:24:52 2023 +0300

    fix: inconsistent scalar types in `DistinctArrayAggAccumulator` state 
(#7385)
    
    * fix: inconsistent types in array_agg_distinct merge_batch
    
    * Apply suggestions from code review
    
    Co-authored-by: Metehan Yıldırım 
<[email protected]>
    
    * filtering NULLs & validating sqllogictest output
    
    ---------
    
    Co-authored-by: Metehan Yıldırım 
<[email protected]>
---
 .../src/aggregate/array_agg_distinct.rs            | 182 ++++++++++++++++++---
 datafusion/sqllogictest/test_files/aggregate.slt   |  29 +++-
 2 files changed, 178 insertions(+), 33 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs 
b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
index 2d7a6e5b0e..422eecd201 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
@@ -28,8 +28,7 @@ use std::collections::HashSet;
 use crate::aggregate::utils::down_cast_any_ref;
 use crate::expressions::format_state_name;
 use crate::{AggregateExpr, PhysicalExpr};
-use datafusion_common::Result;
-use datafusion_common::ScalarValue;
+use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
 use datafusion_expr::Accumulator;
 
 /// Expression for a ARRAY_AGG(DISTINCT) aggregation.
@@ -135,11 +134,13 @@ impl Accumulator for DistinctArrayAggAccumulator {
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
         assert_eq!(values.len(), 1, "batch input should only include 1 
column!");
 
-        let arr = &values[0];
-        for i in 0..arr.len() {
-            self.values.insert(ScalarValue::try_from_array(arr, i)?);
-        }
-        Ok(())
+        let array = &values[0];
+        (0..array.len()).try_for_each(|i| {
+            if !array.is_null(i) {
+                self.values.insert(ScalarValue::try_from_array(array, i)?);
+            }
+            Ok(())
+        })
     }
 
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
@@ -147,11 +148,22 @@ impl Accumulator for DistinctArrayAggAccumulator {
             return Ok(());
         }
 
-        for array in states {
-            for j in 0..array.len() {
-                self.values.insert(ScalarValue::try_from_array(array, j)?);
+        assert_eq!(
+            states.len(),
+            1,
+            "array_agg_distinct states must contain single array"
+        );
+
+        let array = &states[0];
+        (0..array.len()).try_for_each(|i| {
+            let scalar = ScalarValue::try_from_array(array, i)?;
+            if let ScalarValue::List(Some(values), _) = scalar {
+                self.values.extend(values);
+                Ok(())
+            } else {
+                internal_err!("array_agg_distinct state must be list")
             }
-        }
+        })?;
 
         Ok(())
     }
@@ -174,12 +186,35 @@ impl Accumulator for DistinctArrayAggAccumulator {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::aggregate::utils::get_accum_scalar_values_as_arrays;
     use crate::expressions::col;
     use crate::expressions::tests::aggregate;
     use arrow::array::{ArrayRef, Int32Array};
     use arrow::datatypes::{DataType, Schema};
     use arrow::record_batch::RecordBatch;
 
+    fn compare_list_contents(expected: ScalarValue, actual: ScalarValue) -> 
Result<()> {
+        match (expected, actual) {
+            (ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a), 
_)) => {
+                // workaround lack of Ord of ScalarValue
+                let cmp = |a: &ScalarValue, b: &ScalarValue| {
+                    a.partial_cmp(b).expect("Can compare ScalarValues")
+                };
+
+                e.sort_by(cmp);
+                a.sort_by(cmp);
+                // Check that the inputs are the same
+                assert_eq!(e, a);
+            }
+            _ => {
+                return Err(DataFusionError::Internal(
+                    "Expected scalar lists as inputs".to_string(),
+                ));
+            }
+        }
+        Ok(())
+    }
+
     fn check_distinct_array_agg(
         input: ArrayRef,
         expected: ScalarValue,
@@ -195,24 +230,34 @@ mod tests {
         ));
         let actual = aggregate(&batch, agg)?;
 
-        match (expected, actual) {
-            (ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a), 
_)) => {
-                // workaround lack of Ord of ScalarValue
-                let cmp = |a: &ScalarValue, b: &ScalarValue| {
-                    a.partial_cmp(b).expect("Can compare ScalarValues")
-                };
+        compare_list_contents(expected, actual)
+    }
 
-                e.sort_by(cmp);
-                a.sort_by(cmp);
-                // Check that the inputs are the same
-                assert_eq!(e, a);
-            }
-            _ => {
-                unreachable!()
-            }
-        }
+    fn check_merge_distinct_array_agg(
+        input1: ArrayRef,
+        input2: ArrayRef,
+        expected: ScalarValue,
+        datatype: DataType,
+    ) -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", datatype.clone(), 
false)]);
+        let agg = Arc::new(DistinctArrayAgg::new(
+            col("a", &schema)?,
+            "bla".to_string(),
+            datatype,
+        ));
 
-        Ok(())
+        let mut accum1 = agg.create_accumulator()?;
+        let mut accum2 = agg.create_accumulator()?;
+
+        accum1.update_batch(&[input1])?;
+        accum2.update_batch(&[input2])?;
+
+        let state = get_accum_scalar_values_as_arrays(accum2.as_ref())?;
+        accum1.merge_batch(&state)?;
+
+        let actual = accum1.evaluate()?;
+
+        compare_list_contents(expected, actual)
     }
 
     #[test]
@@ -233,6 +278,27 @@ mod tests {
         check_distinct_array_agg(col, out, DataType::Int32)
     }
 
+    #[test]
+    fn merge_distinct_array_agg_i32() -> Result<()> {
+        let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 
2]));
+        let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4]));
+
+        let out = ScalarValue::new_list(
+            Some(vec![
+                ScalarValue::Int32(Some(1)),
+                ScalarValue::Int32(Some(2)),
+                ScalarValue::Int32(Some(3)),
+                ScalarValue::Int32(Some(4)),
+                ScalarValue::Int32(Some(5)),
+                ScalarValue::Int32(Some(7)),
+                ScalarValue::Int32(Some(8)),
+            ]),
+            DataType::Int32,
+        );
+
+        check_merge_distinct_array_agg(col1, col2, out, DataType::Int32)
+    }
+
     #[test]
     fn distinct_array_agg_nested() -> Result<()> {
         // [[1, 2, 3], [4, 5]]
@@ -296,4 +362,66 @@ mod tests {
             ))),
         )
     }
+
+    #[test]
+    fn merge_distinct_array_agg_nested() -> Result<()> {
+        // [[1, 2], [3, 4]]
+        let l1 = ScalarValue::new_list(
+            Some(vec![
+                ScalarValue::new_list(
+                    Some(vec![ScalarValue::from(1i32), 
ScalarValue::from(2i32)]),
+                    DataType::Int32,
+                ),
+                ScalarValue::new_list(
+                    Some(vec![ScalarValue::from(3i32), 
ScalarValue::from(4i32)]),
+                    DataType::Int32,
+                ),
+            ]),
+            DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true))),
+        );
+
+        // [[5]]
+        let l2 = ScalarValue::new_list(
+            Some(vec![ScalarValue::new_list(
+                Some(vec![ScalarValue::from(5i32)]),
+                DataType::Int32,
+            )]),
+            DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true))),
+        );
+
+        // [[6, 7], [8]]
+        let l3 = ScalarValue::new_list(
+            Some(vec![
+                ScalarValue::new_list(
+                    Some(vec![ScalarValue::from(6i32), 
ScalarValue::from(7i32)]),
+                    DataType::Int32,
+                ),
+                ScalarValue::new_list(
+                    Some(vec![ScalarValue::from(8i32)]),
+                    DataType::Int32,
+                ),
+            ]),
+            DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true))),
+        );
+
+        let expected = ScalarValue::new_list(
+            Some(vec![l1.clone(), l2.clone(), l3.clone()]),
+            DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true))),
+        );
+
+        // Duplicate l1 in the input array and check that it is deduped in the 
output.
+        let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2]).unwrap();
+        let input2 = ScalarValue::iter_to_array(vec![l1, l3]).unwrap();
+
+        check_merge_distinct_array_agg(
+            input1,
+            input2,
+            expected,
+            DataType::List(Arc::new(Field::new_list(
+                "item",
+                Field::new("item", DataType::Int32, true),
+                true,
+            ))),
+        )
+    }
 }
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index e881acf575..c89362a2f2 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -1271,14 +1271,31 @@ NULL 4 29 1.260869565217 123 -117 23
 NULL 5 -194 -13.857142857143 118 -101 14
 NULL NULL 781 7.81 125 -117 100
 
-# TODO this querys output is non determinisitic (the order of the elements
-# differs run to run
+# TODO: array_agg_distinct output is non-determinisitic -- rewrite with 
array_sort(list_sort)
+#       unnest is also not available, so manually unnesting via CROSS JOIN
+# additional count(1) forces array_agg_distinct instead of array_agg over 
aggregated by c2 data
 #
 # csv_query_array_agg_distinct
-# query T
-# SELECT array_agg(distinct c2) FROM aggregate_test_100
-# ----
-# [4, 2, 3, 5, 1]
+query III
+WITH indices AS (
+  SELECT 1 AS idx UNION ALL
+  SELECT 2 AS idx UNION ALL
+  SELECT 3 AS idx UNION ALL
+  SELECT 4 AS idx UNION ALL
+  SELECT 5 AS idx
+)
+SELECT data.arr[indices.idx] as element, array_length(data.arr) as array_len, 
dummy
+FROM (
+  SELECT array_agg(distinct c2) as arr, count(1) as dummy FROM 
aggregate_test_100
+) data
+  CROSS JOIN indices
+ORDER BY 1
+----
+1 5 100
+2 5 100
+3 5 100
+4 5 100
+5 5 100
 
 # aggregate_time_min_and_max
 query TT

Reply via email to