This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new cd02c40f75 Support array_distinct function. (#8268)
cd02c40f75 is described below
commit cd02c40f7575e331121a94cb217b71905e240f9f
Author: yi wang <[email protected]>
AuthorDate: Sat Dec 9 02:06:52 2023 +0800
Support array_distinct function. (#8268)
* implement distinct func
implement slt & proto
fix null & empty list
* add comment for slt
Co-authored-by: Alex Huang <[email protected]>
* fix largelist
* add largelist for slt
* Use collect for rows & init capcity for offsets.
* fixup: remove useless match
* fix fmt
* fix fmt
---------
Co-authored-by: Alex Huang <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/expr/src/built_in_function.rs | 6 ++
datafusion/expr/src/expr_fn.rs | 6 ++
datafusion/physical-expr/src/array_expressions.rs | 64 ++++++++++++++-
datafusion/physical-expr/src/functions.rs | 3 +
datafusion/proto/proto/datafusion.proto | 1 +
datafusion/proto/src/generated/pbjson.rs | 3 +
datafusion/proto/src/generated/prost.rs | 3 +
datafusion/proto/src/logical_plan/from_proto.rs | 22 ++---
datafusion/proto/src/logical_plan/to_proto.rs | 1 +
datafusion/sqllogictest/test_files/array.slt | 99 +++++++++++++++++++++++
docs/source/user-guide/expressions.md | 1 +
11 files changed, 198 insertions(+), 11 deletions(-)
diff --git a/datafusion/expr/src/built_in_function.rs
b/datafusion/expr/src/built_in_function.rs
index 44fbf45525..977b556b26 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -146,6 +146,8 @@ pub enum BuiltinScalarFunction {
ArrayPopBack,
/// array_dims
ArrayDims,
+ /// array_distinct
+ ArrayDistinct,
/// array_element
ArrayElement,
/// array_empty
@@ -407,6 +409,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable,
BuiltinScalarFunction::ArrayHas => Volatility::Immutable,
BuiltinScalarFunction::ArrayDims => Volatility::Immutable,
+ BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable,
BuiltinScalarFunction::ArrayElement => Volatility::Immutable,
BuiltinScalarFunction::ArrayExcept => Volatility::Immutable,
BuiltinScalarFunction::ArrayLength => Volatility::Immutable,
@@ -586,6 +589,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayDims => {
Ok(List(Arc::new(Field::new("item", UInt64, true))))
}
+ BuiltinScalarFunction::ArrayDistinct =>
Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] {
List(field) => Ok(field.data_type().clone()),
_ => plan_err!(
@@ -933,6 +937,7 @@ impl BuiltinScalarFunction {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayNdims => Signature::any(1,
self.volatility()),
+ BuiltinScalarFunction::ArrayDistinct => Signature::any(1,
self.volatility()),
BuiltinScalarFunction::ArrayPosition => {
Signature::variadic_any(self.volatility())
}
@@ -1570,6 +1575,7 @@ impl BuiltinScalarFunction {
&["array_concat", "array_cat", "list_concat", "list_cat"]
}
BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"],
+ BuiltinScalarFunction::ArrayDistinct => &["array_distinct",
"list_distinct"],
BuiltinScalarFunction::ArrayEmpty => &["empty"],
BuiltinScalarFunction::ArrayElement => &[
"array_element",
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 8d25619c07..cedf1d8451 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -660,6 +660,12 @@ scalar_expr!(
array,
"returns the number of dimensions of the array."
);
+scalar_expr!(
+ ArrayDistinct,
+ array_distinct,
+ array,
+ "return distinct values from the array after removing duplicates."
+);
scalar_expr!(
ArrayPosition,
array_position,
diff --git a/datafusion/physical-expr/src/array_expressions.rs
b/datafusion/physical-expr/src/array_expressions.rs
index 08df3ef9f6..ae04869458 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -31,8 +31,8 @@ use arrow_buffer::NullBuffer;
use arrow_schema::{FieldRef, SortOptions};
use datafusion_common::cast::{
- as_generic_list_array, as_generic_string_array, as_int64_array,
as_list_array,
- as_null_array, as_string_array,
+ as_generic_list_array, as_generic_string_array, as_int64_array,
as_large_list_array,
+ as_list_array, as_null_array, as_string_array,
};
use datafusion_common::utils::{array_into_list_array, list_ndims};
use datafusion_common::{
@@ -2111,6 +2111,66 @@ pub fn array_intersect(args: &[ArrayRef]) ->
Result<ArrayRef> {
}
}
+pub fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
+ array: &GenericListArray<OffsetSize>,
+ field: &FieldRef,
+) -> Result<ArrayRef> {
+ let dt = array.value_type();
+ let mut offsets = Vec::with_capacity(array.len());
+ offsets.push(OffsetSize::usize_as(0));
+ let mut new_arrays = Vec::with_capacity(array.len());
+ let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
+ // distinct for each list in ListArray
+ for arr in array.iter().flatten() {
+ let values = converter.convert_columns(&[arr])?;
+ // sort elements in list and remove duplicates
+ let rows = values.iter().sorted().dedup().collect::<Vec<_>>();
+ let last_offset: OffsetSize = offsets.last().copied().unwrap();
+ offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
+ let arrays = converter.convert_rows(rows)?;
+ let array = match arrays.get(0) {
+ Some(array) => array.clone(),
+ None => {
+ return internal_err!("array_distinct: failed to get array from
rows")
+ }
+ };
+ new_arrays.push(array);
+ }
+ let offsets = OffsetBuffer::new(offsets.into());
+ let new_arrays_ref = new_arrays.iter().map(|v|
v.as_ref()).collect::<Vec<_>>();
+ let values = compute::concat(&new_arrays_ref)?;
+ Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
+ field.clone(),
+ offsets,
+ values,
+ None,
+ )?))
+}
+
+/// array_distinct SQL function
+/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4]
+pub fn array_distinct(args: &[ArrayRef]) -> Result<ArrayRef> {
+ assert_eq!(args.len(), 1);
+
+ // handle null
+ if args[0].data_type() == &DataType::Null {
+ return Ok(args[0].clone());
+ }
+
+ // handle for list & largelist
+ match args[0].data_type() {
+ DataType::List(field) => {
+ let array = as_list_array(&args[0])?;
+ general_array_distinct(array, field)
+ }
+ DataType::LargeList(field) => {
+ let array = as_large_list_array(&args[0])?;
+ general_array_distinct(array, field)
+ }
+ _ => internal_err!("array_distinct only support list array"),
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/physical-expr/src/functions.rs
b/datafusion/physical-expr/src/functions.rs
index 873864a57a..53de858439 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -350,6 +350,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayDims => {
Arc::new(|args|
make_scalar_function(array_expressions::array_dims)(args))
}
+ BuiltinScalarFunction::ArrayDistinct => {
+ Arc::new(|args|
make_scalar_function(array_expressions::array_distinct)(args))
+ }
BuiltinScalarFunction::ArrayElement => {
Arc::new(|args|
make_scalar_function(array_expressions::array_element)(args))
}
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 55fb080423..13a54f2a56 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -645,6 +645,7 @@ enum ScalarFunction {
SubstrIndex = 126;
FindInSet = 127;
ArraySort = 128;
+ ArrayDistinct = 129;
}
message ScalarFunctionNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index dea329cbea..0d013c72d3 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -21049,6 +21049,7 @@ impl serde::Serialize for ScalarFunction {
Self::SubstrIndex => "SubstrIndex",
Self::FindInSet => "FindInSet",
Self::ArraySort => "ArraySort",
+ Self::ArrayDistinct => "ArrayDistinct",
};
serializer.serialize_str(variant)
}
@@ -21189,6 +21190,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"SubstrIndex",
"FindInSet",
"ArraySort",
+ "ArrayDistinct",
];
struct GeneratedVisitor;
@@ -21358,6 +21360,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"SubstrIndex" => Ok(ScalarFunction::SubstrIndex),
"FindInSet" => Ok(ScalarFunction::FindInSet),
"ArraySort" => Ok(ScalarFunction::ArraySort),
+ "ArrayDistinct" => Ok(ScalarFunction::ArrayDistinct),
_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 41b94a2a39..d4b62d4b3f 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2614,6 +2614,7 @@ pub enum ScalarFunction {
SubstrIndex = 126,
FindInSet = 127,
ArraySort = 128,
+ ArrayDistinct = 129,
}
impl ScalarFunction {
/// String value of the enum field names used in the ProtoBuf definition.
@@ -2751,6 +2752,7 @@ impl ScalarFunction {
ScalarFunction::SubstrIndex => "SubstrIndex",
ScalarFunction::FindInSet => "FindInSet",
ScalarFunction::ArraySort => "ArraySort",
+ ScalarFunction::ArrayDistinct => "ArrayDistinct",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
@@ -2885,6 +2887,7 @@ impl ScalarFunction {
"SubstrIndex" => Some(Self::SubstrIndex),
"FindInSet" => Some(Self::FindInSet),
"ArraySort" => Some(Self::ArraySort),
+ "ArrayDistinct" => Some(Self::ArrayDistinct),
_ => None,
}
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 7daab47837..193e0947d6 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -41,15 +41,15 @@ use datafusion_common::{
};
use datafusion_expr::window_frame::{check_window_frame,
regularize_window_order_by};
use datafusion_expr::{
- abs, acos, acosh, array, array_append, array_concat, array_dims,
array_element,
- array_except, array_has, array_has_all, array_has_any, array_intersect,
array_length,
- array_ndims, array_position, array_positions, array_prepend, array_remove,
- array_remove_all, array_remove_n, array_repeat, array_replace,
array_replace_all,
- array_replace_n, array_slice, array_sort, array_to_string, arrow_typeof,
ascii, asin,
- asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil,
- character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh,
cot,
- current_date, current_time, date_bin, date_part, date_trunc, decode,
degrees, digest,
- encode, exp,
+ abs, acos, acosh, array, array_append, array_concat, array_dims,
array_distinct,
+ array_element, array_except, array_has, array_has_all, array_has_any,
+ array_intersect, array_length, array_ndims, array_position,
array_positions,
+ array_prepend, array_remove, array_remove_all, array_remove_n,
array_repeat,
+ array_replace, array_replace_all, array_replace_n, array_slice, array_sort,
+ array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh,
bit_length,
+ btrim, cardinality, cbrt, ceil, character_length, chr, coalesce,
concat_expr,
+ concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin,
date_part,
+ date_trunc, decode, degrees, digest, encode, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range,
isnan, iszero,
lcm, left, levenshtein, ln, log, log10, log2,
@@ -484,6 +484,7 @@ impl From<&protobuf::ScalarFunction> for
BuiltinScalarFunction {
ScalarFunction::ArrayHasAny => Self::ArrayHasAny,
ScalarFunction::ArrayHas => Self::ArrayHas,
ScalarFunction::ArrayDims => Self::ArrayDims,
+ ScalarFunction::ArrayDistinct => Self::ArrayDistinct,
ScalarFunction::ArrayElement => Self::ArrayElement,
ScalarFunction::Flatten => Self::Flatten,
ScalarFunction::ArrayLength => Self::ArrayLength,
@@ -1467,6 +1468,9 @@ pub fn parse_expr(
ScalarFunction::ArrayDims => {
Ok(array_dims(parse_expr(&args[0], registry)?))
}
+ ScalarFunction::ArrayDistinct => {
+ Ok(array_distinct(parse_expr(&args[0], registry)?))
+ }
ScalarFunction::ArrayElement => Ok(array_element(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 4c6fdaa894..2997d14742 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1512,6 +1512,7 @@ impl TryFrom<&BuiltinScalarFunction> for
protobuf::ScalarFunction {
BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny,
BuiltinScalarFunction::ArrayHas => Self::ArrayHas,
BuiltinScalarFunction::ArrayDims => Self::ArrayDims,
+ BuiltinScalarFunction::ArrayDistinct => Self::ArrayDistinct,
BuiltinScalarFunction::ArrayElement => Self::ArrayElement,
BuiltinScalarFunction::Flatten => Self::Flatten,
BuiltinScalarFunction::ArrayLength => Self::ArrayLength,
diff --git a/datafusion/sqllogictest/test_files/array.slt
b/datafusion/sqllogictest/test_files/array.slt
index 3c23dd369a..1202a2b1e9 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -182,6 +182,38 @@ AS VALUES
(make_array([[1], [2]], [[2], [3]]), make_array([1], [2]))
;
+statement ok
+CREATE TABLE array_distinct_table_1D
+AS VALUES
+ (make_array(1, 1, 2, 2, 3)),
+ (make_array(1, 2, 3, 4, 5)),
+ (make_array(3, 5, 3, 3, 3))
+;
+
+statement ok
+CREATE TABLE array_distinct_table_1D_UTF8
+AS VALUES
+ (make_array('a', 'a', 'bc', 'bc', 'def')),
+ (make_array('a', 'bc', 'def', 'defg', 'defg')),
+ (make_array('defg', 'defg', 'defg', 'defg', 'defg'))
+;
+
+statement ok
+CREATE TABLE array_distinct_table_2D
+AS VALUES
+ (make_array([1,2], [1,2], [3,4], [3,4], [5,6])),
+ (make_array([1,2], [3,4], [5,6], [7,8], [9,10])),
+ (make_array([5,6], [5,6], NULL))
+;
+
+statement ok
+CREATE TABLE array_distinct_table_1D_large
+AS VALUES
+ (arrow_cast(make_array(1, 1, 2, 2, 3), 'LargeList(Int64)')),
+ (arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')),
+ (arrow_cast(make_array(3, 5, 3, 3, 3), 'LargeList(Int64)'))
+;
+
statement ok
CREATE TABLE array_intersect_table_1D
AS VALUES
@@ -2864,6 +2896,73 @@ select array_has_all(arrow_cast(make_array(1,2,3),
'LargeList(Int64)'), arrow_ca
----
true false true false false false true true false false true false true
+query BBBBBBBBBBBBB
+select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'),
arrow_cast(make_array(1,3), 'LargeList(Int64)')),
+ array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'),
arrow_cast(make_array(1,4), 'LargeList(Int64)')),
+ array_has_all(arrow_cast(make_array([1,2], [3,4]),
'LargeList(List(Int64))'), arrow_cast(make_array([1,2]),
'LargeList(List(Int64))')),
+ array_has_all(arrow_cast(make_array([1,2], [3,4]),
'LargeList(List(Int64))'), arrow_cast(make_array([1,3]),
'LargeList(List(Int64))')),
+ array_has_all(arrow_cast(make_array([1,2], [3,4]),
'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]),
'LargeList(List(Int64))')),
+ array_has_all(arrow_cast(make_array([[1,2,3]]),
'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]),
'LargeList(List(List(Int64)))')),
+ array_has_all(arrow_cast(make_array([[1,2,3]]),
'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]),
'LargeList(List(List(Int64)))')),
+ array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'),
arrow_cast(make_array(1,10,100), 'LargeList(Int64)')),
+ array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'),
arrow_cast(make_array(10,100),'LargeList(Int64)')),
+ array_has_any(arrow_cast(make_array([1,2], [3,4]),
'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]),
'LargeList(List(Int64))')),
+ array_has_any(arrow_cast(make_array([1,2], [3,4]),
'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]),
'LargeList(List(Int64))')),
+ array_has_any(arrow_cast(make_array([[1,2,3]]),
'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]),
'LargeList(List(List(Int64)))')),
+ array_has_any(arrow_cast(make_array([[1,2,3]]),
'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]),
'LargeList(List(List(Int64)))'))
+;
+----
+true false true false false false true true false false true false true
+
+## array_distinct
+
+query ?
+select array_distinct(null);
+----
+NULL
+
+query ?
+select array_distinct([]);
+----
+[]
+
+query ?
+select array_distinct([[], []]);
+----
+[[]]
+
+query ?
+select array_distinct(column1)
+from array_distinct_table_1D;
+----
+[1, 2, 3]
+[1, 2, 3, 4, 5]
+[3, 5]
+
+query ?
+select array_distinct(column1)
+from array_distinct_table_1D_UTF8;
+----
+[a, bc, def]
+[a, bc, def, defg]
+[defg]
+
+query ?
+select array_distinct(column1)
+from array_distinct_table_2D;
+----
+[[1, 2], [3, 4], [5, 6]]
+[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
+[, [5, 6]]
+
+query ?
+select array_distinct(column1)
+from array_distinct_table_1D_large;
+----
+[1, 2, 3]
+[1, 2, 3, 4, 5]
+[3, 5]
+
query ???
select array_intersect(column1, column2),
array_intersect(column3, column4),
diff --git a/docs/source/user-guide/expressions.md
b/docs/source/user-guide/expressions.md
index 257c50dfa4..b8689e5567 100644
--- a/docs/source/user-guide/expressions.md
+++ b/docs/source/user-guide/expressions.md
@@ -215,6 +215,7 @@ Unlike to some databases the math functions in Datafusion
works the same way as
| array_has_all(array, sub-array) | Returns true if all elements of
sub-array exist in array `array_has_all([1,2,3], [1,3]) -> true`
|
| array_has_any(array, sub-array) | Returns true if any elements exist
in both arrays `array_has_any([1,2,3], [1,4]) -> true`
|
| array_dims(array) | Returns an array of the array's
dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]`
|
+| array_distinct(array) | Returns distinct values from the
array after removing duplicates. `array_distinct([1, 3, 2, 3, 1, 2, 4]) -> [1,
2, 3, 4]` |
| array_element(array, index) | Extracts the element with the index
n from the array `array_element([1, 2, 3, 4], 3) -> 3`
|
| flatten(array) | Converts an array of arrays to a
flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]`
|
| array_length(array, dimension) | Returns the length of the array
dimension. `array_length([1, 2, 3, 4, 5]) -> 5`
|