This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ea9144e659 fix: inconsistent scalar types in
`DistinctArrayAggAccumulator` state (#7385)
ea9144e659 is described below
commit ea9144e6597593c09a4ef0b71a4da1cdfaca8249
Author: Eduard Karacharov <[email protected]>
AuthorDate: Thu Aug 24 21:24:52 2023 +0300
fix: inconsistent scalar types in `DistinctArrayAggAccumulator` state
(#7385)
* fix: inconsistent types in array_agg_distinct merge_batch
* Apply suggestions from code review
Co-authored-by: Metehan Yıldırım
<[email protected]>
* filtering NULLs & validating sqllogictest output
---------
Co-authored-by: Metehan Yıldırım
<[email protected]>
---
.../src/aggregate/array_agg_distinct.rs | 182 ++++++++++++++++++---
datafusion/sqllogictest/test_files/aggregate.slt | 29 +++-
2 files changed, 178 insertions(+), 33 deletions(-)
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
index 2d7a6e5b0e..422eecd201 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
@@ -28,8 +28,7 @@ use std::collections::HashSet;
use crate::aggregate::utils::down_cast_any_ref;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
-use datafusion_common::Result;
-use datafusion_common::ScalarValue;
+use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::Accumulator;
/// Expression for a ARRAY_AGG(DISTINCT) aggregation.
@@ -135,11 +134,13 @@ impl Accumulator for DistinctArrayAggAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
assert_eq!(values.len(), 1, "batch input should only include 1
column!");
- let arr = &values[0];
- for i in 0..arr.len() {
- self.values.insert(ScalarValue::try_from_array(arr, i)?);
- }
- Ok(())
+ let array = &values[0];
+ (0..array.len()).try_for_each(|i| {
+ if !array.is_null(i) {
+ self.values.insert(ScalarValue::try_from_array(array, i)?);
+ }
+ Ok(())
+ })
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
@@ -147,11 +148,22 @@ impl Accumulator for DistinctArrayAggAccumulator {
return Ok(());
}
- for array in states {
- for j in 0..array.len() {
- self.values.insert(ScalarValue::try_from_array(array, j)?);
+ assert_eq!(
+ states.len(),
+ 1,
+ "array_agg_distinct states must contain single array"
+ );
+
+ let array = &states[0];
+ (0..array.len()).try_for_each(|i| {
+ let scalar = ScalarValue::try_from_array(array, i)?;
+ if let ScalarValue::List(Some(values), _) = scalar {
+ self.values.extend(values);
+ Ok(())
+ } else {
+ internal_err!("array_agg_distinct state must be list")
}
- }
+ })?;
Ok(())
}
@@ -174,12 +186,35 @@ impl Accumulator for DistinctArrayAggAccumulator {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::aggregate::utils::get_accum_scalar_values_as_arrays;
use crate::expressions::col;
use crate::expressions::tests::aggregate;
use arrow::array::{ArrayRef, Int32Array};
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
+ fn compare_list_contents(expected: ScalarValue, actual: ScalarValue) ->
Result<()> {
+ match (expected, actual) {
+ (ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a),
_)) => {
+ // workaround lack of Ord of ScalarValue
+ let cmp = |a: &ScalarValue, b: &ScalarValue| {
+ a.partial_cmp(b).expect("Can compare ScalarValues")
+ };
+
+ e.sort_by(cmp);
+ a.sort_by(cmp);
+ // Check that the inputs are the same
+ assert_eq!(e, a);
+ }
+ _ => {
+ return Err(DataFusionError::Internal(
+ "Expected scalar lists as inputs".to_string(),
+ ));
+ }
+ }
+ Ok(())
+ }
+
fn check_distinct_array_agg(
input: ArrayRef,
expected: ScalarValue,
@@ -195,24 +230,34 @@ mod tests {
));
let actual = aggregate(&batch, agg)?;
- match (expected, actual) {
- (ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a),
_)) => {
- // workaround lack of Ord of ScalarValue
- let cmp = |a: &ScalarValue, b: &ScalarValue| {
- a.partial_cmp(b).expect("Can compare ScalarValues")
- };
+ compare_list_contents(expected, actual)
+ }
- e.sort_by(cmp);
- a.sort_by(cmp);
- // Check that the inputs are the same
- assert_eq!(e, a);
- }
- _ => {
- unreachable!()
- }
- }
+ fn check_merge_distinct_array_agg(
+ input1: ArrayRef,
+ input2: ArrayRef,
+ expected: ScalarValue,
+ datatype: DataType,
+ ) -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", datatype.clone(),
false)]);
+ let agg = Arc::new(DistinctArrayAgg::new(
+ col("a", &schema)?,
+ "bla".to_string(),
+ datatype,
+ ));
- Ok(())
+ let mut accum1 = agg.create_accumulator()?;
+ let mut accum2 = agg.create_accumulator()?;
+
+ accum1.update_batch(&[input1])?;
+ accum2.update_batch(&[input2])?;
+
+ let state = get_accum_scalar_values_as_arrays(accum2.as_ref())?;
+ accum1.merge_batch(&state)?;
+
+ let actual = accum1.evaluate()?;
+
+ compare_list_contents(expected, actual)
}
#[test]
@@ -233,6 +278,27 @@ mod tests {
check_distinct_array_agg(col, out, DataType::Int32)
}
+ #[test]
+ fn merge_distinct_array_agg_i32() -> Result<()> {
+ let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5,
2]));
+ let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4]));
+
+ let out = ScalarValue::new_list(
+ Some(vec![
+ ScalarValue::Int32(Some(1)),
+ ScalarValue::Int32(Some(2)),
+ ScalarValue::Int32(Some(3)),
+ ScalarValue::Int32(Some(4)),
+ ScalarValue::Int32(Some(5)),
+ ScalarValue::Int32(Some(7)),
+ ScalarValue::Int32(Some(8)),
+ ]),
+ DataType::Int32,
+ );
+
+ check_merge_distinct_array_agg(col1, col2, out, DataType::Int32)
+ }
+
#[test]
fn distinct_array_agg_nested() -> Result<()> {
// [[1, 2, 3], [4, 5]]
@@ -296,4 +362,66 @@ mod tests {
))),
)
}
+
+ #[test]
+ fn merge_distinct_array_agg_nested() -> Result<()> {
+ // [[1, 2], [3, 4]]
+ let l1 = ScalarValue::new_list(
+ Some(vec![
+ ScalarValue::new_list(
+ Some(vec![ScalarValue::from(1i32),
ScalarValue::from(2i32)]),
+ DataType::Int32,
+ ),
+ ScalarValue::new_list(
+ Some(vec![ScalarValue::from(3i32),
ScalarValue::from(4i32)]),
+ DataType::Int32,
+ ),
+ ]),
+ DataType::List(Arc::new(Field::new("item", DataType::Int32,
true))),
+ );
+
+ // [[5]]
+ let l2 = ScalarValue::new_list(
+ Some(vec![ScalarValue::new_list(
+ Some(vec![ScalarValue::from(5i32)]),
+ DataType::Int32,
+ )]),
+ DataType::List(Arc::new(Field::new("item", DataType::Int32,
true))),
+ );
+
+ // [[6, 7], [8]]
+ let l3 = ScalarValue::new_list(
+ Some(vec![
+ ScalarValue::new_list(
+ Some(vec![ScalarValue::from(6i32),
ScalarValue::from(7i32)]),
+ DataType::Int32,
+ ),
+ ScalarValue::new_list(
+ Some(vec![ScalarValue::from(8i32)]),
+ DataType::Int32,
+ ),
+ ]),
+ DataType::List(Arc::new(Field::new("item", DataType::Int32,
true))),
+ );
+
+ let expected = ScalarValue::new_list(
+ Some(vec![l1.clone(), l2.clone(), l3.clone()]),
+ DataType::List(Arc::new(Field::new("item", DataType::Int32,
true))),
+ );
+
+ // Duplicate l1 in the input array and check that it is deduped in the
output.
+ let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2]).unwrap();
+ let input2 = ScalarValue::iter_to_array(vec![l1, l3]).unwrap();
+
+ check_merge_distinct_array_agg(
+ input1,
+ input2,
+ expected,
+ DataType::List(Arc::new(Field::new_list(
+ "item",
+ Field::new("item", DataType::Int32, true),
+ true,
+ ))),
+ )
+ }
}
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index e881acf575..c89362a2f2 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -1271,14 +1271,31 @@ NULL 4 29 1.260869565217 123 -117 23
NULL 5 -194 -13.857142857143 118 -101 14
NULL NULL 781 7.81 125 -117 100
-# TODO this querys output is non determinisitic (the order of the elements
-# differs run to run
+# TODO: array_agg_distinct output is non-determinisitic -- rewrite with
array_sort(list_sort)
+# unnest is also not available, so manually unnesting via CROSS JOIN
+# additional count(1) forces array_agg_distinct instead of array_agg over
aggregated by c2 data
#
# csv_query_array_agg_distinct
-# query T
-# SELECT array_agg(distinct c2) FROM aggregate_test_100
-# ----
-# [4, 2, 3, 5, 1]
+query III
+WITH indices AS (
+ SELECT 1 AS idx UNION ALL
+ SELECT 2 AS idx UNION ALL
+ SELECT 3 AS idx UNION ALL
+ SELECT 4 AS idx UNION ALL
+ SELECT 5 AS idx
+)
+SELECT data.arr[indices.idx] as element, array_length(data.arr) as array_len,
dummy
+FROM (
+ SELECT array_agg(distinct c2) as arr, count(1) as dummy FROM
aggregate_test_100
+) data
+ CROSS JOIN indices
+ORDER BY 1
+----
+1 5 100
+2 5 100
+3 5 100
+4 5 100
+5 5 100
# aggregate_time_min_and_max
query TT