This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 54a02470fc make `array_union`/`array_except`/`array_intersect` handle
empty/null arrays rightly (#8269)
54a02470fc is described below
commit 54a02470fc9304110a0995b5e540bc247e0a2c6e
Author: 谭巍 <[email protected]>
AuthorDate: Wed Nov 22 04:28:04 2023 +0800
make `array_union`/`array_except`/`array_intersect` handle empty/null
arrays rightly (#8269)
* make array_union handle empty/null arrays rightly
Signed-off-by: veeupup <[email protected]>
* make array_except handle empty/null arrays rightly
Signed-off-by: veeupup <[email protected]>
* make array_intersect handle empty/null arrays rightly
Signed-off-by: veeupup <[email protected]>
* fix sql_array_literal
Signed-off-by: veeupup <[email protected]>
* fix comments
---------
Signed-off-by: veeupup <[email protected]>
---
datafusion/expr/src/built_in_function.rs | 18 ++-
datafusion/physical-expr/src/array_expressions.rs | 137 +++++++++++++---------
datafusion/sql/src/expr/value.rs | 34 +++---
datafusion/sql/tests/sql_integration.rs | 22 ----
datafusion/sqllogictest/test_files/aggregate.slt | 4 +-
datafusion/sqllogictest/test_files/array.slt | 52 +++++++-
6 files changed, 164 insertions(+), 103 deletions(-)
diff --git a/datafusion/expr/src/built_in_function.rs
b/datafusion/expr/src/built_in_function.rs
index e9030ebcc0..cbf5d400ba 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -599,12 +599,24 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplaceAll =>
Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArraySlice =>
Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
- BuiltinScalarFunction::ArrayIntersect =>
Ok(input_expr_types[0].clone()),
- BuiltinScalarFunction::ArrayUnion =>
Ok(input_expr_types[0].clone()),
+ BuiltinScalarFunction::ArrayUnion |
BuiltinScalarFunction::ArrayIntersect => {
+ match (input_expr_types[0].clone(),
input_expr_types[1].clone()) {
+ (DataType::Null, dt) => Ok(dt),
+ (dt, DataType::Null) => Ok(dt),
+ (dt, _) => Ok(dt),
+ }
+ }
BuiltinScalarFunction::Range => {
Ok(List(Arc::new(Field::new("item", Int64, true))))
}
- BuiltinScalarFunction::ArrayExcept =>
Ok(input_expr_types[0].clone()),
+ BuiltinScalarFunction::ArrayExcept => {
+ match (input_expr_types[0].clone(),
input_expr_types[1].clone()) {
+ (DataType::Null, _) | (_, DataType::Null) => {
+ Ok(input_expr_types[0].clone())
+ }
+ (dt, _) => Ok(dt),
+ }
+ }
BuiltinScalarFunction::Cardinality => Ok(UInt64),
BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
diff --git a/datafusion/physical-expr/src/array_expressions.rs
b/datafusion/physical-expr/src/array_expressions.rs
index c0f6c67263..8968bcf2ea 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -228,10 +228,10 @@ fn compute_array_dims(arr: Option<ArrayRef>) ->
Result<Option<Vec<Option<u64>>>>
fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
let data_type = args[0].data_type();
- if !args
- .iter()
- .all(|arg| arg.data_type().equals_datatype(data_type))
- {
+ if !args.iter().all(|arg| {
+ arg.data_type().equals_datatype(data_type)
+ || arg.data_type().equals_datatype(&DataType::Null)
+ }) {
let types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
return plan_err!("{name} received incompatible types: '{types:?}'.");
}
@@ -1512,19 +1512,29 @@ pub fn array_union(args: &[ArrayRef]) ->
Result<ArrayRef> {
match (array1.data_type(), array2.data_type()) {
(DataType::Null, _) => Ok(array2.clone()),
(_, DataType::Null) => Ok(array1.clone()),
- (DataType::List(field_ref), DataType::List(_)) => {
- check_datatypes("array_union", &[array1, array2])?;
- let list1 = array1.as_list::<i32>();
- let list2 = array2.as_list::<i32>();
- let result = union_generic_lists::<i32>(list1, list2, field_ref)?;
- Ok(Arc::new(result))
+ (DataType::List(l_field_ref), DataType::List(r_field_ref)) => {
+ match (l_field_ref.data_type(), r_field_ref.data_type()) {
+ (DataType::Null, _) => Ok(array2.clone()),
+ (_, DataType::Null) => Ok(array1.clone()),
+ (_, _) => {
+ let list1 = array1.as_list::<i32>();
+ let list2 = array2.as_list::<i32>();
+ let result = union_generic_lists::<i32>(list1, list2,
l_field_ref)?;
+ Ok(Arc::new(result))
+ }
+ }
}
- (DataType::LargeList(field_ref), DataType::LargeList(_)) => {
- check_datatypes("array_union", &[array1, array2])?;
- let list1 = array1.as_list::<i64>();
- let list2 = array2.as_list::<i64>();
- let result = union_generic_lists::<i64>(list1, list2, field_ref)?;
- Ok(Arc::new(result))
+ (DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref))
=> {
+ match (l_field_ref.data_type(), r_field_ref.data_type()) {
+ (DataType::Null, _) => Ok(array2.clone()),
+ (_, DataType::Null) => Ok(array1.clone()),
+ (_, _) => {
+ let list1 = array1.as_list::<i64>();
+ let list2 = array2.as_list::<i64>();
+ let result = union_generic_lists::<i64>(list1, list2,
l_field_ref)?;
+ Ok(Arc::new(result))
+ }
+ }
}
_ => {
internal_err!(
@@ -1919,55 +1929,66 @@ pub fn string_to_array<T: OffsetSizeTrait>(args:
&[ArrayRef]) -> Result<ArrayRef
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
assert_eq!(args.len(), 2);
- let first_array = as_list_array(&args[0])?;
- let second_array = as_list_array(&args[1])?;
+ let first_array = &args[0];
+ let second_array = &args[1];
- if first_array.value_type() != second_array.value_type() {
- return internal_err!("array_intersect is not implemented for
'{first_array:?}' and '{second_array:?}'");
- }
- let dt = first_array.value_type();
+ match (first_array.data_type(), second_array.data_type()) {
+ (DataType::Null, _) => Ok(second_array.clone()),
+ (_, DataType::Null) => Ok(first_array.clone()),
+ _ => {
+ let first_array = as_list_array(&first_array)?;
+ let second_array = as_list_array(&second_array)?;
- let mut offsets = vec![0];
- let mut new_arrays = vec![];
-
- let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
- for (first_arr, second_arr) in first_array.iter().zip(second_array.iter())
{
- if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
- let l_values = converter.convert_columns(&[first_arr])?;
- let r_values = converter.convert_columns(&[second_arr])?;
-
- let values_set: HashSet<_> = l_values.iter().collect();
- let mut rows = Vec::with_capacity(r_values.num_rows());
- for r_val in r_values.iter().sorted().dedup() {
- if values_set.contains(&r_val) {
- rows.push(r_val);
- }
+ if first_array.value_type() != second_array.value_type() {
+ return internal_err!("array_intersect is not implemented for
'{first_array:?}' and '{second_array:?}'");
}
- let last_offset: i32 = match offsets.last().copied() {
- Some(offset) => offset,
- None => return internal_err!("offsets should not be empty"),
- };
- offsets.push(last_offset + rows.len() as i32);
- let arrays = converter.convert_rows(rows)?;
- let array = match arrays.get(0) {
- Some(array) => array.clone(),
- None => {
- return internal_err!(
- "array_intersect: failed to get array from rows"
- )
+ let dt = first_array.value_type();
+
+ let mut offsets = vec![0];
+ let mut new_arrays = vec![];
+
+ let converter =
RowConverter::new(vec![SortField::new(dt.clone())])?;
+ for (first_arr, second_arr) in
first_array.iter().zip(second_array.iter()) {
+ if let (Some(first_arr), Some(second_arr)) = (first_arr,
second_arr) {
+ let l_values = converter.convert_columns(&[first_arr])?;
+ let r_values = converter.convert_columns(&[second_arr])?;
+
+ let values_set: HashSet<_> = l_values.iter().collect();
+ let mut rows = Vec::with_capacity(r_values.num_rows());
+ for r_val in r_values.iter().sorted().dedup() {
+ if values_set.contains(&r_val) {
+ rows.push(r_val);
+ }
+ }
+
+ let last_offset: i32 = match offsets.last().copied() {
+ Some(offset) => offset,
+ None => return internal_err!("offsets should not be
empty"),
+ };
+ offsets.push(last_offset + rows.len() as i32);
+ let arrays = converter.convert_rows(rows)?;
+ let array = match arrays.get(0) {
+ Some(array) => array.clone(),
+ None => {
+ return internal_err!(
+ "array_intersect: failed to get array from
rows"
+ )
+ }
+ };
+ new_arrays.push(array);
}
- };
- new_arrays.push(array);
+ }
+
+ let field = Arc::new(Field::new("item", dt, true));
+ let offsets = OffsetBuffer::new(offsets.into());
+ let new_arrays_ref =
+ new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
+ let values = compute::concat(&new_arrays_ref)?;
+ let arr = Arc::new(ListArray::try_new(field, offsets, values,
None)?);
+ Ok(arr)
}
}
-
- let field = Arc::new(Field::new("item", dt, true));
- let offsets = OffsetBuffer::new(offsets.into());
- let new_arrays_ref = new_arrays.iter().map(|v|
v.as_ref()).collect::<Vec<_>>();
- let values = compute::concat(&new_arrays_ref)?;
- let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?);
- Ok(arr)
}
#[cfg(test)]
diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs
index 3a06fdb158..0f086bca68 100644
--- a/datafusion/sql/src/expr/value.rs
+++ b/datafusion/sql/src/expr/value.rs
@@ -16,20 +16,20 @@
// under the License.
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
-use arrow::array::new_null_array;
use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano;
use arrow::datatypes::DECIMAL128_MAX_PRECISION;
use arrow_schema::DataType;
use datafusion_common::{
not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue,
};
+use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::expr::{BinaryExpr, Placeholder};
+use datafusion_expr::BuiltinScalarFunction;
use datafusion_expr::{lit, Expr, Operator};
use log::debug;
use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value};
use sqlparser::parser::ParserError::ParserError;
use std::borrow::Cow;
-use std::collections::HashSet;
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(crate) fn parse_value(
@@ -138,9 +138,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema,
&mut PlannerContext::new(),
)?;
+
match value {
- Expr::Literal(scalar) => {
- values.push(scalar);
+ Expr::Literal(_) => {
+ values.push(value);
+ }
+ Expr::ScalarFunction(ref scalar_function) => {
+ if scalar_function.fun == BuiltinScalarFunction::MakeArray
{
+ values.push(value);
+ } else {
+ return not_impl_err!(
+ "ScalarFunctions without MakeArray are not
supported: {value}"
+ );
+ }
}
_ => {
return not_impl_err!(
@@ -150,18 +160,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}
- let data_types: HashSet<DataType> =
- values.iter().map(|e| e.data_type()).collect();
-
- if data_types.is_empty() {
- Ok(lit(ScalarValue::List(new_null_array(&DataType::Null, 0))))
- } else if data_types.len() > 1 {
- not_impl_err!("Arrays with different types are not supported:
{data_types:?}")
- } else {
- let data_type = values[0].data_type();
- let arr = ScalarValue::new_list(&values, &data_type);
- Ok(lit(ScalarValue::List(arr)))
- }
+ Ok(Expr::ScalarFunction(ScalarFunction::new(
+ BuiltinScalarFunction::MakeArray,
+ values,
+ )))
}
/// Convert a SQL interval expression to a DataFusion logical plan
diff --git a/datafusion/sql/tests/sql_integration.rs
b/datafusion/sql/tests/sql_integration.rs
index 4c2bad1c71..a56e9a50f0 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -1383,18 +1383,6 @@ fn select_interval_out_of_range() {
);
}
-#[test]
-fn select_array_no_common_type() {
- let sql = "SELECT [1, true, null]";
- let err = logical_plan(sql).expect_err("query should have failed");
-
- // HashSet doesn't guarantee order
- assert_contains!(
- err.strip_backtrace(),
- "This feature is not implemented: Arrays with different types are not
supported: "
- );
-}
-
#[test]
fn recursive_ctes() {
let sql = "
@@ -1411,16 +1399,6 @@ fn recursive_ctes() {
);
}
-#[test]
-fn select_array_non_literal_type() {
- let sql = "SELECT [now()]";
- let err = logical_plan(sql).expect_err("query should have failed");
- assert_eq!(
- "This feature is not implemented: Arrays with elements other than
literal are not supported: now()",
- err.strip_backtrace()
- );
-}
-
#[test]
fn
select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() {
quick_test(
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index faad6feb3f..7157be9489 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -1396,7 +1396,7 @@ SELECT COUNT(DISTINCT c1) FROM test
query ?
SELECT ARRAY_AGG([])
----
-[]
+[[]]
# array_agg_one
query ?
@@ -1419,7 +1419,7 @@ e 4
query ?
SELECT ARRAY_AGG([]);
----
-[]
+[[]]
# array_agg_one
query ?
diff --git a/datafusion/sqllogictest/test_files/array.slt
b/datafusion/sqllogictest/test_files/array.slt
index 61f190e7ba..d33555509e 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -265,6 +265,14 @@ AS VALUES
(make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32,
33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]),
[28, 29, 30], [37, 38, 39], 10)
;
+query ?
+select [1, true, null]
+----
+[1, 1, ]
+
+query error DataFusion error: This feature is not implemented: ScalarFunctions
without MakeArray are not supported: now()
+SELECT [now()]
+
query TTT
select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3)
from arrays;
----
@@ -2014,7 +2022,7 @@ drop table arrays_with_repeating_elements_for_union;
query ?
select array_union([], []);
----
-NULL
+[]
# array_union scalar function #7
query ?
@@ -2032,7 +2040,7 @@ select array_union([null], [null]);
query ?
select array_union(null, []);
----
-NULL
+[]
# array_union scalar function #10
query ?
@@ -2687,6 +2695,26 @@ SELECT array_intersect(make_array(1,2,3),
make_array(2,3,4)),
----
[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]]
+query ?
+select array_intersect([], []);
+----
+[]
+
+query ?
+select array_intersect([], null);
+----
+[]
+
+query ?
+select array_intersect(null, []);
+----
+[]
+
+query ?
+select array_intersect(null, null);
+----
+NULL
+
query ??????
SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)),
list_intersect(make_array(1,3,5), make_array(2,4,6)),
@@ -2842,6 +2870,26 @@ NULL
statement ok
drop table array_except_table_bool;
+query ?
+select array_except([], null);
+----
+[]
+
+query ?
+select array_except([], []);
+----
+[]
+
+query ?
+select array_except(null, []);
+----
+NULL
+
+query ?
+select array_except(null, null)
+----
+NULL
+
### Array operators tests