This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new fc84a639fc Support List for Array aggregate order and distinct (#9234)
fc84a639fc is described below
commit fc84a639fca7716e529384c0e919fb90b75139da
Author: Jay Zhan <[email protected]>
AuthorDate: Tue Feb 20 15:41:42 2024 +0800
Support List for Array aggregate order and distinct (#9234)
* first draft
Signed-off-by: jayzhan211 <[email protected]>
* fix convert_first_level_array_to_scalar_vec
Signed-off-by: jayzhan211 <[email protected]>
* add doc
Signed-off-by: jayzhan211 <[email protected]>
* fix nth
Signed-off-by: jayzhan211 <[email protected]>
* support distinct
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* rm convert_first_level_array_to_scalar_vec
Signed-off-by: jayzhan211 <[email protected]>
* add doc and assertion
Signed-off-by: jayzhan211 <[email protected]>
* fix doc
Signed-off-by: jayzhan211 <[email protected]>
* fix doc
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion/common/src/scalar/mod.rs | 74 ++++++---
datafusion/core/tests/sql/aggregates.rs | 2 +-
.../src/aggregate/array_agg_distinct.rs | 171 ++++++++++-----------
.../src/aggregate/count_distinct/mod.rs | 10 +-
datafusion/sqllogictest/test_files/aggregate.slt | 57 ++++++-
.../sqllogictest/test_files/aggregates_topk.slt | 3 +
6 files changed, 194 insertions(+), 123 deletions(-)
diff --git a/datafusion/common/src/scalar/mod.rs
b/datafusion/common/src/scalar/mod.rs
index 7e53415090..80dddd8a83 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -52,7 +52,6 @@ use arrow::{
UInt16Type, UInt32Type, UInt64Type, UInt8Type,
DECIMAL128_MAX_PRECISION,
},
};
-use arrow_array::cast::as_list_array;
use arrow_array::{ArrowNativeTypeOp, Scalar};
pub use struct_builder::ScalarStructBuilder;
@@ -2138,28 +2137,67 @@ impl ScalarValue {
/// Retrieve ScalarValue for each row in `array`
///
- /// Example
+ /// Example 1: Array (ScalarValue::Int32)
/// ```
/// use datafusion_common::ScalarValue;
/// use arrow::array::ListArray;
/// use arrow::datatypes::{DataType, Int32Type};
///
+ /// // Equivalent to [[1,2,3], [4,5]]
/// let list_arr = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
/// Some(vec![Some(1), Some(2), Some(3)]),
- /// None,
/// Some(vec![Some(4), Some(5)])
/// ]);
///
+ /// // Convert the array into Scalar Values for each row
/// let scalar_vec =
ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap();
///
/// let expected = vec![
- /// vec![
+ /// vec![
/// ScalarValue::Int32(Some(1)),
/// ScalarValue::Int32(Some(2)),
/// ScalarValue::Int32(Some(3)),
+ /// ],
+ /// vec![
+ /// ScalarValue::Int32(Some(4)),
+ /// ScalarValue::Int32(Some(5)),
+ /// ],
+ /// ];
+ ///
+ /// assert_eq!(scalar_vec, expected);
+ /// ```
+ ///
+ /// Example 2: Nested array (ScalarValue::List)
+ /// ```
+ /// use datafusion_common::ScalarValue;
+ /// use arrow::array::ListArray;
+ /// use arrow::datatypes::{DataType, Int32Type};
+ /// use datafusion_common::utils::array_into_list_array;
+ /// use std::sync::Arc;
+ ///
+ /// let list_arr = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ /// Some(vec![Some(1), Some(2), Some(3)]),
+ /// Some(vec![Some(4), Some(5)])
+ /// ]);
+ ///
+ /// // Wrap into another layer of list, we got nested array as [ [[1,2,3],
[4,5]] ]
+ /// let list_arr = array_into_list_array(Arc::new(list_arr));
+ ///
+ /// // Convert the array into Scalar Values for each row, we got 1D arrays
in this example
+ /// let scalar_vec =
ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap();
+ ///
+ /// let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ /// Some(vec![Some(1), Some(2), Some(3)]),
+ /// ]);
+ /// let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ /// Some(vec![Some(4), Some(5)]),
+ /// ]);
+ ///
+ /// let expected = vec![
+ /// vec![
+ /// ScalarValue::List(Arc::new(l1)),
+ /// ScalarValue::List(Arc::new(l2)),
/// ],
- /// vec![],
- /// vec![ScalarValue::Int32(Some(4)), ScalarValue::Int32(Some(5))]
/// ];
///
/// assert_eq!(scalar_vec, expected);
@@ -2168,27 +2206,13 @@ impl ScalarValue {
let mut scalars = Vec::with_capacity(array.len());
for index in 0..array.len() {
- let scalar_values = match array.data_type() {
- DataType::List(_) => {
- let list_array = as_list_array(array);
- match list_array.is_null(index) {
- true => Vec::new(),
- false => {
- let nested_array = list_array.value(index);
-
ScalarValue::convert_array_to_scalar_vec(&nested_array)?
- .into_iter()
- .flatten()
- .collect()
- }
- }
- }
- _ => {
- let scalar = ScalarValue::try_from_array(array, index)?;
- vec![scalar]
- }
- };
+ let nested_array = array.as_list::<i32>().value(index);
+ let scalar_values = (0..nested_array.len())
+ .map(|i| ScalarValue::try_from_array(&nested_array, i))
+ .collect::<Result<Vec<_>>>()?;
scalars.push(scalar_values);
}
+
Ok(scalars)
}
diff --git a/datafusion/core/tests/sql/aggregates.rs
b/datafusion/core/tests/sql/aggregates.rs
index af6d0d5f4e..84b791a3de 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -44,9 +44,9 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
// We should have 1 row containing a list
let column = actual[0].column(0);
assert_eq!(column.len(), 1);
-
let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&column)?;
let mut scalars = scalar_vec[0].clone();
+
// workaround lack of Ord of ScalarValue
let cmp = |a: &ScalarValue, b: &ScalarValue| {
a.partial_cmp(b).expect("Can compare ScalarValues")
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
index 2e9df477d5..b073b00578 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
@@ -24,6 +24,7 @@ use std::sync::Arc;
use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};
+use arrow_array::cast::AsArray;
use crate::aggregate::utils::down_cast_any_ref;
use crate::expressions::format_state_name;
@@ -138,9 +139,10 @@ impl Accumulator for DistinctArrayAggAccumulator {
assert_eq!(values.len(), 1, "batch input should only include 1
column!");
let array = &values[0];
- let scalar_vec = ScalarValue::convert_array_to_scalar_vec(array)?;
- for scalars in scalar_vec {
- self.values.extend(scalars);
+
+ for i in 0..array.len() {
+ let scalar = ScalarValue::try_from_array(&array, i)?;
+ self.values.insert(scalar);
}
Ok(())
@@ -151,7 +153,12 @@ impl Accumulator for DistinctArrayAggAccumulator {
return Ok(());
}
- self.update_batch(states)
+ let array = &states[0];
+
+ assert_eq!(array.len(), 1, "state array should only include 1 row!");
+ // Unwrap outer ListArray then do update batch
+ let inner_array = array.as_list::<i32>().value(0);
+ self.update_batch(&[inner_array])
}
fn evaluate(&mut self) -> Result<ScalarValue> {
@@ -181,47 +188,55 @@ mod tests {
use arrow_array::Array;
use arrow_array::ListArray;
use arrow_buffer::OffsetBuffer;
- use datafusion_common::utils::array_into_list_array;
use datafusion_common::{internal_err, DataFusionError};
- // arrow::compute::sort cann't sort ListArray directly, so we need to sort
the inner primitive array and wrap it back into ListArray.
- fn sort_list_inner(arr: ScalarValue) -> ScalarValue {
- let arr = match arr {
- ScalarValue::List(arr) => arr.value(0),
- _ => {
- panic!("Expected ScalarValue::List, got {:?}", arr)
- }
- };
+ // arrow::compute::sort can't sort nested ListArray directly, so we
compare the scalar values pair-wise.
+ fn compare_list_contents(
+ expected: Vec<ScalarValue>,
+ actual: ScalarValue,
+ ) -> Result<()> {
+ let array = actual.to_array()?;
+ let list_array = array.as_list::<i32>();
+ let inner_array = list_array.value(0);
+ let mut actual_scalars = vec![];
+ for index in 0..inner_array.len() {
+ let sv = ScalarValue::try_from_array(&inner_array, index)?;
+ actual_scalars.push(sv);
+ }
- let arr = arrow::compute::sort(&arr, None).unwrap();
- let list_arr = array_into_list_array(arr);
- ScalarValue::List(Arc::new(list_arr))
- }
+ if actual_scalars.len() != expected.len() {
+ return internal_err!(
+ "Expected and actual list lengths differ: expected={},
actual={}",
+ expected.len(),
+ actual_scalars.len()
+ );
+ }
- fn compare_list_contents(expected: ScalarValue, actual: ScalarValue) ->
Result<()> {
- let actual = sort_list_inner(actual);
-
- match (&expected, &actual) {
- (ScalarValue::List(arr1), ScalarValue::List(arr2)) => {
- if arr1.eq(arr2) {
- Ok(())
- } else {
- internal_err!(
- "Actual value {:?} not found in expected values {:?}",
- actual,
- expected
- )
+ let mut seen = vec![false; expected.len()];
+ for v in expected {
+ let mut found = false;
+ for (i, sv) in actual_scalars.iter().enumerate() {
+ if sv == &v {
+ seen[i] = true;
+ found = true;
+ break;
}
}
- _ => {
- internal_err!("Expected scalar lists as inputs")
+ if !found {
+ return internal_err!(
+ "Expected value {:?} not found in actual values {:?}",
+ v,
+ actual_scalars
+ );
}
}
+
+ Ok(())
}
fn check_distinct_array_agg(
input: ArrayRef,
- expected: ScalarValue,
+ expected: Vec<ScalarValue>,
datatype: DataType,
) -> Result<()> {
let schema = Schema::new(vec![Field::new("a", datatype.clone(),
false)]);
@@ -234,14 +249,13 @@ mod tests {
true,
));
let actual = aggregate(&batch, agg)?;
-
compare_list_contents(expected, actual)
}
fn check_merge_distinct_array_agg(
input1: ArrayRef,
input2: ArrayRef,
- expected: ScalarValue,
+ expected: Vec<ScalarValue>,
datatype: DataType,
) -> Result<()> {
let schema = Schema::new(vec![Field::new("a", datatype.clone(),
false)]);
@@ -262,23 +276,20 @@ mod tests {
accum1.merge_batch(&[array])?;
let actual = accum1.evaluate()?;
-
compare_list_contents(expected, actual)
}
#[test]
fn distinct_array_agg_i32() -> Result<()> {
let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2]));
- let expected =
- ScalarValue::List(Arc::new(
- ListArray::from_iter_primitive::<Int32Type, _,
_>(vec![Some(vec![
- Some(1),
- Some(2),
- Some(4),
- Some(5),
- Some(7),
- ])]),
- ));
+
+ let expected = vec![
+ ScalarValue::Int32(Some(1)),
+ ScalarValue::Int32(Some(2)),
+ ScalarValue::Int32(Some(4)),
+ ScalarValue::Int32(Some(5)),
+ ScalarValue::Int32(Some(7)),
+ ];
check_distinct_array_agg(col, expected, DataType::Int32)
}
@@ -288,18 +299,15 @@ mod tests {
let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5,
2]));
let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4]));
- let expected =
- ScalarValue::List(Arc::new(
- ListArray::from_iter_primitive::<Int32Type, _,
_>(vec![Some(vec![
- Some(1),
- Some(2),
- Some(3),
- Some(4),
- Some(5),
- Some(7),
- Some(8),
- ])]),
- ));
+ let expected = vec![
+ ScalarValue::Int32(Some(1)),
+ ScalarValue::Int32(Some(2)),
+ ScalarValue::Int32(Some(3)),
+ ScalarValue::Int32(Some(4)),
+ ScalarValue::Int32(Some(5)),
+ ScalarValue::Int32(Some(7)),
+ ScalarValue::Int32(Some(8)),
+ ];
check_merge_distinct_array_agg(col1, col2, expected, DataType::Int32)
}
@@ -351,23 +359,16 @@ mod tests {
let l2 = ScalarValue::List(Arc::new(l2));
let l3 = ScalarValue::List(Arc::new(l3));
- // Duplicate l1 in the input array and check that it is deduped in the
output.
- let array = ScalarValue::iter_to_array(vec![l1.clone(), l2, l3,
l1]).unwrap();
-
- let expected =
- ScalarValue::List(Arc::new(
- ListArray::from_iter_primitive::<Int32Type, _,
_>(vec![Some(vec![
- Some(1),
- Some(2),
- Some(3),
- Some(4),
- Some(5),
- Some(6),
- Some(7),
- Some(8),
- Some(9),
- ])]),
- ));
+ // Duplicate l1 and l3 in the input array and check that it is deduped
in the output.
+ let array = ScalarValue::iter_to_array(vec![
+ l1.clone(),
+ l2.clone(),
+ l3.clone(),
+ l3.clone(),
+ l1.clone(),
+ ])
+ .unwrap();
+ let expected = vec![l1, l2, l3];
check_distinct_array_agg(
array,
@@ -426,22 +427,10 @@ mod tests {
let l3 = ScalarValue::List(Arc::new(l3));
// Duplicate l1 in the input array and check that it is deduped in the
output.
- let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2]).unwrap();
- let input2 = ScalarValue::iter_to_array(vec![l1, l3]).unwrap();
-
- let expected =
- ScalarValue::List(Arc::new(
- ListArray::from_iter_primitive::<Int32Type, _,
_>(vec![Some(vec![
- Some(1),
- Some(2),
- Some(3),
- Some(4),
- Some(5),
- Some(6),
- Some(7),
- Some(8),
- ])]),
- ));
+ let input1 = ScalarValue::iter_to_array(vec![l1.clone(),
l2.clone()]).unwrap();
+ let input2 = ScalarValue::iter_to_array(vec![l1.clone(),
l3.clone()]).unwrap();
+
+ let expected = vec![l1, l2, l3];
check_merge_distinct_array_agg(input1, input2, expected,
DataType::Int32)
}
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
index 52afd82d03..71782fcc5f 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
@@ -26,6 +26,7 @@ use std::sync::Arc;
use ahash::RandomState;
use arrow::array::{Array, ArrayRef};
use arrow::datatypes::{DataType, Field, TimeUnit};
+use arrow_array::cast::AsArray;
use arrow_array::types::{
Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type,
Float32Type,
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
Time32MillisecondType,
@@ -250,11 +251,10 @@ impl Accumulator for DistinctCountAccumulator {
return Ok(());
}
assert_eq!(states.len(), 1, "array_agg states must be singleton!");
- let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
- for scalars in scalar_vec.into_iter() {
- self.values.extend(scalars);
- }
- Ok(())
+ let array = &states[0];
+ let list_array = array.as_list::<i32>();
+ let inner_array = list_array.value(0);
+ self.update_batch(&[inner_array])
}
fn evaluate(&mut self) -> Result<ScalarValue> {
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index fdd70a80ac..109c64f060 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -139,6 +139,61 @@ AggregateExec: mode=Final, gby=[],
aggr=[ARRAY_AGG(agg_order.c1)]
--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
----------CsvExec: file_groups={1 group:
[[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_agg_multi_order.csv]]},
projection=[c1, c2, c3], has_header=true
+# test array_agg_order with list data type
+statement ok
+CREATE TABLE array_agg_order_list_table AS VALUES
+ ('w', 2, [1,2,3], 10),
+ ('w', 1, [9,5,2], 20),
+ ('w', 1, [3,2,5], 30),
+ ('b', 2, [4,5,6], 20),
+ ('b', 1, [7,8,9], 30)
+;
+
+query T? rowsort
+select column1, array_agg(column3 order by column2, column4 desc) from
array_agg_order_list_table group by column1;
+----
+b [[7, 8, 9], [4, 5, 6]]
+w [[3, 2, 5], [9, 5, 2], [1, 2, 3]]
+
+query T?? rowsort
+select column1, first_value(column3 order by column2, column4 desc),
last_value(column3 order by column2, column4 desc) from
array_agg_order_list_table group by column1;
+----
+b [7, 8, 9] [4, 5, 6]
+w [3, 2, 5] [1, 2, 3]
+
+query T? rowsort
+select column1, nth_value(column3, 2 order by column2, column4 desc) from
array_agg_order_list_table group by column1;
+----
+b [4, 5, 6]
+w [9, 5, 2]
+
+statement ok
+drop table array_agg_order_list_table;
+
+# test array_agg_distinct with list data type
+statement ok
+CREATE TABLE array_agg_distinct_list_table AS VALUES
+ ('w', [0,1]),
+ ('w', [0,1]),
+ ('w', [1,0]),
+ ('b', [1,0]),
+ ('b', [1,0]),
+ ('b', [1,0]),
+ ('b', [0,1])
+;
+
+# Apply array_sort to have determinisitic result, higher dimension nested
array also works but not for array sort,
+# so they are covered in
`datafusion/physical-expr/src/aggregate/array_agg_distinct.rs`
+query ??
+select array_sort(c1), array_sort(c2) from (
+ select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2
from array_agg_distinct_list_table
+);
+----
+[b, w] [[0, 1], [1, 0]]
+
+statement ok
+drop table array_agg_distinct_list_table;
+
statement error This feature is not implemented: LIMIT not supported in
ARRAY_AGG: 1
SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100
@@ -3259,4 +3314,4 @@ SELECT 0 AS "t.a" FROM t HAVING MAX(t.a) = 0;
----
statement ok
-DROP TABLE t;
\ No newline at end of file
+DROP TABLE t;
diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt
b/datafusion/sqllogictest/test_files/aggregates_topk.slt
index bd8f00e041..3f139ede8c 100644
--- a/datafusion/sqllogictest/test_files/aggregates_topk.slt
+++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt
@@ -212,3 +212,6 @@ b 0 -2
a -1 -1
NULL 0 0
c 1 2
+
+statement ok
+drop table traces;