This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 469f18be1c function: Allow more expressive array signatures (#14532)
469f18be1c is described below
commit 469f18be1c594b07e4b235f3404419792ed3c24f
Author: Joseph Koshakow <[email protected]>
AuthorDate: Fri Feb 14 08:20:34 2025 -0500
function: Allow more expressive array signatures (#14532)
* function: Allow more expressive array signatures
This commit allows for more expressive array function signatures.
Previously, `ArrayFunctionSignature` was an enum of potential argument
combinations and orders. For many array functions, none of the
`ArrayFunctionSignature` variants worked, so they used
`TypeSignature::VariadicAny` instead. This commit will allow those
functions to use more descriptive signatures which will prevent them
from having to perform manual type checking in the function
implementation.
As an example, this commit also updates the signature of the
`array_replace` family of functions to use a new expressive signature,
which removes a panic that existed previously.
There are still a couple of limitations with this approach. First of
all, there's no way to describe a function that has multiple different
arrays of different type or dimension. Additionally, there isn't
support for functions with map arrays and recursive arrays that have
more than one argument.
Works towards resolving #14451
* Add mutability
* Move mutability enum
* fmt
* Fix doctest
* Add validation to array args
* Remove mutability and update return types
* fmt
* Fix clippy
* Fix imports
* Add list coercion flag
* Some formatting fixes
* Some formatting fixes
* Remove ArrayFunctionArguments struct
* Simplify helper functions
* Update array_and_element behavior
---
datafusion/common/src/utils/mod.rs | 44 +++++--
datafusion/expr-common/src/signature.rs | 135 +++++++++++--------
datafusion/expr/src/lib.rs | 4 +-
datafusion/expr/src/type_coercion/functions.rs | 173 +++++++++++--------------
datafusion/functions-nested/src/cardinality.rs | 9 +-
datafusion/functions-nested/src/concat.rs | 17 ++-
datafusion/functions-nested/src/extract.rs | 53 ++++++--
datafusion/functions-nested/src/replace.rs | 47 ++++++-
datafusion/sqllogictest/test_files/array.slt | 56 ++++++++
9 files changed, 352 insertions(+), 186 deletions(-)
diff --git a/datafusion/common/src/utils/mod.rs
b/datafusion/common/src/utils/mod.rs
index cb77cc8e79..ff9cdedab8 100644
--- a/datafusion/common/src/utils/mod.rs
+++ b/datafusion/common/src/utils/mod.rs
@@ -590,6 +590,13 @@ pub fn base_type(data_type: &DataType) -> DataType {
}
}
+/// Information about how to coerce lists.
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
+pub enum ListCoercion {
+ /// [`DataType::FixedSizeList`] should be coerced to [`DataType::List`].
+ FixedSizedListToList,
+}
+
/// A helper function to coerce base type in List.
///
/// Example
@@ -600,16 +607,22 @@ pub fn base_type(data_type: &DataType) -> DataType {
///
/// let data_type =
DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true)));
/// let base_type = DataType::Float64;
-/// let coerced_type = coerced_type_with_base_type_only(&data_type,
&base_type);
+/// let coerced_type = coerced_type_with_base_type_only(&data_type,
&base_type, None);
/// assert_eq!(coerced_type,
DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true))));
pub fn coerced_type_with_base_type_only(
data_type: &DataType,
base_type: &DataType,
+ array_coercion: Option<&ListCoercion>,
) -> DataType {
- match data_type {
- DataType::List(field) | DataType::FixedSizeList(field, _) => {
- let field_type =
- coerced_type_with_base_type_only(field.data_type(), base_type);
+ match (data_type, array_coercion) {
+ (DataType::List(field), _)
+ | (DataType::FixedSizeList(field, _),
Some(ListCoercion::FixedSizedListToList)) =>
+ {
+ let field_type = coerced_type_with_base_type_only(
+ field.data_type(),
+ base_type,
+ array_coercion,
+ );
DataType::List(Arc::new(Field::new(
field.name(),
@@ -617,9 +630,24 @@ pub fn coerced_type_with_base_type_only(
field.is_nullable(),
)))
}
- DataType::LargeList(field) => {
- let field_type =
- coerced_type_with_base_type_only(field.data_type(), base_type);
+ (DataType::FixedSizeList(field, len), _) => {
+ let field_type = coerced_type_with_base_type_only(
+ field.data_type(),
+ base_type,
+ array_coercion,
+ );
+
+ DataType::FixedSizeList(
+ Arc::new(Field::new(field.name(), field_type,
field.is_nullable())),
+ *len,
+ )
+ }
+ (DataType::LargeList(field), _) => {
+ let field_type = coerced_type_with_base_type_only(
+ field.data_type(),
+ base_type,
+ array_coercion,
+ );
DataType::LargeList(Arc::new(Field::new(
field.name(),
diff --git a/datafusion/expr-common/src/signature.rs
b/datafusion/expr-common/src/signature.rs
index 1bfae28af8..4ca4961d7b 100644
--- a/datafusion/expr-common/src/signature.rs
+++ b/datafusion/expr-common/src/signature.rs
@@ -19,11 +19,11 @@
//! and return types of functions in DataFusion.
use std::fmt::Display;
-use std::num::NonZeroUsize;
use crate::type_coercion::aggregates::NUMERICS;
use arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
use datafusion_common::types::{LogicalTypeRef, NativeType};
+use datafusion_common::utils::ListCoercion;
use itertools::Itertools;
/// Constant that is used as a placeholder for any valid timezone.
@@ -227,25 +227,13 @@ impl Display for TypeSignatureClass {
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum ArrayFunctionSignature {
- /// Specialized Signature for ArrayAppend and similar functions
- /// The first argument should be List/LargeList/FixedSizedList, and the
second argument should be non-list or list.
- /// The second argument's list dimension should be one dimension less than
the first argument's list dimension.
- /// List dimension of the List/LargeList is equivalent to the number of
List.
- /// List dimension of the non-list is 0.
- ArrayAndElement,
- /// Specialized Signature for ArrayPrepend and similar functions
- /// The first argument should be non-list or list, and the second argument
should be List/LargeList.
- /// The first argument's list dimension should be one dimension less than
the second argument's list dimension.
- ElementAndArray,
- /// Specialized Signature for Array functions of the form (List/LargeList,
Index+)
- /// The first argument should be List/LargeList/FixedSizedList, and the
next n arguments should be Int64.
- ArrayAndIndexes(NonZeroUsize),
- /// Specialized Signature for Array functions of the form (List/LargeList,
Element, Optional Index)
- ArrayAndElementAndOptionalIndex,
- /// Specialized Signature for ArrayEmpty and similar functions
- /// The function takes a single argument that must be a
List/LargeList/FixedSizeList
- /// or something that can be coerced to one of those types.
- Array,
+ /// A function takes at least one List/LargeList/FixedSizeList argument.
+ Array {
+ /// A full list of the arguments accepted by this function.
+ arguments: Vec<ArrayFunctionArgument>,
+ /// Additional information about how array arguments should be coerced.
+ array_coercion: Option<ListCoercion>,
+ },
/// A function takes a single argument that must be a
List/LargeList/FixedSizeList
/// which gets coerced to List, with element type recursively coerced to
List too if it is list-like.
RecursiveArray,
@@ -257,25 +245,15 @@ pub enum ArrayFunctionSignature {
impl Display for ArrayFunctionSignature {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
- ArrayFunctionSignature::ArrayAndElement => {
- write!(f, "array, element")
- }
- ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => {
- write!(f, "array, element, [index]")
- }
- ArrayFunctionSignature::ElementAndArray => {
- write!(f, "element, array")
- }
- ArrayFunctionSignature::ArrayAndIndexes(count) => {
- write!(f, "array")?;
- for _ in 0..count.get() {
- write!(f, ", index")?;
+ ArrayFunctionSignature::Array { arguments, .. } => {
+ for (idx, argument) in arguments.iter().enumerate() {
+ write!(f, "{argument}")?;
+ if idx != arguments.len() - 1 {
+ write!(f, ", ")?;
+ }
}
Ok(())
}
- ArrayFunctionSignature::Array => {
- write!(f, "array")
- }
ArrayFunctionSignature::RecursiveArray => {
write!(f, "recursive_array")
}
@@ -286,6 +264,34 @@ impl Display for ArrayFunctionSignature {
}
}
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
+pub enum ArrayFunctionArgument {
+ /// A non-list or list argument. The list dimensions should be one less
than the Array's list
+ /// dimensions.
+ Element,
+ /// An Int64 index argument.
+ Index,
+ /// An argument of type List/LargeList/FixedSizeList. All Array arguments
must be coercible
+ /// to the same type.
+ Array,
+}
+
+impl Display for ArrayFunctionArgument {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ ArrayFunctionArgument::Element => {
+ write!(f, "element")
+ }
+ ArrayFunctionArgument::Index => {
+ write!(f, "index")
+ }
+ ArrayFunctionArgument::Array => {
+ write!(f, "array")
+ }
+ }
+ }
+}
+
impl TypeSignature {
pub fn to_string_repr(&self) -> Vec<String> {
match self {
@@ -580,7 +586,13 @@ impl Signature {
pub fn array_and_element(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
- ArrayFunctionSignature::ArrayAndElement,
+ ArrayFunctionSignature::Array {
+ arguments: vec![
+ ArrayFunctionArgument::Array,
+ ArrayFunctionArgument::Element,
+ ],
+ array_coercion: Some(ListCoercion::FixedSizedListToList),
+ },
),
volatility,
}
@@ -588,30 +600,38 @@ impl Signature {
/// Specialized Signature for Array functions with an optional index
pub fn array_and_element_and_optional_index(volatility: Volatility) ->
Self {
Signature {
- type_signature: TypeSignature::ArraySignature(
- ArrayFunctionSignature::ArrayAndElementAndOptionalIndex,
- ),
- volatility,
- }
- }
- /// Specialized Signature for ArrayPrepend and similar functions
- pub fn element_and_array(volatility: Volatility) -> Self {
- Signature {
- type_signature: TypeSignature::ArraySignature(
- ArrayFunctionSignature::ElementAndArray,
- ),
+ type_signature: TypeSignature::OneOf(vec![
+ TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
+ arguments: vec![
+ ArrayFunctionArgument::Array,
+ ArrayFunctionArgument::Element,
+ ],
+ array_coercion: None,
+ }),
+ TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
+ arguments: vec![
+ ArrayFunctionArgument::Array,
+ ArrayFunctionArgument::Element,
+ ArrayFunctionArgument::Index,
+ ],
+ array_coercion: None,
+ }),
+ ]),
volatility,
}
}
+
/// Specialized Signature for ArrayElement and similar functions
pub fn array_and_index(volatility: Volatility) -> Self {
- Self::array_and_indexes(volatility, NonZeroUsize::new(1).expect("1 is
non-zero"))
- }
- /// Specialized Signature for ArraySlice and similar functions
- pub fn array_and_indexes(volatility: Volatility, count: NonZeroUsize) ->
Self {
Signature {
type_signature: TypeSignature::ArraySignature(
- ArrayFunctionSignature::ArrayAndIndexes(count),
+ ArrayFunctionSignature::Array {
+ arguments: vec![
+ ArrayFunctionArgument::Array,
+ ArrayFunctionArgument::Index,
+ ],
+ array_coercion: None,
+ },
),
volatility,
}
@@ -619,7 +639,12 @@ impl Signature {
/// Specialized Signature for ArrayEmpty and similar functions
pub fn array(volatility: Volatility) -> Self {
Signature {
- type_signature:
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
+ type_signature: TypeSignature::ArraySignature(
+ ArrayFunctionSignature::Array {
+ arguments: vec![ArrayFunctionArgument::Array],
+ array_coercion: None,
+ },
+ ),
volatility,
}
}
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index aaa65c676a..2f04f234eb 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -71,8 +71,8 @@ pub use datafusion_expr_common::columnar_value::ColumnarValue;
pub use datafusion_expr_common::groups_accumulator::{EmitTo,
GroupsAccumulator};
pub use datafusion_expr_common::operator::Operator;
pub use datafusion_expr_common::signature::{
- ArrayFunctionSignature, Signature, TypeSignature, TypeSignatureClass,
Volatility,
- TIMEZONE_WILDCARD,
+ ArrayFunctionArgument, ArrayFunctionSignature, Signature, TypeSignature,
+ TypeSignatureClass, Volatility, TIMEZONE_WILDCARD,
};
pub use datafusion_expr_common::type_coercion::binary;
pub use expr::{
diff --git a/datafusion/expr/src/type_coercion/functions.rs
b/datafusion/expr/src/type_coercion/functions.rs
index 7ac836ef3a..7fda92862b 100644
--- a/datafusion/expr/src/type_coercion/functions.rs
+++ b/datafusion/expr/src/type_coercion/functions.rs
@@ -21,13 +21,14 @@ use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
-use datafusion_common::utils::coerced_fixed_size_list_to_list;
+use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
types::{LogicalType, NativeType},
utils::list_ndims,
Result,
};
+use datafusion_expr_common::signature::ArrayFunctionArgument;
use datafusion_expr_common::{
signature::{
ArrayFunctionSignature, TypeSignatureClass, FIXED_SIZE_LIST_WILDCARD,
@@ -357,88 +358,81 @@ fn get_valid_types(
signature: &TypeSignature,
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
- fn array_element_and_optional_index(
+ fn array_valid_types(
function_name: &str,
current_types: &[DataType],
+ arguments: &[ArrayFunctionArgument],
+ array_coercion: Option<&ListCoercion>,
) -> Result<Vec<Vec<DataType>>> {
- // make sure there's 2 or 3 arguments
- if !(current_types.len() == 2 || current_types.len() == 3) {
+ if current_types.len() != arguments.len() {
return Ok(vec![vec![]]);
}
- let first_two_types = ¤t_types[0..2];
- let mut valid_types =
- array_append_or_prepend_valid_types(function_name,
first_two_types, true)?;
-
- // Early return if there are only 2 arguments
- if current_types.len() == 2 {
- return Ok(valid_types);
- }
-
- let valid_types_with_index = valid_types
- .iter()
- .map(|t| {
- let mut t = t.clone();
- t.push(DataType::Int64);
- t
- })
- .collect::<Vec<_>>();
-
- valid_types.extend(valid_types_with_index);
-
- Ok(valid_types)
- }
-
- fn array_append_or_prepend_valid_types(
- function_name: &str,
- current_types: &[DataType],
- is_append: bool,
- ) -> Result<Vec<Vec<DataType>>> {
- if current_types.len() != 2 {
- return Ok(vec![vec![]]);
- }
-
- let (array_type, elem_type) = if is_append {
- (¤t_types[0], ¤t_types[1])
- } else {
- (¤t_types[1], ¤t_types[0])
+ let array_idx = arguments.iter().enumerate().find_map(|(idx, arg)| {
+ if *arg == ArrayFunctionArgument::Array {
+ Some(idx)
+ } else {
+ None
+ }
+ });
+ let Some(array_idx) = array_idx else {
+ return Err(internal_datafusion_err!("Function '{function_name}'
expected at least one argument array argument"));
};
-
- // We follow Postgres on `array_append(Null, T)`, which is not valid.
- if array_type.eq(&DataType::Null) {
+ let Some(array_type) = array(¤t_types[array_idx]) else {
return Ok(vec![vec![]]);
- }
+ };
// We need to find the coerced base type, mainly for cases like:
// `array_append(List(null), i64)` -> `List(i64)`
- let array_base_type = datafusion_common::utils::base_type(array_type);
- let elem_base_type = datafusion_common::utils::base_type(elem_type);
- let new_base_type = comparison_coercion(&array_base_type,
&elem_base_type);
-
- let new_base_type = new_base_type.ok_or_else(|| {
- internal_datafusion_err!(
- "Function '{function_name}' does not support coercion from
{array_base_type:?} to {elem_base_type:?}"
- )
- })?;
-
+ let mut new_base_type =
datafusion_common::utils::base_type(&array_type);
+ for (current_type, argument_type) in
current_types.iter().zip(arguments.iter()) {
+ match argument_type {
+ ArrayFunctionArgument::Element | ArrayFunctionArgument::Array
=> {
+ new_base_type =
+ coerce_array_types(function_name, current_type,
&new_base_type)?;
+ }
+ ArrayFunctionArgument::Index => {}
+ }
+ }
let new_array_type =
datafusion_common::utils::coerced_type_with_base_type_only(
- array_type,
+ &array_type,
&new_base_type,
+ array_coercion,
);
- match new_array_type {
+ let new_elem_type = match new_array_type {
DataType::List(ref field)
| DataType::LargeList(ref field)
- | DataType::FixedSizeList(ref field, _) => {
- let new_elem_type = field.data_type();
- if is_append {
- Ok(vec![vec![new_array_type.clone(),
new_elem_type.clone()]])
- } else {
- Ok(vec![vec![new_elem_type.to_owned(),
new_array_type.clone()]])
+ | DataType::FixedSizeList(ref field, _) => field.data_type(),
+ _ => return Ok(vec![vec![]]),
+ };
+
+ let mut valid_types = Vec::with_capacity(arguments.len());
+ for (current_type, argument_type) in
current_types.iter().zip(arguments.iter()) {
+ let valid_type = match argument_type {
+ ArrayFunctionArgument::Element => new_elem_type.clone(),
+ ArrayFunctionArgument::Index => DataType::Int64,
+ ArrayFunctionArgument::Array => {
+ let Some(current_type) = array(current_type) else {
+ return Ok(vec![vec![]]);
+ };
+ let new_type =
+
datafusion_common::utils::coerced_type_with_base_type_only(
+ ¤t_type,
+ &new_base_type,
+ array_coercion,
+ );
+ // All array arguments must be coercible to the same type
+ if new_type != new_array_type {
+ return Ok(vec![vec![]]);
+ }
+ new_type
}
- }
- _ => Ok(vec![vec![]]),
+ };
+ valid_types.push(valid_type);
}
+
+ Ok(vec![valid_types])
}
fn array(array_type: &DataType) -> Option<DataType> {
@@ -449,6 +443,20 @@ fn get_valid_types(
}
}
+ fn coerce_array_types(
+ function_name: &str,
+ current_type: &DataType,
+ base_type: &DataType,
+ ) -> Result<DataType> {
+ let current_base_type =
datafusion_common::utils::base_type(current_type);
+ let new_base_type = comparison_coercion(base_type, ¤t_base_type);
+ new_base_type.ok_or_else(|| {
+ internal_datafusion_err!(
+ "Function '{function_name}' does not support coercion from
{base_type:?} to {current_base_type:?}"
+ )
+ })
+ }
+
fn recursive_array(array_type: &DataType) -> Option<DataType> {
match array_type {
DataType::List(_)
@@ -693,40 +701,9 @@ fn get_valid_types(
vec![current_types.to_vec()]
}
TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
- TypeSignature::ArraySignature(ref function_signature) => match
function_signature
- {
- ArrayFunctionSignature::ArrayAndElement => {
- array_append_or_prepend_valid_types(function_name,
current_types, true)?
- }
- ArrayFunctionSignature::ElementAndArray => {
- array_append_or_prepend_valid_types(function_name,
current_types, false)?
- }
- ArrayFunctionSignature::ArrayAndIndexes(count) => {
- if current_types.len() != count.get() + 1 {
- return Ok(vec![vec![]]);
- }
- array(¤t_types[0]).map_or_else(
- || vec![vec![]],
- |array_type| {
- let mut inner = Vec::with_capacity(count.get() + 1);
- inner.push(array_type);
- for _ in 0..count.get() {
- inner.push(DataType::Int64);
- }
- vec![inner]
- },
- )
- }
- ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => {
- array_element_and_optional_index(function_name, current_types)?
- }
- ArrayFunctionSignature::Array => {
- if current_types.len() != 1 {
- return Ok(vec![vec![]]);
- }
-
- array(¤t_types[0])
- .map_or_else(|| vec![vec![]], |array_type|
vec![vec![array_type]])
+ TypeSignature::ArraySignature(ref function_signature) => match
function_signature {
+ ArrayFunctionSignature::Array { arguments, array_coercion, } => {
+ array_valid_types(function_name, current_types, arguments,
array_coercion.as_ref())?
}
ArrayFunctionSignature::RecursiveArray => {
if current_types.len() != 1 {
diff --git a/datafusion/functions-nested/src/cardinality.rs
b/datafusion/functions-nested/src/cardinality.rs
index ad30c0b540..8867097799 100644
--- a/datafusion/functions-nested/src/cardinality.rs
+++ b/datafusion/functions-nested/src/cardinality.rs
@@ -30,8 +30,8 @@ use datafusion_common::utils::take_function_args;
use datafusion_common::Result;
use datafusion_common::{exec_err, plan_err};
use datafusion_expr::{
- ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl,
Signature,
- TypeSignature, Volatility,
+ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue,
Documentation,
+ ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_macros::user_doc;
use std::any::Any;
@@ -50,7 +50,10 @@ impl Cardinality {
Self {
signature: Signature::one_of(
vec![
-
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
+
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
+ arguments: vec![ArrayFunctionArgument::Array],
+ array_coercion: None,
+ }),
TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray),
],
Volatility::Immutable,
diff --git a/datafusion/functions-nested/src/concat.rs
b/datafusion/functions-nested/src/concat.rs
index 14d4b95886..f404173869 100644
--- a/datafusion/functions-nested/src/concat.rs
+++ b/datafusion/functions-nested/src/concat.rs
@@ -26,6 +26,7 @@ use arrow::array::{
};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::{DataType, Field};
+use datafusion_common::utils::ListCoercion;
use datafusion_common::Result;
use datafusion_common::{
cast::as_generic_list_array,
@@ -33,7 +34,8 @@ use datafusion_common::{
utils::{list_ndims, take_function_args},
};
use datafusion_expr::{
- ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
+ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue,
Documentation,
+ ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_macros::user_doc;
@@ -165,7 +167,18 @@ impl Default for ArrayPrepend {
impl ArrayPrepend {
pub fn new() -> Self {
Self {
- signature: Signature::element_and_array(Volatility::Immutable),
+ signature: Signature {
+ type_signature: TypeSignature::ArraySignature(
+ ArrayFunctionSignature::Array {
+ arguments: vec![
+ ArrayFunctionArgument::Element,
+ ArrayFunctionArgument::Array,
+ ],
+ array_coercion:
Some(ListCoercion::FixedSizedListToList),
+ },
+ ),
+ volatility: Volatility::Immutable,
+ },
aliases: vec![
String::from("list_prepend"),
String::from("array_push_front"),
diff --git a/datafusion/functions-nested/src/extract.rs
b/datafusion/functions-nested/src/extract.rs
index 697c868fde..6bf4d16db6 100644
--- a/datafusion/functions-nested/src/extract.rs
+++ b/datafusion/functions-nested/src/extract.rs
@@ -30,17 +30,19 @@ use arrow::datatypes::{
use datafusion_common::cast::as_int64_array;
use datafusion_common::cast::as_large_list_array;
use datafusion_common::cast::as_list_array;
+use datafusion_common::utils::ListCoercion;
use datafusion_common::{
exec_err, internal_datafusion_err, plan_err, utils::take_function_args,
DataFusionError, Result,
};
-use datafusion_expr::{ArrayFunctionSignature, Expr, TypeSignature};
+use datafusion_expr::{
+ ArrayFunctionArgument, ArrayFunctionSignature, Expr, TypeSignature,
+};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use std::any::Any;
-use std::num::NonZeroUsize;
use std::sync::Arc;
use crate::utils::make_scalar_function;
@@ -330,16 +332,23 @@ impl ArraySlice {
Self {
signature: Signature::one_of(
vec![
- TypeSignature::ArraySignature(
- ArrayFunctionSignature::ArrayAndIndexes(
- NonZeroUsize::new(2).expect("2 is non-zero"),
- ),
- ),
- TypeSignature::ArraySignature(
- ArrayFunctionSignature::ArrayAndIndexes(
- NonZeroUsize::new(3).expect("3 is non-zero"),
- ),
- ),
+
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
+ arguments: vec![
+ ArrayFunctionArgument::Array,
+ ArrayFunctionArgument::Index,
+ ArrayFunctionArgument::Index,
+ ],
+ array_coercion: None,
+ }),
+
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
+ arguments: vec![
+ ArrayFunctionArgument::Array,
+ ArrayFunctionArgument::Index,
+ ArrayFunctionArgument::Index,
+ ArrayFunctionArgument::Index,
+ ],
+ array_coercion: None,
+ }),
],
Volatility::Immutable,
),
@@ -665,7 +674,15 @@ pub(super) struct ArrayPopFront {
impl ArrayPopFront {
pub fn new() -> Self {
Self {
- signature: Signature::array(Volatility::Immutable),
+ signature: Signature {
+ type_signature: TypeSignature::ArraySignature(
+ ArrayFunctionSignature::Array {
+ arguments: vec![ArrayFunctionArgument::Array],
+ array_coercion:
Some(ListCoercion::FixedSizedListToList),
+ },
+ ),
+ volatility: Volatility::Immutable,
+ },
aliases: vec![String::from("list_pop_front")],
}
}
@@ -765,7 +782,15 @@ pub(super) struct ArrayPopBack {
impl ArrayPopBack {
pub fn new() -> Self {
Self {
- signature: Signature::array(Volatility::Immutable),
+ signature: Signature {
+ type_signature: TypeSignature::ArraySignature(
+ ArrayFunctionSignature::Array {
+ arguments: vec![ArrayFunctionArgument::Array],
+ array_coercion:
Some(ListCoercion::FixedSizedListToList),
+ },
+ ),
+ volatility: Volatility::Immutable,
+ },
aliases: vec![String::from("list_pop_back")],
}
}
diff --git a/datafusion/functions-nested/src/replace.rs
b/datafusion/functions-nested/src/replace.rs
index 53f43de410..6d84e64cba 100644
--- a/datafusion/functions-nested/src/replace.rs
+++ b/datafusion/functions-nested/src/replace.rs
@@ -25,9 +25,11 @@ use arrow::datatypes::{DataType, Field};
use arrow::buffer::OffsetBuffer;
use datafusion_common::cast::as_int64_array;
+use datafusion_common::utils::ListCoercion;
use datafusion_common::{exec_err, utils::take_function_args, Result};
use datafusion_expr::{
- ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
+ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue,
Documentation,
+ ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_macros::user_doc;
@@ -91,7 +93,19 @@ impl Default for ArrayReplace {
impl ArrayReplace {
pub fn new() -> Self {
Self {
- signature: Signature::any(3, Volatility::Immutable),
+ signature: Signature {
+ type_signature: TypeSignature::ArraySignature(
+ ArrayFunctionSignature::Array {
+ arguments: vec![
+ ArrayFunctionArgument::Array,
+ ArrayFunctionArgument::Element,
+ ArrayFunctionArgument::Element,
+ ],
+ array_coercion:
Some(ListCoercion::FixedSizedListToList),
+ },
+ ),
+ volatility: Volatility::Immutable,
+ },
aliases: vec![String::from("list_replace")],
}
}
@@ -160,7 +174,20 @@ pub(super) struct ArrayReplaceN {
impl ArrayReplaceN {
pub fn new() -> Self {
Self {
- signature: Signature::any(4, Volatility::Immutable),
+ signature: Signature {
+ type_signature: TypeSignature::ArraySignature(
+ ArrayFunctionSignature::Array {
+ arguments: vec![
+ ArrayFunctionArgument::Array,
+ ArrayFunctionArgument::Element,
+ ArrayFunctionArgument::Element,
+ ArrayFunctionArgument::Index,
+ ],
+ array_coercion:
Some(ListCoercion::FixedSizedListToList),
+ },
+ ),
+ volatility: Volatility::Immutable,
+ },
aliases: vec![String::from("list_replace_n")],
}
}
@@ -228,7 +255,19 @@ pub(super) struct ArrayReplaceAll {
impl ArrayReplaceAll {
pub fn new() -> Self {
Self {
- signature: Signature::any(3, Volatility::Immutable),
+ signature: Signature {
+ type_signature: TypeSignature::ArraySignature(
+ ArrayFunctionSignature::Array {
+ arguments: vec![
+ ArrayFunctionArgument::Array,
+ ArrayFunctionArgument::Element,
+ ArrayFunctionArgument::Element,
+ ],
+ array_coercion:
Some(ListCoercion::FixedSizedListToList),
+ },
+ ),
+ volatility: Volatility::Immutable,
+ },
aliases: vec![String::from("list_replace_all")],
}
}
diff --git a/datafusion/sqllogictest/test_files/array.slt
b/datafusion/sqllogictest/test_files/array.slt
index 8f23bfe5ea..4418d426cc 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -2656,6 +2656,29 @@ select list_push_front(1, arrow_cast(make_array(2, 3,
4), 'LargeList(Int64)')),
----
[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o]
+# array_prepend scalar function #7 (element is fixed size list)
+query ???
+select array_prepend(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)'),
make_array(arrow_cast(make_array(2), 'FixedSizeList(1, Int64)'),
arrow_cast(make_array(3), 'FixedSizeList(1, Int64)'), arrow_cast(make_array(4),
'FixedSizeList(1, Int64)'))),
+ array_prepend(arrow_cast(make_array(1.0), 'FixedSizeList(1, Float64)'),
make_array(arrow_cast([2.0], 'FixedSizeList(1, Float64)'), arrow_cast([3.0],
'FixedSizeList(1, Float64)'), arrow_cast([4.0], 'FixedSizeList(1, Float64)'))),
+ array_prepend(arrow_cast(make_array('h'), 'FixedSizeList(1, Utf8)'),
make_array(arrow_cast(['e'], 'FixedSizeList(1, Utf8)'), arrow_cast(['l'],
'FixedSizeList(1, Utf8)'), arrow_cast(['l'], 'FixedSizeList(1, Utf8)'),
arrow_cast(['o'], 'FixedSizeList(1, Utf8)')));
+----
+[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]]
+
+# TODO: https://github.com/apache/datafusion/issues/14613
+#query ???
+#select array_prepend(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)'),
arrow_cast(make_array(make_array(2), make_array(3), make_array(4)),
'LargeList(FixedSizeList(1, Int64))')),
+# array_prepend(arrow_cast(make_array(1.0), 'FixedSizeList(1,
Float64)'), arrow_cast(make_array([2.0], [3.0], [4.0]),
'LargeList(FixedSizeList(1, Float64))')),
+# array_prepend(arrow_cast(make_array('h'), 'FixedSizeList(1, Utf8)'),
arrow_cast(make_array(['e'], ['l'], ['l'], ['o']), 'LargeList(FixedSizeList(1,
Utf8))'));
+#----
+#[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]]
+
+query ???
+select array_prepend(arrow_cast([1], 'FixedSizeList(1, Int64)'),
arrow_cast([[1], [2], [3]], 'FixedSizeList(3, FixedSizeList(1, Int64))')),
+ array_prepend(arrow_cast([1.0], 'FixedSizeList(1, Float64)'),
arrow_cast([[2.0], [3.0], [4.0]], 'FixedSizeList(3, FixedSizeList(1,
Float64))')),
+ array_prepend(arrow_cast(['h'], 'FixedSizeList(1, Utf8)'),
arrow_cast([['e'], ['l'], ['l'], ['o']], 'FixedSizeList(4, FixedSizeList(1,
Utf8))'));
+----
+[[1], [1], [2], [3]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]]
+
# array_prepend with columns #1
query ?
select array_prepend(column2, column1) from arrays_values;
@@ -3563,6 +3586,17 @@ select list_replace(
----
[1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3]
+# array_replace scalar function #4 (null input)
+query ?
+select array_replace(make_array(1, 2, 3, 4, 5), NULL, NULL);
+----
+[1, 2, 3, 4, 5]
+
+query ?
+select array_replace(arrow_cast(make_array(1, 2, 3, 4, 5),
'LargeList(Int64)'), NULL, NULL);
+----
+[1, 2, 3, 4, 5]
+
# array_replace scalar function with columns #1
query ?
select array_replace(column1, column2, column3) from
arrays_with_repeating_elements;
@@ -3728,6 +3762,17 @@ select
----
[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3]
+# array_replace_n scalar function #4 (null input)
+query ?
+select array_replace_n(make_array(1, 2, 3, 4, 5), NULL, NULL, NULL);
+----
+[1, 2, 3, 4, 5]
+
+query ?
+select array_replace_n(arrow_cast(make_array(1, 2, 3, 4, 5),
'LargeList(Int64)'), NULL, NULL, NULL);
+----
+[1, 2, 3, 4, 5]
+
# array_replace_n scalar function with columns #1
query ?
select
@@ -3904,6 +3949,17 @@ select
----
[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3]
+# array_replace_all scalar function #4 (null input)
+query ?
+select array_replace_all(make_array(1, 2, 3, 4, 5), NULL, NULL);
+----
+[1, 2, 3, 4, 5]
+
+query ?
+select array_replace_all(arrow_cast(make_array(1, 2, 3, 4, 5),
'LargeList(Int64)'), NULL, NULL);
+----
+[1, 2, 3, 4, 5]
+
# array_replace_all scalar function with columns #1
query ?
select
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]