This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ee59dcc730 Support array `flatten` sql function (#7239)
ee59dcc730 is described below

commit ee59dcc730ce6dda2340cde31a6168ff7a5e9c7a
Author: Jay Zhan <[email protected]>
AuthorDate: Thu Aug 10 01:14:09 2023 +0800

    Support array `flatten` sql function (#7239)
    
    * Support array flatten sql function
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add null and float
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add alias, 1d test and docs
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * pretty
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rename
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 .../core/tests/sqllogictests/test_files/array.slt  | 34 ++++++++++++++++
 datafusion/expr/src/built_in_function.rs           | 25 +++++++++++-
 datafusion/expr/src/expr_fn.rs                     |  6 +++
 datafusion/expr/src/expr_schema.rs                 |  1 +
 datafusion/physical-expr/src/array_expressions.rs  | 47 ++++++++++++++++++++++
 datafusion/physical-expr/src/functions.rs          |  4 ++
 datafusion/proto/proto/datafusion.proto            |  1 +
 datafusion/proto/src/generated/pbjson.rs           |  3 ++
 datafusion/proto/src/generated/prost.rs            |  3 ++
 datafusion/proto/src/logical_plan/from_proto.rs    |  1 +
 datafusion/proto/src/logical_plan/to_proto.rs      |  1 +
 docs/source/user-guide/expressions.md              |  1 +
 docs/source/user-guide/sql/scalar_functions.md     | 18 +++++++++
 13 files changed, 143 insertions(+), 2 deletions(-)

diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt 
b/datafusion/core/tests/sqllogictests/test_files/array.slt
index 218817fc16..569f14f99a 100644
--- a/datafusion/core/tests/sqllogictests/test_files/array.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/array.slt
@@ -110,6 +110,13 @@ AS VALUES
   (NULL, NULL, NULL, NULL)
 ;
 
+statement ok
+CREATE TABLE flatten_table
+AS VALUES
+  (make_array([1], [2], [3]), make_array([[1, 2, 3]], [[4, 5]], [[6]]), 
make_array([[[1]]], [[[2, 3]]]), make_array([1.0], [2.1, 2.2], [3.2, 3.3, 
3.4])),
+  (make_array([1, 2], [3, 4], [5, 6]), make_array([[8]]), 
make_array([[[1,2]]], [[[3]]]), make_array([1.0, 2.0], [3.0, 4.0], [5.0, 6.0]))
+;
+
 statement ok
 CREATE TABLE array_has_table_1D
 AS VALUES
@@ -2330,6 +2337,30 @@ select array_concat(column1, [7]) from arrays_values_v2;
 [11, 12, 7]
 [7]
 
+# flatten
+query ???
+select flatten(make_array(1, 2, 1, 3, 2)),
+       flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))),
+       flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]));
+----
+[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4]
+
+query ????
+select column1, column2, column3, column4 from flatten_table;
+----
+[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], 
[2.1, 2.2], [3.2, 3.3, 3.4]]
+[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 
4.0], [5.0, 6.0]]
+
+query ????
+select flatten(column1),
+       flatten(column2),
+       flatten(column3),
+       flatten(column4)
+from flatten_table;
+----
+[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
+[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
+
 ### Delete tables
 
 statement ok
@@ -2382,3 +2413,6 @@ drop table arrays_with_repeating_elements;
 
 statement ok
 drop table nested_arrays_with_repeating_elements;
+
+statement ok
+drop table flatten_table;
diff --git a/datafusion/expr/src/built_in_function.rs 
b/datafusion/expr/src/built_in_function.rs
index f239155a92..703d41cbee 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -165,6 +165,8 @@ pub enum BuiltinScalarFunction {
     Cardinality,
     /// construct an array from columns
     MakeArray,
+    /// Flatten
+    Flatten,
 
     // struct functions
     /// struct
@@ -368,6 +370,7 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::ArrayReplace => Volatility::Immutable,
             BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable,
             BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable,
+            BuiltinScalarFunction::Flatten => Volatility::Immutable,
             BuiltinScalarFunction::ArraySlice => Volatility::Immutable,
             BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
             BuiltinScalarFunction::Cardinality => Volatility::Immutable,
@@ -501,6 +504,22 @@ impl BuiltinScalarFunction {
         // the return type of the built in function.
         // Some built-in functions' return type depends on the incoming type.
         match self {
+            BuiltinScalarFunction::Flatten => {
+                fn get_base_type(data_type: &DataType) -> Result<DataType> {
+                    match data_type {
+                        DataType::List(field) => match field.data_type() {
+                            DataType::List(_) => 
get_base_type(field.data_type()),
+                            _ => Ok(data_type.to_owned()),
+                        },
+                        _ => Err(DataFusionError::Internal(
+                            "Not reachable, data_type should be 
List".to_string(),
+                        )),
+                    }
+                }
+
+                let data_type = get_base_type(&input_expr_types[0])?;
+                Ok(data_type)
+            }
             BuiltinScalarFunction::ArrayAppend => 
Ok(input_expr_types[0].clone()),
             BuiltinScalarFunction::ArrayConcat => {
                 let mut expr_type = Null;
@@ -819,11 +838,12 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::ArrayConcat => {
                 Signature::variadic_any(self.volatility())
             }
+            BuiltinScalarFunction::ArrayDims => Signature::any(1, 
self.volatility()),
+            BuiltinScalarFunction::ArrayElement => Signature::any(2, 
self.volatility()),
+            BuiltinScalarFunction::Flatten => Signature::any(1, 
self.volatility()),
             BuiltinScalarFunction::ArrayHasAll
             | BuiltinScalarFunction::ArrayHasAny
             | BuiltinScalarFunction::ArrayHas => Signature::any(2, 
self.volatility()),
-            BuiltinScalarFunction::ArrayDims => Signature::any(1, 
self.volatility()),
-            BuiltinScalarFunction::ArrayElement => Signature::any(2, 
self.volatility()),
             BuiltinScalarFunction::ArrayLength => {
                 Signature::variadic_any(self.volatility())
             }
@@ -1307,6 +1327,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static 
[&'static str] {
             "list_element",
             "list_extract",
         ],
+        BuiltinScalarFunction::Flatten => &["flatten"],
         BuiltinScalarFunction::ArrayHasAll => &["array_has_all", 
"list_has_all"],
         BuiltinScalarFunction::ArrayHasAny => &["array_has_any", 
"list_has_any"],
         BuiltinScalarFunction::ArrayHas => {
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index ef6ce81711..47767c23b3 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -564,6 +564,12 @@ scalar_expr!(
     first_array second_array,
 "Returns true if at least one element of the second array appears in the first 
array; otherwise, it returns false."
 );
+scalar_expr!(
+    Flatten,
+    flatten,
+    array,
+    "flattens an array of arrays into a single array."
+);
 scalar_expr!(
     ArrayDims,
     array_dims,
diff --git a/datafusion/expr/src/expr_schema.rs 
b/datafusion/expr/src/expr_schema.rs
index 1d26485b4e..d7bc86158b 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -92,6 +92,7 @@ impl ExprSchemable for Expr {
                     .iter()
                     .map(|e| e.get_type(schema))
                     .collect::<Result<Vec<_>>>()?;
+
                 fun.return_type(&data_types)
             }
             Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
diff --git a/datafusion/physical-expr/src/array_expressions.rs 
b/datafusion/physical-expr/src/array_expressions.rs
index fcd9adf19d..819b389cff 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -1789,6 +1789,53 @@ pub fn cardinality(args: &[ArrayRef]) -> 
Result<ArrayRef> {
     Ok(Arc::new(result) as ArrayRef)
 }
 
+// Create new offsets that are euqiavlent to `flatten` the array.
+fn get_offsets_for_flatten(
+    offsets: OffsetBuffer<i32>,
+    indexes: OffsetBuffer<i32>,
+) -> OffsetBuffer<i32> {
+    let buffer = offsets.into_inner();
+    let offsets: Vec<i32> = indexes.iter().map(|i| buffer[*i as 
usize]).collect();
+    OffsetBuffer::new(offsets.into())
+}
+
+fn flatten_internal(
+    array: &dyn Array,
+    indexes: Option<OffsetBuffer<i32>>,
+) -> Result<ListArray> {
+    let list_arr = as_list_array(array)?;
+    let (field, offsets, values, nulls) = list_arr.clone().into_parts();
+    let data_type = field.data_type();
+
+    match data_type {
+        // Recursively get the base offsets for flattened array
+        DataType::List(_) => {
+            if let Some(indexes) = indexes {
+                let offsets = get_offsets_for_flatten(offsets, indexes);
+                flatten_internal(&values, Some(offsets))
+            } else {
+                flatten_internal(&values, Some(offsets))
+            }
+        }
+        // Reach the base level, create a new list array
+        _ => {
+            if let Some(indexes) = indexes {
+                let offsets = get_offsets_for_flatten(offsets, indexes);
+                let list_arr = ListArray::new(field, offsets, values, nulls);
+                Ok(list_arr)
+            } else {
+                Ok(list_arr.clone())
+            }
+        }
+    }
+}
+
+/// Flatten SQL function
+pub fn flatten(args: &[ArrayRef]) -> Result<ArrayRef> {
+    let flattened_array = flatten_internal(&args[0], None)?;
+    Ok(Arc::new(flattened_array) as ArrayRef)
+}
+
 /// Array_length SQL function
 pub fn array_length(args: &[ArrayRef]) -> Result<ArrayRef> {
     let list_array = as_list_array(&args[0])?;
diff --git a/datafusion/physical-expr/src/functions.rs 
b/datafusion/physical-expr/src/functions.rs
index df76d55bfc..d1a5119ee8 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -437,6 +437,10 @@ pub fn create_physical_fun(
         BuiltinScalarFunction::ArrayLength => {
             Arc::new(|args| 
make_scalar_function(array_expressions::array_length)(args))
         }
+        BuiltinScalarFunction::Flatten => {
+            Arc::new(|args| 
make_scalar_function(array_expressions::flatten)(args))
+        }
+
         BuiltinScalarFunction::ArrayNdims => {
             Arc::new(|args| 
make_scalar_function(array_expressions::array_ndims)(args))
         }
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 1081fca2e1..1a8ad093b9 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -594,6 +594,7 @@ enum ScalarFunction {
   ArrayRemoveAll = 109;
   ArrayReplaceAll = 110;
   Nanvl = 111;
+  Flatten = 112;
 }
 
 message ScalarFunctionNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 8691487c72..a33d80be9d 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -18944,6 +18944,7 @@ impl serde::Serialize for ScalarFunction {
             Self::ArrayRemoveAll => "ArrayRemoveAll",
             Self::ArrayReplaceAll => "ArrayReplaceAll",
             Self::Nanvl => "Nanvl",
+            Self::Flatten => "Flatten",
         };
         serializer.serialize_str(variant)
     }
@@ -19067,6 +19068,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
             "ArrayRemoveAll",
             "ArrayReplaceAll",
             "Nanvl",
+            "Flatten",
         ];
 
         struct GeneratedVisitor;
@@ -19221,6 +19223,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
                     "ArrayRemoveAll" => Ok(ScalarFunction::ArrayRemoveAll),
                     "ArrayReplaceAll" => Ok(ScalarFunction::ArrayReplaceAll),
                     "Nanvl" => Ok(ScalarFunction::Nanvl),
+                    "Flatten" => Ok(ScalarFunction::Flatten),
                     _ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
                 }
             }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 87371ba277..519cd002df 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2374,6 +2374,7 @@ pub enum ScalarFunction {
     ArrayRemoveAll = 109,
     ArrayReplaceAll = 110,
     Nanvl = 111,
+    Flatten = 112,
 }
 impl ScalarFunction {
     /// String value of the enum field names used in the ProtoBuf definition.
@@ -2494,6 +2495,7 @@ impl ScalarFunction {
             ScalarFunction::ArrayRemoveAll => "ArrayRemoveAll",
             ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll",
             ScalarFunction::Nanvl => "Nanvl",
+            ScalarFunction::Flatten => "Flatten",
         }
     }
     /// Creates an enum from field names used in the ProtoBuf definition.
@@ -2611,6 +2613,7 @@ impl ScalarFunction {
             "ArrayRemoveAll" => Some(Self::ArrayRemoveAll),
             "ArrayReplaceAll" => Some(Self::ArrayReplaceAll),
             "Nanvl" => Some(Self::Nanvl),
+            "Flatten" => Some(Self::Flatten),
             _ => None,
         }
     }
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index c17d8dbd8c..d2e037aa4b 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -457,6 +457,7 @@ impl From<&protobuf::ScalarFunction> for 
BuiltinScalarFunction {
             ScalarFunction::ArrayHas => Self::ArrayHas,
             ScalarFunction::ArrayDims => Self::ArrayDims,
             ScalarFunction::ArrayElement => Self::ArrayElement,
+            ScalarFunction::Flatten => Self::Flatten,
             ScalarFunction::ArrayLength => Self::ArrayLength,
             ScalarFunction::ArrayNdims => Self::ArrayNdims,
             ScalarFunction::ArrayPosition => Self::ArrayPosition,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index aa1132e8b1..cdb9b00803 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1456,6 +1456,7 @@ impl TryFrom<&BuiltinScalarFunction> for 
protobuf::ScalarFunction {
             BuiltinScalarFunction::ArrayHas => Self::ArrayHas,
             BuiltinScalarFunction::ArrayDims => Self::ArrayDims,
             BuiltinScalarFunction::ArrayElement => Self::ArrayElement,
+            BuiltinScalarFunction::Flatten => Self::Flatten,
             BuiltinScalarFunction::ArrayLength => Self::ArrayLength,
             BuiltinScalarFunction::ArrayNdims => Self::ArrayNdims,
             BuiltinScalarFunction::ArrayPosition => Self::ArrayPosition,
diff --git a/docs/source/user-guide/expressions.md 
b/docs/source/user-guide/expressions.md
index a04f43fd4b..88a5a73a6d 100644
--- a/docs/source/user-guide/expressions.md
+++ b/docs/source/user-guide/expressions.md
@@ -188,6 +188,7 @@ Unlike to some databases the math functions in Datafusion 
works the same way as
 | array_has_any(array, sub-array)       | Returns true if any elements exist 
in both arrays `array_has_any([1,2,3], [1,4]) -> true`                          
                                                      |
 | array_dims(array)                     | Returns an array of the array's 
dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]`                      
                                                         |
 | array_element(array, index)           | Extracts the element with the index 
n from the array `array_element([1, 2, 3, 4], 3) -> 3`                          
                                                     |
+| flatten(array)                        | Converts an array of arrays to a 
flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]`            
                                                        |
 | array_length(array, dimension)        | Returns the length of the array 
dimension. `array_length([1, 2, 3, 4, 5]) -> 5`                                 
                                                         |
 | array_ndims(array)                    | Returns the number of dimensions of 
the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2`                           
                                                     |
 | array_position(array, element)        | Searches for an element in the 
array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2`      
                                                          |
diff --git a/docs/source/user-guide/sql/scalar_functions.md 
b/docs/source/user-guide/sql/scalar_functions.md
index dec120db18..9bcf2ae0b0 100644
--- a/docs/source/user-guide/sql/scalar_functions.md
+++ b/docs/source/user-guide/sql/scalar_functions.md
@@ -1685,6 +1685,24 @@ array_fill(element, array)
   Can be a constant, column, or function, and any combination of array 
operators.
 - **element**: Element to copy to the array.
 
+### `flatten`
+
+Converts an array of arrays to a flat array
+
+- Applies to any depth of nested arrays
+- Does not change arrays that are already flat
+
+The flattened array contains all the elements from all source arrays.
+
+#### Arguments
+
+- **array**: Array expression
+  Can be a constant, column, or function, and any combination of array 
operators.
+
+```
+flatten(array)
+```
+
 ### `array_indexof`
 
 _Alias of [array_position](#array_position)._

Reply via email to