This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 0624378160 feat: array functions treat an array as an element (#6986)
0624378160 is described below
commit 0624378160b1d19de12a29b1374dea1930c9faaa
Author: Igor Izvekov <[email protected]>
AuthorDate: Wed Jul 19 00:13:45 2023 +0300
feat: array functions treat an array as an element (#6986)
---
.../core/tests/sqllogictests/test_files/array.slt | 196 +++++++++++++++------
datafusion/physical-expr/src/array_expressions.rs | 59 ++++++-
2 files changed, 205 insertions(+), 50 deletions(-)
diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt
b/datafusion/core/tests/sqllogictests/test_files/array.slt
index f0f50ccc93..1e9b32414b 100644
--- a/datafusion/core/tests/sqllogictests/test_files/array.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/array.slt
@@ -55,6 +55,13 @@ AS VALUES
(make_array(make_array(15, 16),make_array(NULL, 18)), make_array(16.6, 17.7,
18.8), NULL)
;
+statement ok
+CREATE TABLE nested_arrays
+AS VALUES
+ (make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9),
make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), make_array(7,
8, 9), 2, make_array([[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]])),
+ (make_array(make_array(4, 5, 6), make_array(10, 11, 12), make_array(4, 9,
8), make_array(7, 8, 9), make_array(10, 11, 12), make_array(1, 8, 7)),
make_array(10, 11, 12), 3, make_array([[11, 12, 13], [14, 15, 16]], [[17, 18,
19], [20, 21, 22]]))
+;
+
statement ok
CREATE TABLE arrays_values
AS VALUES
@@ -100,6 +107,13 @@ NULL [13.3, 14.4, 15.5] [a, m, e, t]
[[11, 12], [13, 14]] NULL [,]
[[15, 16], [, 18]] [16.6, 17.7, 18.8] NULL
+# nested_arrays table
+query ??I?
+select column1, column2, column3, column4 from nested_arrays;
+----
+[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [7, 8, 9] 2
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
+[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [10,
11, 12] 3 [[[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]]]
+
# values table
query IIIRT
select a, b, c, d, e from values;
@@ -292,7 +306,13 @@ select array_append(make_array(1, 2, 3), 4),
array_append(make_array(1.0, 2.0, 3
----
[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o]
-# array_append with columns
+# array_append scalar function #4 (element is list)
+query ???
+select array_append(make_array([1], [2], [3]), make_array(4)),
array_append(make_array([1.0], [2.0], [3.0]), make_array(4.0)),
array_append(make_array(['h'], ['e'], ['l'], ['l']), make_array('o'));
+----
+[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]]
+
+# array_append with columns #1
query ?
select array_append(column1, column2) from arrays_values;
----
@@ -305,7 +325,14 @@ select array_append(column1, column2) from arrays_values;
[51, 52, , 54, 55, 56, 57, 58, 59, 60, 55]
[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66]
-# array_append with columns and scalars
+# array_append with columns #2 (element is list)
+query ?
+select array_append(column1, column2) from nested_arrays;
+----
+[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]]
+[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10,
11, 12]]
+
+# array_append with columns and scalars #1
query ??
select array_append(column2, 100.1), array_append(column3, '.') from arrays;
----
@@ -317,6 +344,13 @@ select array_append(column2, 100.1), array_append(column3,
'.') from arrays;
[100.1] [,, .]
[16.6, 17.7, 18.8, 100.1] [.]
+# array_append with columns and scalars #2
+query ??
+select array_append(column1, make_array(1, 11, 111)),
array_append(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), column2)
from nested_arrays;
+----
+[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11,
111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]]
+[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1,
11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]]
+
## array_prepend
# array_prepend scalar function #1
@@ -337,7 +371,13 @@ select array_prepend(1, make_array(2, 3, 4)),
array_prepend(1.0, make_array(2.0,
----
[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o]
-# array_prepend with columns
+# array_prepend scalar function #4 (element is list)
+query ???
+select array_prepend(make_array(1), make_array(make_array(2), make_array(3),
make_array(4))), array_prepend(make_array(1.0), make_array([2.0], [3.0],
[4.0])), array_prepend(make_array('h'), make_array(['e'], ['l'], ['l'], ['o']));
+----
+[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]]
+
+# array_prepend with columns #1
query ?
select array_prepend(column2, column1) from arrays_values;
----
@@ -350,7 +390,14 @@ select array_prepend(column2, column1) from arrays_values;
[55, 51, 52, , 54, 55, 56, 57, 58, 59, 60]
[66, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70]
-# array_prepend with columns and scalars
+# array_prepend with columns #2 (element is list)
+query ?
+select array_prepend(column2, column1) from nested_arrays;
+----
+[[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]]
+[[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12],
[1, 8, 7]]
+
+# array_prepend with columns and scalars #1
query ??
select array_prepend(100.1, column2), array_prepend('.', column3) from arrays;
----
@@ -362,6 +409,13 @@ select array_prepend(100.1, column2), array_prepend('.',
column3) from arrays;
[100.1] [., ,]
[100.1, 16.6, 17.7, 18.8] [.]
+# array_prepend with columns and scalars #2 (element is list)
+query ??
+select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2,
make_array(make_array(1, 2, 3), make_array(11, 12, 13))) from nested_arrays;
+----
+[[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5,
6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]]
+[[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12],
[1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]]
+
## array_fill
# array_fill scalar function #1
@@ -473,19 +527,6 @@ select array_concat(make_array(column2),
make_array(column3)) from arrays_values
# array_concat column-wise #4
query ?
-select array_concat(column1, column2) from arrays_values;
-----
-[, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1]
-[11, 12, 13, 14, 15, 16, 17, 18, , 20, 12]
-[21, 22, 23, , 25, 26, 27, 28, 29, 30, 23]
-[31, 32, 33, 34, 35, , 37, 38, 39, 40, 34]
-[44]
-[41, 42, 43, 44, 45, 46, 47, 48, 49, 50, ]
-[51, 52, , 54, 55, 56, 57, 58, 59, 60, 55]
-[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66]
-
-# array_concat column-wise #5
-query ?
select array_concat(make_array(column2), make_array(0)) from arrays_values;
----
[1, 0]
@@ -497,7 +538,7 @@ select array_concat(make_array(column2), make_array(0))
from arrays_values;
[55, 0]
[66, 0]
-# array_concat column-wise #6
+# array_concat column-wise #5
query ???
select array_concat(column1, column1), array_concat(column2, column2),
array_concat(column3, column3) from arrays;
----
@@ -509,7 +550,7 @@ NULL [13.3, 14.4, 15.5, 13.3, 14.4, 15.5] [a, m, e, t, a,
m, e, t]
[[11, 12], [13, 14], [11, 12], [13, 14]] NULL [,, ,]
[[15, 16], [, 18], [15, 16], [, 18]] [16.6, 17.7, 18.8, 16.6, 17.7, 18.8] NULL
-# array_concat column-wise #7
+# array_concat column-wise #6
query ??
select array_concat(column1, make_array(make_array(1, 2), make_array(3, 4))),
array_concat(column2, make_array(1.1, 2.2, 3.3)) from arrays;
----
@@ -521,7 +562,7 @@ select array_concat(column1, make_array(make_array(1, 2),
make_array(3, 4))), ar
[[11, 12], [13, 14], [1, 2], [3, 4]] [1.1, 2.2, 3.3]
[[15, 16], [, 18], [1, 2], [3, 4]] [16.6, 17.7, 18.8, 1.1, 2.2, 3.3]
-# array_concat column-wise #8
+# array_concat column-wise #7
query ?
select array_concat(column3, make_array('.', '.', '.')) from arrays;
----
@@ -543,7 +584,7 @@ select array_concat(column3, make_array('.', '.', '.'))
from arrays;
# [11, 12] NULL NULL NULL
# NULL NULL NULL NULL
-# array_concat column-wise #9 (1D + 1D)
+# array_concat column-wise #8 (1D + 1D)
query ?
select array_concat(column1, column2) from arrays_values_v2;
----
@@ -554,28 +595,36 @@ select array_concat(column1, column2) from
arrays_values_v2;
[11, 12]
NULL
-# TODO: Concat columns with different dimensions fails
-# array_concat column-wise #10 (1D + 2D)
-# query error DataFusion error: Arrow error: Invalid argument error: column
types must match schema types, expected List\(Field \{ name: "item", data_type:
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\)
but found List\(Field \{ name: "item", data_type: List\(Field \{ name: "item",
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata:
\{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\}
\}\) at column index 0
-# select array_concat(make_array(column3), column4) from arrays_values_v2;
+# array_concat column-wise #9 (2D + 1D)
+query ?
+select array_concat(column4, make_array(column3)) from arrays_values_v2;
+----
+[[30, 40, 50], [12]]
+[[, , 60], [13]]
+[[70, , ], [14]]
+[[]]
+[[]]
+[[]]
+
+# array_concat column-wise #10 (3D + 2D + 1D)
+query ?
+select array_concat(column4, column1, column2) from nested_arrays;
+----
+[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]], [[1, 2, 3], [2, 9, 1], [7,
8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]], [[7, 8, 9]]]
+[[[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]], [[4, 5, 6], [10,
11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]], [[10, 11, 12]]]
-# array_concat column-wise #11 (1D + Integers)
+# array_concat column-wise #11 (2D + 1D)
query ?
-select array_concat(column2, column3) from arrays_values_v2;
+select array_concat(column4, column1) from arrays_values_v2;
----
-[4, 5, , 12]
-[7, , 8, 13]
-[14]
-[, 21, ]
+[[30, 40, 50], [, 2, 3]]
+[[, , 60], ]
+[[70, , ], [9, , 10]]
+[[, 1]]
+[[11, 12]]
[]
-[]
-
-# TODO: Panic at 'range end index 3 out of range for slice of length 2'
-# array_concat column-wise #12 (2D + 1D)
-# query
-# select array_concat(column4, column1) from arrays_values_v2;
-# array_concat column-wise #13 (1D + 1D + 1D)
+# array_concat column-wise #12 (1D + 1D + 1D)
query ?
select array_concat(make_array(column3), column1, column2) from
arrays_values_v2;
----
@@ -594,13 +643,25 @@ select array_position(['h', 'e', 'l', 'l', 'o'], 'l'),
array_position([1, 2, 3,
----
3 5 1
-# array_position scalar function #2
+# array_position scalar function #2 (with optional argument)
query III
select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1,
2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2);
----
4 5 2
-# array_position with columns
+# array_position scalar function #3 (element is list)
+query II
+select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6],
[7, 8, 9]), [4, 5, 6]), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3,
4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]);
+----
+2 2
+
+# array_position scalar function #4 (element in list; with optional argument)
+query II
+select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6],
[7, 8, 9]), [4, 5, 6], 3), array_position(make_array([1, 3, 2], [2, 3, 4], [2,
3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], 3);
+----
+4 3
+
+# array_position with columns #1
query II
select array_position(column1, column2), array_position(column1, column2,
column3) from arrays_values_without_nulls;
----
@@ -609,24 +670,44 @@ select array_position(column1, column2),
array_position(column1, column2, column
3 3
4 4
-# array_position with columns and scalars
+# array_position with columns #2 (element is list)
query II
-select array_position(column1, 3), array_position(column1, 3, 5) from
arrays_values_without_nulls;
+select array_position(column1, column2), array_position(column1, column2,
column3) from nested_arrays;
----
-3 NULL
-NULL NULL
-NULL NULL
-NULL NULL
+3 3
+2 5
+
+# array_position with columns and scalars #1
+query III
+select array_position(make_array(1, 2, 3, 4, 5), column2),
array_position(column1, 3), array_position(column1, 3, 5) from
arrays_values_without_nulls;
+----
+1 3 NULL
+NULL NULL NULL
+NULL NULL NULL
+NULL NULL NULL
+
+# array_position with columns and scalars #2 (element is list)
+query III
+select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]),
column2), array_position(column1, make_array(4, 5, 6)), array_position(column1,
make_array(1, 2, 3), 2) from nested_arrays;
+----
+NULL 6 4
+NULL 1 NULL
## array_positions
-# array_positions scalar function
+# array_positions scalar function #1
query ???
select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2,
3, 4, 5], 5), array_positions([1, 1, 1], 1);
----
[3, 4] [5] [1, 2, 3]
-# array_positions with columns
+# array_positions scalar function #2
+query ?
+select array_positions(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3],
[4, 5, 6]), [2, 1, 3]);
+----
+[2, 4]
+
+# array_positions with columns #1
query ?
select array_positions(column1, column2) from arrays_values_without_nulls;
----
@@ -635,7 +716,14 @@ select array_positions(column1, column2) from
arrays_values_without_nulls;
[3]
[4]
-# array_positions with columns and scalars
+# array_positions with columns #2 (element is list)
+query ?
+select array_positions(column1, column2) from nested_arrays;
+----
+[3]
+[2, 5]
+
+# array_positions with columns and scalars #1
query ??
select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33,
45], column2) from arrays_values_without_nulls;
----
@@ -644,6 +732,13 @@ select array_positions(column1, 4),
array_positions(array[1, 2, 23, 13, 33, 45],
[] [3]
[] []
+# array_positions with columns and scalars #2 (element is list)
+query ??
+select array_positions(column1, make_array(4, 5, 6)),
array_positions(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), column2) from
nested_arrays;
+----
+[6] []
+[1] []
+
## array_replace
# array_replace scalar function
@@ -1053,6 +1148,9 @@ select make_array(f0) from fixed_size_list_array
statement ok
drop table values;
+statement ok
+drop table nested_arrays;
+
statement ok
drop table arrays;
diff --git a/datafusion/physical-expr/src/array_expressions.rs
b/datafusion/physical-expr/src/array_expressions.rs
index 104d49e1c8..b16432b505 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -410,6 +410,7 @@ pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {
let element = &args[1];
let res = match (arr.value_type(), element.data_type()) {
+ (DataType::List(_), DataType::List(_)) =>
concat_internal(args)?,
(DataType::Utf8, DataType::Utf8) => append!(arr, element,
StringArray),
(DataType::LargeUtf8, DataType::LargeUtf8) => append!(arr,
element, LargeStringArray),
(DataType::Boolean, DataType::Boolean) => append!(arr,
element, BooleanArray),
@@ -499,6 +500,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef>
{
let arr = as_list_array(&args[1])?;
let res = match (arr.value_type(), element.data_type()) {
+ (DataType::List(_), DataType::List(_)) =>
concat_internal(args)?,
(DataType::Utf8, DataType::Utf8) => prepend!(arr, element,
StringArray),
(DataType::LargeUtf8, DataType::LargeUtf8) => prepend!(arr,
element, LargeStringArray),
(DataType::Boolean, DataType::Boolean) => prepend!(arr,
element, BooleanArray),
@@ -543,7 +545,18 @@ fn align_array_dimensions(args: Vec<ArrayRef>) ->
Result<Vec<ArrayRef>> {
let mut aligned_array = array.clone();
for _ in 0..(max_ndim - ndim) {
let data_type = aligned_array.as_ref().data_type().clone();
- aligned_array = array_array(&[aligned_array], data_type)?;
+ let offsets: Vec<i32> =
+ (0..downcast_arg!(aligned_array,
ListArray).offsets().len())
+ .map(|i| i as i32)
+ .collect();
+ let field = Arc::new(Field::new("item", data_type, true));
+
+ aligned_array = Arc::new(ListArray::try_new(
+ field,
+ OffsetBuffer::new(offsets.into()),
+ Arc::new(aligned_array.clone()),
+ None,
+ )?)
}
Ok(aligned_array)
} else {
@@ -761,6 +774,7 @@ pub fn array_position(args: &[ArrayRef]) ->
Result<ArrayRef> {
let res = match arr.data_type() {
DataType::List(field) => match field.data_type() {
+ DataType::List(_) => position!(arr, element, index, ListArray),
DataType::Utf8 => position!(arr, element, index, StringArray),
DataType::LargeUtf8 => position!(arr, element, index,
LargeStringArray),
DataType::Boolean => position!(arr, element, index, BooleanArray),
@@ -846,6 +860,7 @@ pub fn array_positions(args: &[ArrayRef]) ->
Result<ArrayRef> {
let res = match arr.data_type() {
DataType::List(field) => match field.data_type() {
+ DataType::List(_) => positions!(arr, element, ListArray),
DataType::Utf8 => positions!(arr, element, StringArray),
DataType::LargeUtf8 => positions!(arr, element, LargeStringArray),
DataType::Boolean => positions!(arr, element, BooleanArray),
@@ -1617,6 +1632,48 @@ mod tests {
);
}
+ #[test]
+ fn test_nested_array_concat() {
+ // array_concat([1, 2, 3, 4], [1, 2, 3, 4]) = [1, 2, 3, 4, 1, 2, 3, 4]
+ let list_array = return_array().into_array(1);
+ let arr = array_concat(&[list_array.clone(), list_array.clone()])
+ .expect("failed to initialize function array_concat");
+ let result =
+ as_list_array(&arr).expect("failed to initialize function
array_concat");
+
+ assert_eq!(
+ &[1, 2, 3, 4, 1, 2, 3, 4],
+ result
+ .value(0)
+ .as_any()
+ .downcast_ref::<Int64Array>()
+ .unwrap()
+ .values()
+ );
+
+ // array_concat([[1, 2, 3, 4], [5, 6, 7, 8]], [1, 2, 3, 4]) = [[1, 2,
3, 4], [5, 6, 7, 8], [1, 2, 3, 4]]
+ let list_nested_array = return_nested_array().into_array(1);
+ let list_array = return_array().into_array(1);
+ let arr = array_concat(&[list_nested_array, list_array])
+ .expect("failed to initialize function array_concat");
+ let result =
+ as_list_array(&arr).expect("failed to initialize function
array_concat");
+
+ assert_eq!(
+ &[1, 2, 3, 4],
+ result
+ .value(0)
+ .as_any()
+ .downcast_ref::<ListArray>()
+ .unwrap()
+ .value(2)
+ .as_any()
+ .downcast_ref::<Int64Array>()
+ .unwrap()
+ .values()
+ );
+ }
+
#[test]
fn test_array_fill() {
// array_fill(4, [5]) = [4, 4, 4, 4, 4]