This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 0624378160 feat: array functions treat an array as an element (#6986)
0624378160 is described below

commit 0624378160b1d19de12a29b1374dea1930c9faaa
Author: Igor Izvekov <[email protected]>
AuthorDate: Wed Jul 19 00:13:45 2023 +0300

    feat: array functions treat an array as an element (#6986)
---
 .../core/tests/sqllogictests/test_files/array.slt  | 196 +++++++++++++++------
 datafusion/physical-expr/src/array_expressions.rs  |  59 ++++++-
 2 files changed, 205 insertions(+), 50 deletions(-)

diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt 
b/datafusion/core/tests/sqllogictests/test_files/array.slt
index f0f50ccc93..1e9b32414b 100644
--- a/datafusion/core/tests/sqllogictests/test_files/array.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/array.slt
@@ -55,6 +55,13 @@ AS VALUES
   (make_array(make_array(15, 16),make_array(NULL, 18)), make_array(16.6, 17.7, 
18.8), NULL)
 ;
 
+statement ok
+CREATE TABLE nested_arrays
+AS VALUES
+  (make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), 
make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), make_array(7, 
8, 9), 2, make_array([[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]])),
+  (make_array(make_array(4, 5, 6), make_array(10, 11, 12), make_array(4, 9, 
8), make_array(7, 8, 9), make_array(10, 11, 12), make_array(1, 8, 7)), 
make_array(10, 11, 12), 3, make_array([[11, 12, 13], [14, 15, 16]], [[17, 18, 
19], [20, 21, 22]]))
+;
+
 statement ok
 CREATE TABLE arrays_values
 AS VALUES
@@ -100,6 +107,13 @@ NULL [13.3, 14.4, 15.5] [a, m, e, t]
 [[11, 12], [13, 14]] NULL [,]
 [[15, 16], [, 18]] [16.6, 17.7, 18.8] NULL
 
+# nested_arrays table
+query ??I?
+select column1, column2, column3, column4 from nested_arrays;
+----
+[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [7, 8, 9] 2 
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
+[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [10, 
11, 12] 3 [[[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]]]
+
 # values table
 query IIIRT
 select a, b, c, d, e from values;
@@ -292,7 +306,13 @@ select array_append(make_array(1, 2, 3), 4), 
array_append(make_array(1.0, 2.0, 3
 ----
 [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o]
 
-# array_append with columns
+# array_append scalar function #4 (element is list)
+query ???
+select array_append(make_array([1], [2], [3]), make_array(4)), 
array_append(make_array([1.0], [2.0], [3.0]), make_array(4.0)), 
array_append(make_array(['h'], ['e'], ['l'], ['l']), make_array('o'));
+----
+[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]]
+
+# array_append with columns #1
 query ?
 select array_append(column1, column2) from arrays_values;
 ----
@@ -305,7 +325,14 @@ select array_append(column1, column2) from arrays_values;
 [51, 52, , 54, 55, 56, 57, 58, 59, 60, 55]
 [61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66]
 
-# array_append with columns and scalars
+# array_append with columns #2 (element is list)
+query ?
+select array_append(column1, column2) from nested_arrays;
+----
+[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]]
+[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 
11, 12]]
+
+# array_append with columns and scalars #1
 query ??
 select array_append(column2, 100.1), array_append(column3, '.') from arrays;
 ----
@@ -317,6 +344,13 @@ select array_append(column2, 100.1), array_append(column3, 
'.') from arrays;
 [100.1] [,, .]
 [16.6, 17.7, 18.8, 100.1] [.]
 
+# array_append with columns and scalars #2
+query ??
+select array_append(column1, make_array(1, 11, 111)), 
array_append(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), column2) 
from nested_arrays;
+----
+[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 
111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]]
+[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 
11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]]
+
 ## array_prepend
 
 # array_prepend scalar function #1
@@ -337,7 +371,13 @@ select array_prepend(1, make_array(2, 3, 4)), 
array_prepend(1.0, make_array(2.0,
 ----
 [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o]
 
-# array_prepend with columns
+# array_prepend scalar function #4 (element is list)
+query ???
+select array_prepend(make_array(1), make_array(make_array(2), make_array(3), 
make_array(4))), array_prepend(make_array(1.0), make_array([2.0], [3.0], 
[4.0])), array_prepend(make_array('h'), make_array(['e'], ['l'], ['l'], ['o']));
+----
+[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]]
+
+# array_prepend with columns #1
 query ?
 select array_prepend(column2, column1) from arrays_values;
 ----
@@ -350,7 +390,14 @@ select array_prepend(column2, column1) from arrays_values;
 [55, 51, 52, , 54, 55, 56, 57, 58, 59, 60]
 [66, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70]
 
-# array_prepend with columns and scalars
+# array_prepend with columns #2 (element is list)
+query ?
+select array_prepend(column2, column1) from nested_arrays;
+----
+[[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]]
+[[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], 
[1, 8, 7]]
+
+# array_prepend with columns and scalars #1
 query ??
 select array_prepend(100.1, column2), array_prepend('.', column3) from arrays;
 ----
@@ -362,6 +409,13 @@ select array_prepend(100.1, column2), array_prepend('.', 
column3) from arrays;
 [100.1] [., ,]
 [100.1, 16.6, 17.7, 18.8] [.]
 
+# array_prepend with columns and scalars #2 (element is list)
+query ??
+select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, 
make_array(make_array(1, 2, 3), make_array(11, 12, 13))) from nested_arrays;
+----
+[[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 
6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]]
+[[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], 
[1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]]
+
 ## array_fill
 
 # array_fill scalar function #1
@@ -473,19 +527,6 @@ select array_concat(make_array(column2), 
make_array(column3)) from arrays_values
 
 # array_concat column-wise #4
 query ?
-select array_concat(column1, column2) from arrays_values;
-----
-[, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1]
-[11, 12, 13, 14, 15, 16, 17, 18, , 20, 12]
-[21, 22, 23, , 25, 26, 27, 28, 29, 30, 23]
-[31, 32, 33, 34, 35, , 37, 38, 39, 40, 34]
-[44]
-[41, 42, 43, 44, 45, 46, 47, 48, 49, 50, ]
-[51, 52, , 54, 55, 56, 57, 58, 59, 60, 55]
-[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66]
-
-# array_concat column-wise #5
-query ?
 select array_concat(make_array(column2), make_array(0)) from arrays_values;
 ----
 [1, 0]
@@ -497,7 +538,7 @@ select array_concat(make_array(column2), make_array(0)) 
from arrays_values;
 [55, 0]
 [66, 0]
 
-# array_concat column-wise #6
+# array_concat column-wise #5
 query ???
 select array_concat(column1, column1), array_concat(column2, column2), 
array_concat(column3, column3) from arrays;
 ----
@@ -509,7 +550,7 @@ NULL [13.3, 14.4, 15.5, 13.3, 14.4, 15.5] [a, m, e, t, a, 
m, e, t]
 [[11, 12], [13, 14], [11, 12], [13, 14]] NULL [,, ,]
 [[15, 16], [, 18], [15, 16], [, 18]] [16.6, 17.7, 18.8, 16.6, 17.7, 18.8] NULL
 
-# array_concat column-wise #7
+# array_concat column-wise #6
 query ??
 select array_concat(column1, make_array(make_array(1, 2), make_array(3, 4))), 
array_concat(column2, make_array(1.1, 2.2, 3.3)) from arrays;
 ----
@@ -521,7 +562,7 @@ select array_concat(column1, make_array(make_array(1, 2), 
make_array(3, 4))), ar
 [[11, 12], [13, 14], [1, 2], [3, 4]] [1.1, 2.2, 3.3]
 [[15, 16], [, 18], [1, 2], [3, 4]] [16.6, 17.7, 18.8, 1.1, 2.2, 3.3]
 
-# array_concat column-wise #8
+# array_concat column-wise #7
 query ?
 select array_concat(column3, make_array('.', '.', '.')) from arrays;
 ----
@@ -543,7 +584,7 @@ select array_concat(column3, make_array('.', '.', '.')) 
from arrays;
 # [11, 12] NULL NULL NULL
 # NULL NULL NULL NULL
 
-# array_concat column-wise #9 (1D + 1D)
+# array_concat column-wise #8 (1D + 1D)
 query ?
 select array_concat(column1, column2) from arrays_values_v2;
 ----
@@ -554,28 +595,36 @@ select array_concat(column1, column2) from 
arrays_values_v2;
 [11, 12]
 NULL
 
-# TODO: Concat columns with different dimensions fails
-# array_concat column-wise #10 (1D + 2D)
-# query error DataFusion error: Arrow error: Invalid argument error: column 
types must match schema types, expected List\(Field \{ name: "item", data_type: 
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) 
but found List\(Field \{ name: "item", data_type: List\(Field \{ name: "item", 
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: 
\{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} 
\}\) at column index 0
-# select array_concat(make_array(column3), column4) from arrays_values_v2;
+# array_concat column-wise #9 (2D + 1D)
+query ?
+select array_concat(column4, make_array(column3)) from arrays_values_v2;
+----
+[[30, 40, 50], [12]]
+[[, , 60], [13]]
+[[70, , ], [14]]
+[[]]
+[[]]
+[[]]
+
+# array_concat column-wise #10 (3D + 2D + 1D)
+query ?
+select array_concat(column4, column1, column2) from nested_arrays;
+----
+[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]], [[1, 2, 3], [2, 9, 1], [7, 
8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]], [[7, 8, 9]]]
+[[[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]], [[4, 5, 6], [10, 
11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]], [[10, 11, 12]]]
 
-# array_concat column-wise #11 (1D + Integers)
+# array_concat column-wise #11 (2D + 1D)
 query ?
-select array_concat(column2, column3) from arrays_values_v2;
+select array_concat(column4, column1) from arrays_values_v2;
 ----
-[4, 5, , 12]
-[7, , 8, 13]
-[14]
-[, 21, ]
+[[30, 40, 50], [, 2, 3]]
+[[, , 60], ]
+[[70, , ], [9, , 10]]
+[[, 1]]
+[[11, 12]]
 []
-[]
-
-# TODO: Panic at 'range end index 3 out of range for slice of length 2'
-# array_concat column-wise #12 (2D + 1D)
-# query
-# select array_concat(column4, column1) from arrays_values_v2;
 
-# array_concat column-wise #13 (1D + 1D + 1D)
+# array_concat column-wise #12 (1D + 1D + 1D)
 query ?
 select array_concat(make_array(column3), column1, column2) from 
arrays_values_v2;
 ----
@@ -594,13 +643,25 @@ select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), 
array_position([1, 2, 3,
 ----
 3 5 1
 
-# array_position scalar function #2
+# array_position scalar function #2 (with optional argument)
 query III
 select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 
2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2);
 ----
 4 5 2
 
-# array_position with columns
+# array_position scalar function #3 (element is list)
+query II
+select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], 
[7, 8, 9]), [4, 5, 6]), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 
4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]);
+----
+2 2
+
+# array_position scalar function #4 (element in list; with optional argument)
+query II
+select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], 
[7, 8, 9]), [4, 5, 6], 3), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 
3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], 3);
+----
+4 3
+
+# array_position with columns #1
 query II
 select array_position(column1, column2), array_position(column1, column2, 
column3) from arrays_values_without_nulls;
 ----
@@ -609,24 +670,44 @@ select array_position(column1, column2), 
array_position(column1, column2, column
 3 3
 4 4
 
-# array_position with columns and scalars
+# array_position with columns #2 (element is list)
 query II
-select array_position(column1, 3), array_position(column1, 3, 5) from 
arrays_values_without_nulls;
+select array_position(column1, column2), array_position(column1, column2, 
column3) from nested_arrays;
 ----
-3 NULL
-NULL NULL
-NULL NULL
-NULL NULL
+3 3
+2 5
+
+# array_position with columns and scalars #1
+query III
+select array_position(make_array(1, 2, 3, 4, 5), column2), 
array_position(column1, 3), array_position(column1, 3, 5) from 
arrays_values_without_nulls;
+----
+1 3 NULL
+NULL NULL NULL
+NULL NULL NULL
+NULL NULL NULL
+
+# array_position with columns and scalars #2 (element is list)
+query III
+select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), 
column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, 
make_array(1, 2, 3), 2) from nested_arrays;
+----
+NULL 6 4
+NULL 1 NULL
 
 ## array_positions
 
-# array_positions scalar function
+# array_positions scalar function #1
 query ???
 select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 
3, 4, 5], 5), array_positions([1, 1, 1], 1);
 ----
 [3, 4] [5] [1, 2, 3]
 
-# array_positions with columns
+# array_positions scalar function #2
+query ?
+select array_positions(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], 
[4, 5, 6]), [2, 1, 3]);
+----
+[2, 4]
+
+# array_positions with columns #1
 query ?
 select array_positions(column1, column2) from arrays_values_without_nulls;
 ----
@@ -635,7 +716,14 @@ select array_positions(column1, column2) from 
arrays_values_without_nulls;
 [3]
 [4]
 
-# array_positions with columns and scalars
+# array_positions with columns #2 (element is list)
+query ?
+select array_positions(column1, column2) from nested_arrays;
+----
+[3]
+[2, 5]
+
+# array_positions with columns and scalars #1
 query ??
 select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 
45], column2) from arrays_values_without_nulls;
 ----
@@ -644,6 +732,13 @@ select array_positions(column1, 4), 
array_positions(array[1, 2, 23, 13, 33, 45],
 [] [3]
 [] []
 
+# array_positions with columns and scalars #2 (element is list)
+query ??
+select array_positions(column1, make_array(4, 5, 6)), 
array_positions(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), column2) from 
nested_arrays;
+----
+[6] []
+[1] []
+
 ## array_replace
 
 # array_replace scalar function
@@ -1053,6 +1148,9 @@ select make_array(f0) from fixed_size_list_array
 statement ok
 drop table values;
 
+statement ok
+drop table nested_arrays;
+
 statement ok
 drop table arrays;
 
diff --git a/datafusion/physical-expr/src/array_expressions.rs 
b/datafusion/physical-expr/src/array_expressions.rs
index 104d49e1c8..b16432b505 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -410,6 +410,7 @@ pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {
     let element = &args[1];
 
     let res = match (arr.value_type(), element.data_type()) {
+                (DataType::List(_), DataType::List(_)) => 
concat_internal(args)?,
                 (DataType::Utf8, DataType::Utf8) => append!(arr, element, 
StringArray),
                 (DataType::LargeUtf8, DataType::LargeUtf8) => append!(arr, 
element, LargeStringArray),
                 (DataType::Boolean, DataType::Boolean) => append!(arr, 
element, BooleanArray),
@@ -499,6 +500,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> 
{
     let arr = as_list_array(&args[1])?;
 
     let res = match (arr.value_type(), element.data_type()) {
+                (DataType::List(_), DataType::List(_)) => 
concat_internal(args)?,
                 (DataType::Utf8, DataType::Utf8) => prepend!(arr, element, 
StringArray),
                 (DataType::LargeUtf8, DataType::LargeUtf8) => prepend!(arr, 
element, LargeStringArray),
                 (DataType::Boolean, DataType::Boolean) => prepend!(arr, 
element, BooleanArray),
@@ -543,7 +545,18 @@ fn align_array_dimensions(args: Vec<ArrayRef>) -> 
Result<Vec<ArrayRef>> {
                 let mut aligned_array = array.clone();
                 for _ in 0..(max_ndim - ndim) {
                     let data_type = aligned_array.as_ref().data_type().clone();
-                    aligned_array = array_array(&[aligned_array], data_type)?;
+                    let offsets: Vec<i32> =
+                        (0..downcast_arg!(aligned_array, 
ListArray).offsets().len())
+                            .map(|i| i as i32)
+                            .collect();
+                    let field = Arc::new(Field::new("item", data_type, true));
+
+                    aligned_array = Arc::new(ListArray::try_new(
+                        field,
+                        OffsetBuffer::new(offsets.into()),
+                        Arc::new(aligned_array.clone()),
+                        None,
+                    )?)
                 }
                 Ok(aligned_array)
             } else {
@@ -761,6 +774,7 @@ pub fn array_position(args: &[ArrayRef]) -> 
Result<ArrayRef> {
 
     let res = match arr.data_type() {
         DataType::List(field) => match field.data_type() {
+            DataType::List(_) => position!(arr, element, index, ListArray),
             DataType::Utf8 => position!(arr, element, index, StringArray),
             DataType::LargeUtf8 => position!(arr, element, index, 
LargeStringArray),
             DataType::Boolean => position!(arr, element, index, BooleanArray),
@@ -846,6 +860,7 @@ pub fn array_positions(args: &[ArrayRef]) -> 
Result<ArrayRef> {
 
     let res = match arr.data_type() {
         DataType::List(field) => match field.data_type() {
+            DataType::List(_) => positions!(arr, element, ListArray),
             DataType::Utf8 => positions!(arr, element, StringArray),
             DataType::LargeUtf8 => positions!(arr, element, LargeStringArray),
             DataType::Boolean => positions!(arr, element, BooleanArray),
@@ -1617,6 +1632,48 @@ mod tests {
         );
     }
 
+    #[test]
+    fn test_nested_array_concat() {
+        // array_concat([1, 2, 3, 4], [1, 2, 3, 4]) = [1, 2, 3, 4, 1, 2, 3, 4]
+        let list_array = return_array().into_array(1);
+        let arr = array_concat(&[list_array.clone(), list_array.clone()])
+            .expect("failed to initialize function array_concat");
+        let result =
+            as_list_array(&arr).expect("failed to initialize function 
array_concat");
+
+        assert_eq!(
+            &[1, 2, 3, 4, 1, 2, 3, 4],
+            result
+                .value(0)
+                .as_any()
+                .downcast_ref::<Int64Array>()
+                .unwrap()
+                .values()
+        );
+
+        // array_concat([[1, 2, 3, 4], [5, 6, 7, 8]], [1, 2, 3, 4]) = [[1, 2, 
3, 4], [5, 6, 7, 8], [1, 2, 3, 4]]
+        let list_nested_array = return_nested_array().into_array(1);
+        let list_array = return_array().into_array(1);
+        let arr = array_concat(&[list_nested_array, list_array])
+            .expect("failed to initialize function array_concat");
+        let result =
+            as_list_array(&arr).expect("failed to initialize function 
array_concat");
+
+        assert_eq!(
+            &[1, 2, 3, 4],
+            result
+                .value(0)
+                .as_any()
+                .downcast_ref::<ListArray>()
+                .unwrap()
+                .value(2)
+                .as_any()
+                .downcast_ref::<Int64Array>()
+                .unwrap()
+                .values()
+        );
+    }
+
     #[test]
     fn test_array_fill() {
         // array_fill(4, [5]) = [4, 4, 4, 4, 4]

Reply via email to