This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 54a02470fc make `array_union`/`array_except`/`array_intersect` handle 
empty/null arrays rightly (#8269)
54a02470fc is described below

commit 54a02470fc9304110a0995b5e540bc247e0a2c6e
Author: 谭巍 <[email protected]>
AuthorDate: Wed Nov 22 04:28:04 2023 +0800

    make `array_union`/`array_except`/`array_intersect` handle empty/null 
arrays rightly (#8269)
    
    * make array_union handle empty/null arrays rightly
    
    Signed-off-by: veeupup <[email protected]>
    
    * make array_except handle empty/null arrays rightly
    
    Signed-off-by: veeupup <[email protected]>
    
    * make array_intersect handle empty/null arrays rightly
    
    Signed-off-by: veeupup <[email protected]>
    
    * fix  sql_array_literal
    
    Signed-off-by: veeupup <[email protected]>
    
    * fix comments
    
    ---------
    
    Signed-off-by: veeupup <[email protected]>
---
 datafusion/expr/src/built_in_function.rs          |  18 ++-
 datafusion/physical-expr/src/array_expressions.rs | 137 +++++++++++++---------
 datafusion/sql/src/expr/value.rs                  |  34 +++---
 datafusion/sql/tests/sql_integration.rs           |  22 ----
 datafusion/sqllogictest/test_files/aggregate.slt  |   4 +-
 datafusion/sqllogictest/test_files/array.slt      |  52 +++++++-
 6 files changed, 164 insertions(+), 103 deletions(-)

diff --git a/datafusion/expr/src/built_in_function.rs 
b/datafusion/expr/src/built_in_function.rs
index e9030ebcc0..cbf5d400ba 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -599,12 +599,24 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::ArrayReplaceAll => 
Ok(input_expr_types[0].clone()),
             BuiltinScalarFunction::ArraySlice => 
Ok(input_expr_types[0].clone()),
             BuiltinScalarFunction::ArrayToString => Ok(Utf8),
-            BuiltinScalarFunction::ArrayIntersect => 
Ok(input_expr_types[0].clone()),
-            BuiltinScalarFunction::ArrayUnion => 
Ok(input_expr_types[0].clone()),
+            BuiltinScalarFunction::ArrayUnion | 
BuiltinScalarFunction::ArrayIntersect => {
+                match (input_expr_types[0].clone(), 
input_expr_types[1].clone()) {
+                    (DataType::Null, dt) => Ok(dt),
+                    (dt, DataType::Null) => Ok(dt),
+                    (dt, _) => Ok(dt),
+                }
+            }
             BuiltinScalarFunction::Range => {
                 Ok(List(Arc::new(Field::new("item", Int64, true))))
             }
-            BuiltinScalarFunction::ArrayExcept => 
Ok(input_expr_types[0].clone()),
+            BuiltinScalarFunction::ArrayExcept => {
+                match (input_expr_types[0].clone(), 
input_expr_types[1].clone()) {
+                    (DataType::Null, _) | (_, DataType::Null) => {
+                        Ok(input_expr_types[0].clone())
+                    }
+                    (dt, _) => Ok(dt),
+                }
+            }
             BuiltinScalarFunction::Cardinality => Ok(UInt64),
             BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
                 0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
diff --git a/datafusion/physical-expr/src/array_expressions.rs 
b/datafusion/physical-expr/src/array_expressions.rs
index c0f6c67263..8968bcf2ea 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -228,10 +228,10 @@ fn compute_array_dims(arr: Option<ArrayRef>) -> 
Result<Option<Vec<Option<u64>>>>
 
 fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
     let data_type = args[0].data_type();
-    if !args
-        .iter()
-        .all(|arg| arg.data_type().equals_datatype(data_type))
-    {
+    if !args.iter().all(|arg| {
+        arg.data_type().equals_datatype(data_type)
+            || arg.data_type().equals_datatype(&DataType::Null)
+    }) {
         let types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
         return plan_err!("{name} received incompatible types: '{types:?}'.");
     }
@@ -1512,19 +1512,29 @@ pub fn array_union(args: &[ArrayRef]) -> 
Result<ArrayRef> {
     match (array1.data_type(), array2.data_type()) {
         (DataType::Null, _) => Ok(array2.clone()),
         (_, DataType::Null) => Ok(array1.clone()),
-        (DataType::List(field_ref), DataType::List(_)) => {
-            check_datatypes("array_union", &[array1, array2])?;
-            let list1 = array1.as_list::<i32>();
-            let list2 = array2.as_list::<i32>();
-            let result = union_generic_lists::<i32>(list1, list2, field_ref)?;
-            Ok(Arc::new(result))
+        (DataType::List(l_field_ref), DataType::List(r_field_ref)) => {
+            match (l_field_ref.data_type(), r_field_ref.data_type()) {
+                (DataType::Null, _) => Ok(array2.clone()),
+                (_, DataType::Null) => Ok(array1.clone()),
+                (_, _) => {
+                    let list1 = array1.as_list::<i32>();
+                    let list2 = array2.as_list::<i32>();
+                    let result = union_generic_lists::<i32>(list1, list2, 
l_field_ref)?;
+                    Ok(Arc::new(result))
+                }
+            }
         }
-        (DataType::LargeList(field_ref), DataType::LargeList(_)) => {
-            check_datatypes("array_union", &[array1, array2])?;
-            let list1 = array1.as_list::<i64>();
-            let list2 = array2.as_list::<i64>();
-            let result = union_generic_lists::<i64>(list1, list2, field_ref)?;
-            Ok(Arc::new(result))
+        (DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref)) 
=> {
+            match (l_field_ref.data_type(), r_field_ref.data_type()) {
+                (DataType::Null, _) => Ok(array2.clone()),
+                (_, DataType::Null) => Ok(array1.clone()),
+                (_, _) => {
+                    let list1 = array1.as_list::<i64>();
+                    let list2 = array2.as_list::<i64>();
+                    let result = union_generic_lists::<i64>(list1, list2, 
l_field_ref)?;
+                    Ok(Arc::new(result))
+                }
+            }
         }
         _ => {
             internal_err!(
@@ -1919,55 +1929,66 @@ pub fn string_to_array<T: OffsetSizeTrait>(args: 
&[ArrayRef]) -> Result<ArrayRef
 pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
     assert_eq!(args.len(), 2);
 
-    let first_array = as_list_array(&args[0])?;
-    let second_array = as_list_array(&args[1])?;
+    let first_array = &args[0];
+    let second_array = &args[1];
 
-    if first_array.value_type() != second_array.value_type() {
-        return internal_err!("array_intersect is not implemented for 
'{first_array:?}' and '{second_array:?}'");
-    }
-    let dt = first_array.value_type();
+    match (first_array.data_type(), second_array.data_type()) {
+        (DataType::Null, _) => Ok(second_array.clone()),
+        (_, DataType::Null) => Ok(first_array.clone()),
+        _ => {
+            let first_array = as_list_array(&first_array)?;
+            let second_array = as_list_array(&second_array)?;
 
-    let mut offsets = vec![0];
-    let mut new_arrays = vec![];
-
-    let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
-    for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) 
{
-        if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
-            let l_values = converter.convert_columns(&[first_arr])?;
-            let r_values = converter.convert_columns(&[second_arr])?;
-
-            let values_set: HashSet<_> = l_values.iter().collect();
-            let mut rows = Vec::with_capacity(r_values.num_rows());
-            for r_val in r_values.iter().sorted().dedup() {
-                if values_set.contains(&r_val) {
-                    rows.push(r_val);
-                }
+            if first_array.value_type() != second_array.value_type() {
+                return internal_err!("array_intersect is not implemented for 
'{first_array:?}' and '{second_array:?}'");
             }
 
-            let last_offset: i32 = match offsets.last().copied() {
-                Some(offset) => offset,
-                None => return internal_err!("offsets should not be empty"),
-            };
-            offsets.push(last_offset + rows.len() as i32);
-            let arrays = converter.convert_rows(rows)?;
-            let array = match arrays.get(0) {
-                Some(array) => array.clone(),
-                None => {
-                    return internal_err!(
-                        "array_intersect: failed to get array from rows"
-                    )
+            let dt = first_array.value_type();
+
+            let mut offsets = vec![0];
+            let mut new_arrays = vec![];
+
+            let converter = 
RowConverter::new(vec![SortField::new(dt.clone())])?;
+            for (first_arr, second_arr) in 
first_array.iter().zip(second_array.iter()) {
+                if let (Some(first_arr), Some(second_arr)) = (first_arr, 
second_arr) {
+                    let l_values = converter.convert_columns(&[first_arr])?;
+                    let r_values = converter.convert_columns(&[second_arr])?;
+
+                    let values_set: HashSet<_> = l_values.iter().collect();
+                    let mut rows = Vec::with_capacity(r_values.num_rows());
+                    for r_val in r_values.iter().sorted().dedup() {
+                        if values_set.contains(&r_val) {
+                            rows.push(r_val);
+                        }
+                    }
+
+                    let last_offset: i32 = match offsets.last().copied() {
+                        Some(offset) => offset,
+                        None => return internal_err!("offsets should not be 
empty"),
+                    };
+                    offsets.push(last_offset + rows.len() as i32);
+                    let arrays = converter.convert_rows(rows)?;
+                    let array = match arrays.get(0) {
+                        Some(array) => array.clone(),
+                        None => {
+                            return internal_err!(
+                                "array_intersect: failed to get array from 
rows"
+                            )
+                        }
+                    };
+                    new_arrays.push(array);
                 }
-            };
-            new_arrays.push(array);
+            }
+
+            let field = Arc::new(Field::new("item", dt, true));
+            let offsets = OffsetBuffer::new(offsets.into());
+            let new_arrays_ref =
+                new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
+            let values = compute::concat(&new_arrays_ref)?;
+            let arr = Arc::new(ListArray::try_new(field, offsets, values, 
None)?);
+            Ok(arr)
         }
     }
-
-    let field = Arc::new(Field::new("item", dt, true));
-    let offsets = OffsetBuffer::new(offsets.into());
-    let new_arrays_ref = new_arrays.iter().map(|v| 
v.as_ref()).collect::<Vec<_>>();
-    let values = compute::concat(&new_arrays_ref)?;
-    let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?);
-    Ok(arr)
 }
 
 #[cfg(test)]
diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs
index 3a06fdb158..0f086bca68 100644
--- a/datafusion/sql/src/expr/value.rs
+++ b/datafusion/sql/src/expr/value.rs
@@ -16,20 +16,20 @@
 // under the License.
 
 use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
-use arrow::array::new_null_array;
 use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano;
 use arrow::datatypes::DECIMAL128_MAX_PRECISION;
 use arrow_schema::DataType;
 use datafusion_common::{
     not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue,
 };
+use datafusion_expr::expr::ScalarFunction;
 use datafusion_expr::expr::{BinaryExpr, Placeholder};
+use datafusion_expr::BuiltinScalarFunction;
 use datafusion_expr::{lit, Expr, Operator};
 use log::debug;
 use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value};
 use sqlparser::parser::ParserError::ParserError;
 use std::borrow::Cow;
-use std::collections::HashSet;
 
 impl<'a, S: ContextProvider> SqlToRel<'a, S> {
     pub(crate) fn parse_value(
@@ -138,9 +138,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 schema,
                 &mut PlannerContext::new(),
             )?;
+
             match value {
-                Expr::Literal(scalar) => {
-                    values.push(scalar);
+                Expr::Literal(_) => {
+                    values.push(value);
+                }
+                Expr::ScalarFunction(ref scalar_function) => {
+                    if scalar_function.fun == BuiltinScalarFunction::MakeArray 
{
+                        values.push(value);
+                    } else {
+                        return not_impl_err!(
+                            "ScalarFunctions without MakeArray are not 
supported: {value}"
+                        );
+                    }
                 }
                 _ => {
                     return not_impl_err!(
@@ -150,18 +160,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             }
         }
 
-        let data_types: HashSet<DataType> =
-            values.iter().map(|e| e.data_type()).collect();
-
-        if data_types.is_empty() {
-            Ok(lit(ScalarValue::List(new_null_array(&DataType::Null, 0))))
-        } else if data_types.len() > 1 {
-            not_impl_err!("Arrays with different types are not supported: 
{data_types:?}")
-        } else {
-            let data_type = values[0].data_type();
-            let arr = ScalarValue::new_list(&values, &data_type);
-            Ok(lit(ScalarValue::List(arr)))
-        }
+        Ok(Expr::ScalarFunction(ScalarFunction::new(
+            BuiltinScalarFunction::MakeArray,
+            values,
+        )))
     }
 
     /// Convert a SQL interval expression to a DataFusion logical plan
diff --git a/datafusion/sql/tests/sql_integration.rs 
b/datafusion/sql/tests/sql_integration.rs
index 4c2bad1c71..a56e9a50f0 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -1383,18 +1383,6 @@ fn select_interval_out_of_range() {
     );
 }
 
-#[test]
-fn select_array_no_common_type() {
-    let sql = "SELECT [1, true, null]";
-    let err = logical_plan(sql).expect_err("query should have failed");
-
-    // HashSet doesn't guarantee order
-    assert_contains!(
-        err.strip_backtrace(),
-        "This feature is not implemented: Arrays with different types are not 
supported: "
-    );
-}
-
 #[test]
 fn recursive_ctes() {
     let sql = "
@@ -1411,16 +1399,6 @@ fn recursive_ctes() {
     );
 }
 
-#[test]
-fn select_array_non_literal_type() {
-    let sql = "SELECT [now()]";
-    let err = logical_plan(sql).expect_err("query should have failed");
-    assert_eq!(
-        "This feature is not implemented: Arrays with elements other than 
literal are not supported: now()",
-        err.strip_backtrace()
-    );
-}
-
 #[test]
 fn 
select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() {
     quick_test(
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index faad6feb3f..7157be9489 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -1396,7 +1396,7 @@ SELECT COUNT(DISTINCT c1) FROM test
 query ?
 SELECT ARRAY_AGG([])
 ----
-[]
+[[]]
 
 # array_agg_one
 query ?
@@ -1419,7 +1419,7 @@ e 4
 query ?
 SELECT ARRAY_AGG([]);
 ----
-[]
+[[]]
 
 # array_agg_one
 query ?
diff --git a/datafusion/sqllogictest/test_files/array.slt 
b/datafusion/sqllogictest/test_files/array.slt
index 61f190e7ba..d33555509e 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -265,6 +265,14 @@ AS VALUES
   (make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 
33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), 
[28, 29, 30], [37, 38, 39], 10)
 ;
 
+query ?
+select [1, true, null]
+----
+[1, 1, ]
+
+query error DataFusion error: This feature is not implemented: ScalarFunctions 
without MakeArray are not supported: now()
+SELECT [now()]
+
 query TTT
 select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) 
from arrays;
 ----
@@ -2014,7 +2022,7 @@ drop table arrays_with_repeating_elements_for_union;
 query ?
 select array_union([], []);
 ----
-NULL
+[]
 
 # array_union scalar function #7
 query ?
@@ -2032,7 +2040,7 @@ select array_union([null], [null]);
 query ?
 select array_union(null, []);
 ----
-NULL
+[]
 
 # array_union scalar function #10
 query ?
@@ -2687,6 +2695,26 @@ SELECT  array_intersect(make_array(1,2,3), 
make_array(2,3,4)),
 ----
 [2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]]
 
+query ?
+select array_intersect([], []);
+----
+[]
+
+query ?
+select array_intersect([], null);
+----
+[]
+
+query ?
+select array_intersect(null, []);
+----
+[]
+
+query ?
+select array_intersect(null, null);
+----
+NULL
+
 query ??????
 SELECT  list_intersect(make_array(1,2,3), make_array(2,3,4)),
         list_intersect(make_array(1,3,5), make_array(2,4,6)),
@@ -2842,6 +2870,26 @@ NULL
 statement ok
 drop table array_except_table_bool;
 
+query ?
+select array_except([], null);
+----
+[]
+
+query ?
+select array_except([], []);
+----
+[]
+
+query ?
+select array_except(null, []);
+----
+NULL
+
+query ?
+select array_except(null, null)
+----
+NULL
+
 ### Array operators tests
 
 

Reply via email to