This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ee59dcc730 Support array `flatten` sql function (#7239)
ee59dcc730 is described below
commit ee59dcc730ce6dda2340cde31a6168ff7a5e9c7a
Author: Jay Zhan <[email protected]>
AuthorDate: Thu Aug 10 01:14:09 2023 +0800
Support array `flatten` sql function (#7239)
* Support array flatten sql function
Signed-off-by: jayzhan211 <[email protected]>
* add null and float
Signed-off-by: jayzhan211 <[email protected]>
* add alias, 1d test and docs
Signed-off-by: jayzhan211 <[email protected]>
* pretty
Signed-off-by: jayzhan211 <[email protected]>
* rename
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
.../core/tests/sqllogictests/test_files/array.slt | 34 ++++++++++++++++
datafusion/expr/src/built_in_function.rs | 25 +++++++++++-
datafusion/expr/src/expr_fn.rs | 6 +++
datafusion/expr/src/expr_schema.rs | 1 +
datafusion/physical-expr/src/array_expressions.rs | 47 ++++++++++++++++++++++
datafusion/physical-expr/src/functions.rs | 4 ++
datafusion/proto/proto/datafusion.proto | 1 +
datafusion/proto/src/generated/pbjson.rs | 3 ++
datafusion/proto/src/generated/prost.rs | 3 ++
datafusion/proto/src/logical_plan/from_proto.rs | 1 +
datafusion/proto/src/logical_plan/to_proto.rs | 1 +
docs/source/user-guide/expressions.md | 1 +
docs/source/user-guide/sql/scalar_functions.md | 18 +++++++++
13 files changed, 143 insertions(+), 2 deletions(-)
diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt
b/datafusion/core/tests/sqllogictests/test_files/array.slt
index 218817fc16..569f14f99a 100644
--- a/datafusion/core/tests/sqllogictests/test_files/array.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/array.slt
@@ -110,6 +110,13 @@ AS VALUES
(NULL, NULL, NULL, NULL)
;
+statement ok
+CREATE TABLE flatten_table
+AS VALUES
+ (make_array([1], [2], [3]), make_array([[1, 2, 3]], [[4, 5]], [[6]]),
make_array([[[1]]], [[[2, 3]]]), make_array([1.0], [2.1, 2.2], [3.2, 3.3,
3.4])),
+ (make_array([1, 2], [3, 4], [5, 6]), make_array([[8]]),
make_array([[[1,2]]], [[[3]]]), make_array([1.0, 2.0], [3.0, 4.0], [5.0, 6.0]))
+;
+
statement ok
CREATE TABLE array_has_table_1D
AS VALUES
@@ -2330,6 +2337,30 @@ select array_concat(column1, [7]) from arrays_values_v2;
[11, 12, 7]
[7]
+# flatten
+query ???
+select flatten(make_array(1, 2, 1, 3, 2)),
+ flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))),
+ flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]));
+----
+[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4]
+
+query ????
+select column1, column2, column3, column4 from flatten_table;
+----
+[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0],
[2.1, 2.2], [3.2, 3.3, 3.4]]
+[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0,
4.0], [5.0, 6.0]]
+
+query ????
+select flatten(column1),
+ flatten(column2),
+ flatten(column3),
+ flatten(column4)
+from flatten_table;
+----
+[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
+[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
+
### Delete tables
statement ok
@@ -2382,3 +2413,6 @@ drop table arrays_with_repeating_elements;
statement ok
drop table nested_arrays_with_repeating_elements;
+
+statement ok
+drop table flatten_table;
diff --git a/datafusion/expr/src/built_in_function.rs
b/datafusion/expr/src/built_in_function.rs
index f239155a92..703d41cbee 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -165,6 +165,8 @@ pub enum BuiltinScalarFunction {
Cardinality,
/// construct an array from columns
MakeArray,
+ /// Flatten
+ Flatten,
// struct functions
/// struct
@@ -368,6 +370,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplace => Volatility::Immutable,
BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable,
BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable,
+ BuiltinScalarFunction::Flatten => Volatility::Immutable,
BuiltinScalarFunction::ArraySlice => Volatility::Immutable,
BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
@@ -501,6 +504,22 @@ impl BuiltinScalarFunction {
// the return type of the built in function.
// Some built-in functions' return type depends on the incoming type.
match self {
+ BuiltinScalarFunction::Flatten => {
+ fn get_base_type(data_type: &DataType) -> Result<DataType> {
+ match data_type {
+ DataType::List(field) => match field.data_type() {
+ DataType::List(_) =>
get_base_type(field.data_type()),
+ _ => Ok(data_type.to_owned()),
+ },
+ _ => Err(DataFusionError::Internal(
+ "Not reachable, data_type should be
List".to_string(),
+ )),
+ }
+ }
+
+ let data_type = get_base_type(&input_expr_types[0])?;
+ Ok(data_type)
+ }
BuiltinScalarFunction::ArrayAppend =>
Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayConcat => {
let mut expr_type = Null;
@@ -819,11 +838,12 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayConcat => {
Signature::variadic_any(self.volatility())
}
+ BuiltinScalarFunction::ArrayDims => Signature::any(1,
self.volatility()),
+ BuiltinScalarFunction::ArrayElement => Signature::any(2,
self.volatility()),
+ BuiltinScalarFunction::Flatten => Signature::any(1,
self.volatility()),
BuiltinScalarFunction::ArrayHasAll
| BuiltinScalarFunction::ArrayHasAny
| BuiltinScalarFunction::ArrayHas => Signature::any(2,
self.volatility()),
- BuiltinScalarFunction::ArrayDims => Signature::any(1,
self.volatility()),
- BuiltinScalarFunction::ArrayElement => Signature::any(2,
self.volatility()),
BuiltinScalarFunction::ArrayLength => {
Signature::variadic_any(self.volatility())
}
@@ -1307,6 +1327,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static
[&'static str] {
"list_element",
"list_extract",
],
+ BuiltinScalarFunction::Flatten => &["flatten"],
BuiltinScalarFunction::ArrayHasAll => &["array_has_all",
"list_has_all"],
BuiltinScalarFunction::ArrayHasAny => &["array_has_any",
"list_has_any"],
BuiltinScalarFunction::ArrayHas => {
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index ef6ce81711..47767c23b3 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -564,6 +564,12 @@ scalar_expr!(
first_array second_array,
"Returns true if at least one element of the second array appears in the first
array; otherwise, it returns false."
);
+scalar_expr!(
+ Flatten,
+ flatten,
+ array,
+ "flattens an array of arrays into a single array."
+);
scalar_expr!(
ArrayDims,
array_dims,
diff --git a/datafusion/expr/src/expr_schema.rs
b/datafusion/expr/src/expr_schema.rs
index 1d26485b4e..d7bc86158b 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -92,6 +92,7 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
+
fun.return_type(&data_types)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
diff --git a/datafusion/physical-expr/src/array_expressions.rs
b/datafusion/physical-expr/src/array_expressions.rs
index fcd9adf19d..819b389cff 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -1789,6 +1789,53 @@ pub fn cardinality(args: &[ArrayRef]) ->
Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}
+// Create new offsets that are euqiavlent to `flatten` the array.
+fn get_offsets_for_flatten(
+ offsets: OffsetBuffer<i32>,
+ indexes: OffsetBuffer<i32>,
+) -> OffsetBuffer<i32> {
+ let buffer = offsets.into_inner();
+ let offsets: Vec<i32> = indexes.iter().map(|i| buffer[*i as
usize]).collect();
+ OffsetBuffer::new(offsets.into())
+}
+
+fn flatten_internal(
+ array: &dyn Array,
+ indexes: Option<OffsetBuffer<i32>>,
+) -> Result<ListArray> {
+ let list_arr = as_list_array(array)?;
+ let (field, offsets, values, nulls) = list_arr.clone().into_parts();
+ let data_type = field.data_type();
+
+ match data_type {
+ // Recursively get the base offsets for flattened array
+ DataType::List(_) => {
+ if let Some(indexes) = indexes {
+ let offsets = get_offsets_for_flatten(offsets, indexes);
+ flatten_internal(&values, Some(offsets))
+ } else {
+ flatten_internal(&values, Some(offsets))
+ }
+ }
+ // Reach the base level, create a new list array
+ _ => {
+ if let Some(indexes) = indexes {
+ let offsets = get_offsets_for_flatten(offsets, indexes);
+ let list_arr = ListArray::new(field, offsets, values, nulls);
+ Ok(list_arr)
+ } else {
+ Ok(list_arr.clone())
+ }
+ }
+ }
+}
+
+/// Flatten SQL function
+pub fn flatten(args: &[ArrayRef]) -> Result<ArrayRef> {
+ let flattened_array = flatten_internal(&args[0], None)?;
+ Ok(Arc::new(flattened_array) as ArrayRef)
+}
+
/// Array_length SQL function
pub fn array_length(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_list_array(&args[0])?;
diff --git a/datafusion/physical-expr/src/functions.rs
b/datafusion/physical-expr/src/functions.rs
index df76d55bfc..d1a5119ee8 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -437,6 +437,10 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayLength => {
Arc::new(|args|
make_scalar_function(array_expressions::array_length)(args))
}
+ BuiltinScalarFunction::Flatten => {
+ Arc::new(|args|
make_scalar_function(array_expressions::flatten)(args))
+ }
+
BuiltinScalarFunction::ArrayNdims => {
Arc::new(|args|
make_scalar_function(array_expressions::array_ndims)(args))
}
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 1081fca2e1..1a8ad093b9 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -594,6 +594,7 @@ enum ScalarFunction {
ArrayRemoveAll = 109;
ArrayReplaceAll = 110;
Nanvl = 111;
+ Flatten = 112;
}
message ScalarFunctionNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 8691487c72..a33d80be9d 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -18944,6 +18944,7 @@ impl serde::Serialize for ScalarFunction {
Self::ArrayRemoveAll => "ArrayRemoveAll",
Self::ArrayReplaceAll => "ArrayReplaceAll",
Self::Nanvl => "Nanvl",
+ Self::Flatten => "Flatten",
};
serializer.serialize_str(variant)
}
@@ -19067,6 +19068,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"ArrayRemoveAll",
"ArrayReplaceAll",
"Nanvl",
+ "Flatten",
];
struct GeneratedVisitor;
@@ -19221,6 +19223,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"ArrayRemoveAll" => Ok(ScalarFunction::ArrayRemoveAll),
"ArrayReplaceAll" => Ok(ScalarFunction::ArrayReplaceAll),
"Nanvl" => Ok(ScalarFunction::Nanvl),
+ "Flatten" => Ok(ScalarFunction::Flatten),
_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 87371ba277..519cd002df 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2374,6 +2374,7 @@ pub enum ScalarFunction {
ArrayRemoveAll = 109,
ArrayReplaceAll = 110,
Nanvl = 111,
+ Flatten = 112,
}
impl ScalarFunction {
/// String value of the enum field names used in the ProtoBuf definition.
@@ -2494,6 +2495,7 @@ impl ScalarFunction {
ScalarFunction::ArrayRemoveAll => "ArrayRemoveAll",
ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll",
ScalarFunction::Nanvl => "Nanvl",
+ ScalarFunction::Flatten => "Flatten",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
@@ -2611,6 +2613,7 @@ impl ScalarFunction {
"ArrayRemoveAll" => Some(Self::ArrayRemoveAll),
"ArrayReplaceAll" => Some(Self::ArrayReplaceAll),
"Nanvl" => Some(Self::Nanvl),
+ "Flatten" => Some(Self::Flatten),
_ => None,
}
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index c17d8dbd8c..d2e037aa4b 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -457,6 +457,7 @@ impl From<&protobuf::ScalarFunction> for
BuiltinScalarFunction {
ScalarFunction::ArrayHas => Self::ArrayHas,
ScalarFunction::ArrayDims => Self::ArrayDims,
ScalarFunction::ArrayElement => Self::ArrayElement,
+ ScalarFunction::Flatten => Self::Flatten,
ScalarFunction::ArrayLength => Self::ArrayLength,
ScalarFunction::ArrayNdims => Self::ArrayNdims,
ScalarFunction::ArrayPosition => Self::ArrayPosition,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index aa1132e8b1..cdb9b00803 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1456,6 +1456,7 @@ impl TryFrom<&BuiltinScalarFunction> for
protobuf::ScalarFunction {
BuiltinScalarFunction::ArrayHas => Self::ArrayHas,
BuiltinScalarFunction::ArrayDims => Self::ArrayDims,
BuiltinScalarFunction::ArrayElement => Self::ArrayElement,
+ BuiltinScalarFunction::Flatten => Self::Flatten,
BuiltinScalarFunction::ArrayLength => Self::ArrayLength,
BuiltinScalarFunction::ArrayNdims => Self::ArrayNdims,
BuiltinScalarFunction::ArrayPosition => Self::ArrayPosition,
diff --git a/docs/source/user-guide/expressions.md
b/docs/source/user-guide/expressions.md
index a04f43fd4b..88a5a73a6d 100644
--- a/docs/source/user-guide/expressions.md
+++ b/docs/source/user-guide/expressions.md
@@ -188,6 +188,7 @@ Unlike to some databases the math functions in Datafusion
works the same way as
| array_has_any(array, sub-array) | Returns true if any elements exist
in both arrays `array_has_any([1,2,3], [1,4]) -> true`
|
| array_dims(array) | Returns an array of the array's
dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]`
|
| array_element(array, index) | Extracts the element with the index
n from the array `array_element([1, 2, 3, 4], 3) -> 3`
|
+| flatten(array) | Converts an array of arrays to a
flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]`
|
| array_length(array, dimension) | Returns the length of the array
dimension. `array_length([1, 2, 3, 4, 5]) -> 5`
|
| array_ndims(array) | Returns the number of dimensions of
the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2`
|
| array_position(array, element) | Searches for an element in the
array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2`
|
diff --git a/docs/source/user-guide/sql/scalar_functions.md
b/docs/source/user-guide/sql/scalar_functions.md
index dec120db18..9bcf2ae0b0 100644
--- a/docs/source/user-guide/sql/scalar_functions.md
+++ b/docs/source/user-guide/sql/scalar_functions.md
@@ -1685,6 +1685,24 @@ array_fill(element, array)
Can be a constant, column, or function, and any combination of array
operators.
- **element**: Element to copy to the array.
+### `flatten`
+
+Converts an array of arrays to a flat array
+
+- Applies to any depth of nested arrays
+- Does not change arrays that are already flat
+
+The flattened array contains all the elements from all source arrays.
+
+#### Arguments
+
+- **array**: Array expression
+ Can be a constant, column, or function, and any combination of array
operators.
+
+```
+flatten(array)
+```
+
### `array_indexof`
_Alias of [array_position](#array_position)._