This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 195f8256ce feat: support `LargeList` in `flatten` (#9110)
195f8256ce is described below
commit 195f8256ce5414844c297601f145757f87532edb
Author: Alex Huang <[email protected]>
AuthorDate: Mon Feb 5 20:25:03 2024 +0800
feat: support `LargeList` in `flatten` (#9110)
* support FixedSizeList in flatten
* Refactor flatten function and add test cases
* remove redundant tests
---
datafusion/expr/src/built_in_function.rs | 9 ++--
datafusion/physical-expr/src/array_expressions.rs | 52 ++++++++++++++++-------
datafusion/sqllogictest/test_files/array.slt | 42 ++++++++++++++++--
3 files changed, 78 insertions(+), 25 deletions(-)
diff --git a/datafusion/expr/src/built_in_function.rs
b/datafusion/expr/src/built_in_function.rs
index b1b74c1628..4cdf0c4a11 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -547,11 +547,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Flatten => {
fn get_base_type(data_type: &DataType) -> Result<DataType> {
match data_type {
- DataType::List(field) => match field.data_type() {
- DataType::List(_) =>
get_base_type(field.data_type()),
- _ => Ok(data_type.to_owned()),
- },
- _ => internal_err!("Not reachable, data_type should be
List"),
+ DataType::List(field) if matches!(field.data_type(),
DataType::List(_)) => get_base_type(field.data_type()),
+ DataType::LargeList(field) if
matches!(field.data_type(), DataType::LargeList(_)) =>
get_base_type(field.data_type()),
+ DataType::Null | DataType::List(_) |
DataType::LargeList(_) => Ok(data_type.to_owned()),
+ _ => internal_err!("Not reachable, data_type should be
List or LargeList"),
}
}
diff --git a/datafusion/physical-expr/src/array_expressions.rs
b/datafusion/physical-expr/src/array_expressions.rs
index 844dae0917..0709e66a35 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -2246,38 +2246,41 @@ fn generic_list_cardinality<O: OffsetSizeTrait>(
}
// Create new offsets that are euqiavlent to `flatten` the array.
-fn get_offsets_for_flatten(
- offsets: OffsetBuffer<i32>,
- indexes: OffsetBuffer<i32>,
-) -> OffsetBuffer<i32> {
+fn get_offsets_for_flatten<O: OffsetSizeTrait>(
+ offsets: OffsetBuffer<O>,
+ indexes: OffsetBuffer<O>,
+) -> OffsetBuffer<O> {
let buffer = offsets.into_inner();
- let offsets: Vec<i32> = indexes.iter().map(|i| buffer[*i as
usize]).collect();
+ let offsets: Vec<O> = indexes
+ .iter()
+ .map(|i| buffer[i.to_usize().unwrap()])
+ .collect();
OffsetBuffer::new(offsets.into())
}
-fn flatten_internal(
- array: &dyn Array,
- indexes: Option<OffsetBuffer<i32>>,
-) -> Result<ListArray> {
- let list_arr = as_list_array(array)?;
+fn flatten_internal<O: OffsetSizeTrait>(
+ list_arr: GenericListArray<O>,
+ indexes: Option<OffsetBuffer<O>>,
+) -> Result<GenericListArray<O>> {
let (field, offsets, values, _) = list_arr.clone().into_parts();
let data_type = field.data_type();
match data_type {
// Recursively get the base offsets for flattened array
- DataType::List(_) => {
+ DataType::List(_) | DataType::LargeList(_) => {
+ let sub_list = as_generic_list_array::<O>(&values)?;
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
- flatten_internal(&values, Some(offsets))
+ flatten_internal::<O>(sub_list.clone(), Some(offsets))
} else {
- flatten_internal(&values, Some(offsets))
+ flatten_internal::<O>(sub_list.clone(), Some(offsets))
}
}
// Reach the base level, create a new list array
_ => {
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
- let list_arr = ListArray::new(field, offsets, values, None);
+ let list_arr = GenericListArray::<O>::new(field, offsets,
values, None);
Ok(list_arr)
} else {
Ok(list_arr.clone())
@@ -2292,8 +2295,25 @@ pub fn flatten(args: &[ArrayRef]) -> Result<ArrayRef> {
return exec_err!("flatten expects one argument");
}
- let flattened_array = flatten_internal(&args[0], None)?;
- Ok(Arc::new(flattened_array) as ArrayRef)
+ let array_type = args[0].data_type();
+ match array_type {
+ DataType::List(_) => {
+ let list_arr = as_list_array(&args[0])?;
+ let flattened_array = flatten_internal::<i32>(list_arr.clone(),
None)?;
+ Ok(Arc::new(flattened_array) as ArrayRef)
+ }
+ DataType::LargeList(_) => {
+ let list_arr = as_large_list_array(&args[0])?;
+ let flattened_array = flatten_internal::<i64>(list_arr.clone(),
None)?;
+ Ok(Arc::new(flattened_array) as ArrayRef)
+ }
+ DataType::Null => Ok(args[0].clone()),
+ _ => {
+ exec_err!("flatten does not support type '{array_type:?}'")
+ }
+ }
+
+ // Ok(Arc::new(flattened_array) as ArrayRef)
}
/// Dispatch array length computation based on the offset type.
diff --git a/datafusion/sqllogictest/test_files/array.slt
b/datafusion/sqllogictest/test_files/array.slt
index 4fdc428d7a..36a656eb7f 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -202,6 +202,17 @@ AS VALUES
(make_array([1, 2], [3, 4], [5, 6]), make_array([[8]]),
make_array([[[1,2]]], [[[3]]]), make_array([1.0, 2.0], [3.0, 4.0], [5.0, 6.0]))
;
+statement ok
+CREATE TABLE large_flatten_table
+AS
+ SELECT
+ arrow_cast(column1, 'LargeList(LargeList(Int64))') AS column1,
+ arrow_cast(column2, 'LargeList(LargeList(LargeList(Int64)))') AS column2,
+ arrow_cast(column3, 'LargeList(LargeList(LargeList(LargeList(Int64))))')
AS column3,
+ arrow_cast(column4, 'LargeList(LargeList(Float64))') AS column4
+ FROM flatten_table
+;
+
statement ok
CREATE TABLE array_has_table_1D
AS VALUES
@@ -5345,6 +5356,13 @@ select array_concat(column1, [7]) from arrays_values_v2;
[7]
# flatten
+# follow DuckDB
+query ?
+select flatten(NULL);
+----
+NULL
+
+# flatten with scalar values #1
query ???
select flatten(make_array(1, 2, 1, 3, 2)),
flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))),
@@ -5352,12 +5370,14 @@ select flatten(make_array(1, 2, 1, 3, 2)),
----
[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4]
-query ????
-select column1, column2, column3, column4 from flatten_table;
+query ???
+select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'LargeList(Int64)')),
+ flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null,
5)), 'LargeList(LargeList(Int64))')),
+ flatten(arrow_cast(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]),
'LargeList(LargeList(LargeList(Float64)))'));
----
-[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0],
[2.1, 2.2], [3.2, 3.3, 3.4]]
-[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0,
4.0], [5.0, 6.0]]
+[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4]
+# flatten with column values
query ????
select flatten(column1),
flatten(column2),
@@ -5368,6 +5388,17 @@ from flatten_table;
[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
+query ????
+select flatten(column1),
+ flatten(column2),
+ flatten(column3),
+ flatten(column4)
+from large_flatten_table;
+----
+[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
+[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
+
+## empty
# empty scalar function #1
query B
select empty(make_array(1));
@@ -5746,6 +5777,9 @@ drop table
fixed_size_nested_arrays_with_repeating_elements;
statement ok
drop table flatten_table;
+statement ok
+drop table large_flatten_table;
+
statement ok
drop table arrays_values_without_nulls;