This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 76f9e2eb44 Introduce user-defined signature (#10439)
76f9e2eb44 is described below
commit 76f9e2eb44444b1b6adaf97c4601f5bd32d352d1
Author: Jay Zhan <[email protected]>
AuthorDate: Sat May 11 20:56:40 2024 +0800
Introduce user-defined signature (#10439)
* introduce new sig
Signed-off-by: jayzhan211 <[email protected]>
* add udfimpl
Signed-off-by: jayzhan211 <[email protected]>
* replace fun
Signed-off-by: jayzhan211 <[email protected]>
* replace array
Signed-off-by: jayzhan211 <[email protected]>
* coalesce
Signed-off-by: jayzhan211 <[email protected]>
* nvl2
Signed-off-by: jayzhan211 <[email protected]>
* rm variadic equal
Signed-off-by: jayzhan211 <[email protected]>
* fix test
Signed-off-by: jayzhan211 <[email protected]>
* rm err msg to fix ci
Signed-off-by: jayzhan211 <[email protected]>
* user defined sig
Signed-off-by: jayzhan211 <[email protected]>
* add err msg
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* fix ci
Signed-off-by: jayzhan211 <[email protected]>
* fix ci
Signed-off-by: jayzhan211 <[email protected]>
* upd comment
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion/expr/src/expr_schema.rs | 7 +-
datafusion/expr/src/signature.rs | 23 ++-
datafusion/expr/src/type_coercion/functions.rs | 176 ++++++++++++++++++---
datafusion/expr/src/udaf.rs | 4 +
datafusion/expr/src/udf.rs | 29 ++++
datafusion/functions-array/src/make_array.rs | 31 +++-
datafusion/functions/src/core/coalesce.rs | 29 +++-
datafusion/functions/src/core/nvl2.rs | 44 ++++--
datafusion/optimizer/src/analyzer/type_coercion.rs | 59 +++++--
datafusion/physical-expr/src/scalar_function.rs | 4 +-
datafusion/sqllogictest/test_files/array.slt | 4 +-
.../sqllogictest/test_files/arrow_typeof.slt | 3 +-
datafusion/sqllogictest/test_files/coalesce.slt | 16 +-
datafusion/sqllogictest/test_files/encoding.slt | 2 +-
datafusion/sqllogictest/test_files/errors.slt | 12 +-
datafusion/sqllogictest/test_files/expr.slt | 15 +-
datafusion/sqllogictest/test_files/math.slt | 4 +-
datafusion/sqllogictest/test_files/scalar.slt | 17 +-
datafusion/sqllogictest/test_files/struct.slt | 2 +-
datafusion/sqllogictest/test_files/timestamps.slt | 2 +-
20 files changed, 359 insertions(+), 124 deletions(-)
diff --git a/datafusion/expr/src/expr_schema.rs
b/datafusion/expr/src/expr_schema.rs
index 4aca52d67c..ce79f9da64 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -23,7 +23,7 @@ use crate::expr::{
};
use crate::field_util::GetFieldAccessSchema;
use crate::type_coercion::binary::get_result_type;
-use crate::type_coercion::functions::data_types;
+use crate::type_coercion::functions::data_types_with_scalar_udf;
use crate::{utils, LogicalPlan, Projection, Subquery};
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
@@ -139,9 +139,10 @@ impl ExprSchemable for Expr {
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
// verify that function is invoked with correct number
and type of arguments as defined in `TypeSignature`
- data_types(&arg_data_types,
func.signature()).map_err(|_| {
+ data_types_with_scalar_udf(&arg_data_types,
func).map_err(|err| {
plan_datafusion_err!(
- "{}",
+ "{} and {}",
+ err,
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs
index e2505d6fd6..5d925c8605 100644
--- a/datafusion/expr/src/signature.rs
+++ b/datafusion/expr/src/signature.rs
@@ -91,15 +91,12 @@ pub enum TypeSignature {
/// # Examples
/// A function such as `concat` is `Variadic(vec![DataType::Utf8,
DataType::LargeUtf8])`
Variadic(Vec<DataType>),
- /// One or more arguments of an arbitrary but equal type.
- /// DataFusion attempts to coerce all argument types to match the first
argument's type
+ /// The acceptable signature and coercions rules to coerce arguments to
this
+ /// signature are special for this function. If this signature is
specified,
+ /// Datafusion will call [`ScalarUDFImpl::coerce_types`] to prepare
argument types.
///
- /// # Examples
- /// Given types in signature should be coercible to the same final type.
- /// A function such as `make_array` is `VariadicEqual`.
- ///
- /// `make_array(i32, i64) -> make_array(i64, i64)`
- VariadicEqual,
+ /// [`ScalarUDFImpl::coerce_types`]:
crate::udf::ScalarUDFImpl::coerce_types
+ UserDefined,
/// One or more arguments with arbitrary types
VariadicAny,
/// Fixed number of arguments of an arbitrary but equal type out of a list
of valid types.
@@ -190,8 +187,8 @@ impl TypeSignature {
.collect::<Vec<&str>>()
.join(", ")]
}
- TypeSignature::VariadicEqual => {
- vec!["CoercibleT, .., CoercibleT".to_string()]
+ TypeSignature::UserDefined => {
+ vec!["UserDefined".to_string()]
}
TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()],
TypeSignature::OneOf(sigs) => {
@@ -255,10 +252,10 @@ impl Signature {
volatility,
}
}
- /// An arbitrary number of arguments of the same type.
- pub fn variadic_equal(volatility: Volatility) -> Self {
+ /// User-defined coercion rules for the function.
+ pub fn user_defined(volatility: Volatility) -> Self {
Self {
- type_signature: TypeSignature::VariadicEqual,
+ type_signature: TypeSignature::UserDefined,
volatility,
}
}
diff --git a/datafusion/expr/src/type_coercion/functions.rs
b/datafusion/expr/src/type_coercion/functions.rs
index eb4f325ff8..583d75e1cc 100644
--- a/datafusion/expr/src/type_coercion/functions.rs
+++ b/datafusion/expr/src/type_coercion/functions.rs
@@ -20,16 +20,114 @@ use std::sync::Arc;
use crate::signature::{
ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD,
};
-use crate::{Signature, TypeSignature};
+use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature};
use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims};
-use datafusion_common::{internal_datafusion_err, internal_err, plan_err,
Result};
+use datafusion_common::{
+ exec_err, internal_datafusion_err, internal_err, plan_err, Result,
+};
use super::binary::{comparison_binary_numeric_coercion, comparison_coercion};
+/// Performs type coercion for scalar function arguments.
+///
+/// Returns the data types to which each argument must be coerced to
+/// match `signature`.
+///
+/// For more details on coercion in general, please see the
+/// [`type_coercion`](crate::type_coercion) module.
+pub fn data_types_with_scalar_udf(
+ current_types: &[DataType],
+ func: &ScalarUDF,
+) -> Result<Vec<DataType>> {
+ let signature = func.signature();
+
+ if current_types.is_empty() {
+ if signature.type_signature.supports_zero_argument() {
+ return Ok(vec![]);
+ } else {
+ return plan_err!(
+ "[data_types_with_scalar_udf] signature {:?} does not support
zero arguments.",
+ &signature.type_signature
+ );
+ }
+ }
+
+ let valid_types =
+ get_valid_types_with_scalar_udf(&signature.type_signature,
current_types, func)?;
+
+ if valid_types
+ .iter()
+ .any(|data_type| data_type == current_types)
+ {
+ return Ok(current_types.to_vec());
+ }
+
+ // Try and coerce the argument types to match the signature, returning the
+ // coerced types from the first matching signature.
+ for valid_types in valid_types {
+ if let Some(types) = maybe_data_types(&valid_types, current_types) {
+ return Ok(types);
+ }
+ }
+
+ // none possible -> Error
+ plan_err!(
+ "[data_types_with_scalar_udf] Coercion from {:?} to the signature {:?}
failed.",
+ current_types,
+ &signature.type_signature
+ )
+}
+
+pub fn data_types_with_aggregate_udf(
+ current_types: &[DataType],
+ func: &AggregateUDF,
+) -> Result<Vec<DataType>> {
+ let signature = func.signature();
+
+ if current_types.is_empty() {
+ if signature.type_signature.supports_zero_argument() {
+ return Ok(vec![]);
+ } else {
+ return plan_err!(
+ "[data_types_with_aggregate_udf] Coercion from {:?} to the
signature {:?} failed.",
+ current_types,
+ &signature.type_signature
+ );
+ }
+ }
+
+ let valid_types = get_valid_types_with_aggregate_udf(
+ &signature.type_signature,
+ current_types,
+ func,
+ )?;
+ if valid_types
+ .iter()
+ .any(|data_type| data_type == current_types)
+ {
+ return Ok(current_types.to_vec());
+ }
+
+ // Try and coerce the argument types to match the signature, returning the
+ // coerced types from the first matching signature.
+ for valid_types in valid_types {
+ if let Some(types) = maybe_data_types(&valid_types, current_types) {
+ return Ok(types);
+ }
+ }
+
+ // none possible -> Error
+ plan_err!(
+ "[data_types_with_aggregate_udf] Coercion from {:?} to the signature
{:?} failed.",
+ current_types,
+ &signature.type_signature
+ )
+}
+
/// Performs type coercion for function arguments.
///
/// Returns the data types to which each argument must be coerced to
@@ -46,7 +144,7 @@ pub fn data_types(
return Ok(vec![]);
} else {
return plan_err!(
- "Coercion from {:?} to the signature {:?} failed.",
+ "[data_types] Coercion from {:?} to the signature {:?}
failed.",
current_types,
&signature.type_signature
);
@@ -72,12 +170,56 @@ pub fn data_types(
// none possible -> Error
plan_err!(
- "Coercion from {:?} to the signature {:?} failed.",
+ "[data_types] Coercion from {:?} to the signature {:?} failed.",
current_types,
&signature.type_signature
)
}
+fn get_valid_types_with_scalar_udf(
+ signature: &TypeSignature,
+ current_types: &[DataType],
+ func: &ScalarUDF,
+) -> Result<Vec<Vec<DataType>>> {
+ let valid_types = match signature {
+ TypeSignature::UserDefined => match func.coerce_types(current_types) {
+ Ok(coerced_types) => vec![coerced_types],
+ Err(e) => return exec_err!("User-defined coercion failed with
{:?}", e),
+ },
+ TypeSignature::OneOf(signatures) => signatures
+ .iter()
+ .filter_map(|t| get_valid_types_with_scalar_udf(t, current_types,
func).ok())
+ .flatten()
+ .collect::<Vec<_>>(),
+ _ => get_valid_types(signature, current_types)?,
+ };
+
+ Ok(valid_types)
+}
+
+fn get_valid_types_with_aggregate_udf(
+ signature: &TypeSignature,
+ current_types: &[DataType],
+ func: &AggregateUDF,
+) -> Result<Vec<Vec<DataType>>> {
+ let valid_types = match signature {
+ TypeSignature::UserDefined => match func.coerce_types(current_types) {
+ Ok(coerced_types) => vec![coerced_types],
+ Err(e) => return exec_err!("User-defined coercion failed with
{:?}", e),
+ },
+ TypeSignature::OneOf(signatures) => signatures
+ .iter()
+ .filter_map(|t| {
+ get_valid_types_with_aggregate_udf(t, current_types, func).ok()
+ })
+ .flatten()
+ .collect::<Vec<_>>(),
+ _ => get_valid_types(signature, current_types)?,
+ };
+
+ Ok(valid_types)
+}
+
/// Returns a Vec of all possible valid argument types for the given signature.
fn get_valid_types(
signature: &TypeSignature,
@@ -184,32 +326,14 @@ fn get_valid_types(
.iter()
.map(|valid_type| (0..*number).map(|_|
valid_type.clone()).collect())
.collect(),
- TypeSignature::VariadicEqual => {
- let new_type = current_types.iter().skip(1).try_fold(
- current_types.first().unwrap().clone(),
- |acc, x| {
- // The coerced types found by `comparison_coercion` are
not guaranteed to be
- // coercible for the arguments. `comparison_coercion`
returns more loose
- // types that can be coerced to both `acc` and `x` for
comparison purpose.
- // See `maybe_data_types` for the actual coercion.
- let coerced_type = comparison_coercion(&acc, x);
- if let Some(coerced_type) = coerced_type {
- Ok(coerced_type)
- } else {
- internal_err!("Coercion from {acc:?} to {x:?} failed.")
- }
- },
- );
-
- match new_type {
- Ok(new_type) => vec![vec![new_type; current_types.len()]],
- Err(e) => return Err(e),
- }
+ TypeSignature::UserDefined => {
+ return internal_err!(
+ "User-defined signature should be handled by function-specific
coerce_types."
+ )
}
TypeSignature::VariadicAny => {
vec![current_types.to_vec()]
}
-
TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
TypeSignature::ArraySignature(ref function_signature) => match
function_signature
{
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index 67c3b51ca3..e5a47ddcd8 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -195,6 +195,10 @@ impl AggregateUDF {
pub fn create_groups_accumulator(&self) -> Result<Box<dyn
GroupsAccumulator>> {
self.inner.create_groups_accumulator()
}
+
+ pub fn coerce_types(&self, _args: &[DataType]) -> Result<Vec<DataType>> {
+ not_impl_err!("coerce_types not implemented for {:?} yet", self.name())
+ }
}
impl<F> From<F> for AggregateUDF
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index 29ee4a86e5..fadea26e7f 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -213,6 +213,11 @@ impl ScalarUDF {
pub fn short_circuits(&self) -> bool {
self.inner.short_circuits()
}
+
+ /// See [`ScalarUDFImpl::coerce_types`] for more details.
+ pub fn coerce_types(&self, arg_types: &[DataType]) ->
Result<Vec<DataType>> {
+ self.inner.coerce_types(arg_types)
+ }
}
impl<F> From<F> for ScalarUDF
@@ -420,6 +425,29 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
fn short_circuits(&self) -> bool {
false
}
+
+ /// Coerce arguments of a function call to types that the function can
evaluate.
+ ///
+ /// This function is only called if [`ScalarUDFImpl::signature`] returns
[`crate::TypeSignature::UserDefined`]. Most
+ /// UDFs should return one of the other variants of `TypeSignature` which
handle common
+ /// cases
+ ///
+ /// See the [type coercion module](crate::type_coercion)
+ /// documentation for more details on type coercion
+ ///
+ /// For example, if your function requires a floating point arguments, but
the user calls
+ /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types
could return `[DataType::Float64]`
+ /// to ensure the argument was cast to `1::double`
+ ///
+ /// # Parameters
+ /// * `arg_types`: The argument types of the arguments this function with
+ ///
+ /// # Return value
+ /// A Vec the same length as `arg_types`. DataFusion will `CAST` the
function call
+ /// arguments to these specific types.
+ fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
+ not_impl_err!("Function {} does not implement coerce_types",
self.name())
+ }
}
/// ScalarUDF that adds an alias to the underlying function. It is better to
@@ -446,6 +474,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}
+
fn name(&self) -> &str {
self.inner.name()
}
diff --git a/datafusion/functions-array/src/make_array.rs
b/datafusion/functions-array/src/make_array.rs
index 770276938f..4f7dda933f 100644
--- a/datafusion/functions-array/src/make_array.rs
+++ b/datafusion/functions-array/src/make_array.rs
@@ -26,12 +26,12 @@ use arrow_array::{
use arrow_buffer::OffsetBuffer;
use arrow_schema::DataType::{LargeList, List, Null};
use arrow_schema::{DataType, Field};
+use datafusion_common::internal_err;
use datafusion_common::{plan_err, utils::array_into_list_array, Result};
use datafusion_expr::expr::ScalarFunction;
-use datafusion_expr::Expr;
-use datafusion_expr::{
- ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility,
-};
+use datafusion_expr::type_coercion::binary::comparison_coercion;
+use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
+use datafusion_expr::{Expr, TypeSignature};
use crate::utils::make_scalar_function;
@@ -58,10 +58,10 @@ impl MakeArray {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
- vec![TypeSignature::VariadicEqual, TypeSignature::Any(0)],
+ vec![TypeSignature::UserDefined, TypeSignature::Any(0)],
Volatility::Immutable,
),
- aliases: vec![String::from("make_array"),
String::from("make_list")],
+ aliases: vec![String::from("make_list")],
}
}
}
@@ -111,6 +111,25 @@ impl ScalarUDFImpl for MakeArray {
fn aliases(&self) -> &[String] {
&self.aliases
}
+
+ fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+ let new_type = arg_types.iter().skip(1).try_fold(
+ arg_types.first().unwrap().clone(),
+ |acc, x| {
+ // The coerced types found by `comparison_coercion` are not
guaranteed to be
+ // coercible for the arguments. `comparison_coercion` returns
more loose
+ // types that can be coerced to both `acc` and `x` for
comparison purpose.
+ // See `maybe_data_types` for the actual coercion.
+ let coerced_type = comparison_coercion(&acc, x);
+ if let Some(coerced_type) = coerced_type {
+ Ok(coerced_type)
+ } else {
+ internal_err!("Coercion from {acc:?} to {x:?} failed.")
+ }
+ },
+ )?;
+ Ok(vec![new_type; arg_types.len()])
+ }
}
/// `make_array_inner` is the implementation of the `make_array` function.
diff --git a/datafusion/functions/src/core/coalesce.rs
b/datafusion/functions/src/core/coalesce.rs
index 76f2a3ed74..63778eb773 100644
--- a/datafusion/functions/src/core/coalesce.rs
+++ b/datafusion/functions/src/core/coalesce.rs
@@ -22,8 +22,8 @@ use arrow::compute::kernels::zip::zip;
use arrow::compute::{and, is_not_null, is_null};
use arrow::datatypes::DataType;
-use datafusion_common::{exec_err, Result};
-use datafusion_expr::type_coercion::functions::data_types;
+use datafusion_common::{exec_err, internal_err, Result};
+use datafusion_expr::type_coercion::binary::comparison_coercion;
use datafusion_expr::ColumnarValue;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
@@ -41,7 +41,7 @@ impl Default for CoalesceFunc {
impl CoalesceFunc {
pub fn new() -> Self {
Self {
- signature: Signature::variadic_equal(Volatility::Immutable),
+ signature: Signature::user_defined(Volatility::Immutable),
}
}
}
@@ -60,9 +60,7 @@ impl ScalarUDFImpl for CoalesceFunc {
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- // COALESCE has multiple args and they might get coerced, get a
preview of this
- let coerced_types = data_types(arg_types, self.signature());
- coerced_types.map(|types| types[0].clone())
+ Ok(arg_types[0].clone())
}
/// coalesce evaluates to the first value which is not NULL
@@ -124,6 +122,25 @@ impl ScalarUDFImpl for CoalesceFunc {
fn short_circuits(&self) -> bool {
true
}
+
+ fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+ let new_type = arg_types.iter().skip(1).try_fold(
+ arg_types.first().unwrap().clone(),
+ |acc, x| {
+ // The coerced types found by `comparison_coercion` are not
guaranteed to be
+ // coercible for the arguments. `comparison_coercion` returns
more loose
+ // types that can be coerced to both `acc` and `x` for
comparison purpose.
+ // See `maybe_data_types` for the actual coercion.
+ let coerced_type = comparison_coercion(&acc, x);
+ if let Some(coerced_type) = coerced_type {
+ Ok(coerced_type)
+ } else {
+ internal_err!("Coercion from {acc:?} to {x:?} failed.")
+ }
+ },
+ )?;
+ Ok(vec![new_type; arg_types.len()])
+ }
}
#[cfg(test)]
diff --git a/datafusion/functions/src/core/nvl2.rs
b/datafusion/functions/src/core/nvl2.rs
index 66b9ef566a..573ac72425 100644
--- a/datafusion/functions/src/core/nvl2.rs
+++ b/datafusion/functions/src/core/nvl2.rs
@@ -19,8 +19,11 @@ use arrow::array::Array;
use arrow::compute::is_not_null;
use arrow::compute::kernels::zip::zip;
use arrow::datatypes::DataType;
-use datafusion_common::{internal_err, plan_datafusion_err, Result};
-use datafusion_expr::{utils, ColumnarValue, ScalarUDFImpl, Signature,
Volatility};
+use datafusion_common::{exec_err, internal_err, Result};
+use datafusion_expr::{
+ type_coercion::binary::comparison_coercion, ColumnarValue, ScalarUDFImpl,
Signature,
+ Volatility,
+};
#[derive(Debug)]
pub struct NVL2Func {
@@ -36,7 +39,7 @@ impl Default for NVL2Func {
impl NVL2Func {
pub fn new() -> Self {
Self {
- signature: Signature::variadic_equal(Volatility::Immutable),
+ signature: Signature::user_defined(Volatility::Immutable),
}
}
}
@@ -55,22 +58,37 @@ impl ScalarUDFImpl for NVL2Func {
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- if arg_types.len() != 3 {
- return Err(plan_datafusion_err!(
- "{}",
- utils::generate_signature_error_msg(
- self.name(),
- self.signature().clone(),
- arg_types,
- )
- ));
- }
Ok(arg_types[1].clone())
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
nvl2_func(args)
}
+
+ fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+ if arg_types.len() != 3 {
+ return exec_err!(
+ "NVL2 takes exactly three arguments, but got {}",
+ arg_types.len()
+ );
+ }
+ let new_type = arg_types.iter().skip(1).try_fold(
+ arg_types.first().unwrap().clone(),
+ |acc, x| {
+ // The coerced types found by `comparison_coercion` are not
guaranteed to be
+ // coercible for the arguments. `comparison_coercion` returns
more loose
+ // types that can be coerced to both `acc` and `x` for
comparison purpose.
+ // See `maybe_data_types` for the actual coercion.
+ let coerced_type = comparison_coercion(&acc, x);
+ if let Some(coerced_type) = coerced_type {
+ Ok(coerced_type)
+ } else {
+ internal_err!("Coercion from {acc:?} to {x:?} failed.")
+ }
+ },
+ )?;
+ Ok(vec![new_type; arg_types.len()])
+ }
}
fn nvl2_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 61b1d1d77b..e5c7afa10e 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -37,7 +37,9 @@ use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::binary::{
comparison_coercion, get_input_types, like_coercion,
};
-use datafusion_expr::type_coercion::functions::data_types;
+use datafusion_expr::type_coercion::functions::{
+ data_types_with_aggregate_udf, data_types_with_scalar_udf,
+};
use datafusion_expr::type_coercion::other::{
get_coerce_type_for_case_expression, get_coerce_type_for_list,
};
@@ -45,8 +47,8 @@ use datafusion_expr::type_coercion::{is_datetime,
is_utf8_or_large_utf8};
use datafusion_expr::utils::merge_schema;
use datafusion_expr::{
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown,
not,
- type_coercion, AggregateFunction, Expr, ExprSchemable, LogicalPlan,
Operator,
- ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits,
+ type_coercion, AggregateFunction, AggregateUDF, Expr, ExprSchemable,
LogicalPlan,
+ Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound,
WindowFrameUnits,
};
use crate::analyzer::AnalyzerRule;
@@ -303,8 +305,11 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
Ok(Transformed::yes(Expr::Case(case)))
}
Expr::ScalarFunction(ScalarFunction { func, args }) => {
- let new_expr =
- coerce_arguments_for_signature(args, self.schema,
func.signature())?;
+ let new_expr = coerce_arguments_for_signature_with_scalar_udf(
+ args,
+ self.schema,
+ &func,
+ )?;
let new_expr = coerce_arguments_for_fun(new_expr, self.schema,
&func)?;
Ok(Transformed::yes(Expr::ScalarFunction(
ScalarFunction::new_udf(func, new_expr),
@@ -337,10 +342,10 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
)))
}
AggregateFunctionDefinition::UDF(fun) => {
- let new_expr = coerce_arguments_for_signature(
+ let new_expr =
coerce_arguments_for_signature_with_aggregate_udf(
args,
self.schema,
- fun.signature(),
+ &fun,
)?;
Ok(Transformed::yes(Expr::AggregateFunction(
expr::AggregateFunction::new_udf(
@@ -532,10 +537,37 @@ fn get_casted_expr_for_bool_op(expr: Expr, schema:
&DFSchema) -> Result<Expr> {
/// `signature`, if possible.
///
/// See the module level documentation for more detail on coercion.
-fn coerce_arguments_for_signature(
+fn coerce_arguments_for_signature_with_scalar_udf(
expressions: Vec<Expr>,
schema: &DFSchema,
- signature: &Signature,
+ func: &ScalarUDF,
+) -> Result<Vec<Expr>> {
+ if expressions.is_empty() {
+ return Ok(expressions);
+ }
+
+ let current_types = expressions
+ .iter()
+ .map(|e| e.get_type(schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ let new_types = data_types_with_scalar_udf(¤t_types, func)?;
+
+ expressions
+ .into_iter()
+ .enumerate()
+ .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
+ .collect()
+}
+
+/// Returns `expressions` coerced to types compatible with
+/// `signature`, if possible.
+///
+/// See the module level documentation for more detail on coercion.
+fn coerce_arguments_for_signature_with_aggregate_udf(
+ expressions: Vec<Expr>,
+ schema: &DFSchema,
+ func: &AggregateUDF,
) -> Result<Vec<Expr>> {
if expressions.is_empty() {
return Ok(expressions);
@@ -546,7 +578,7 @@ fn coerce_arguments_for_signature(
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
- let new_types = data_types(¤t_types, signature)?;
+ let new_types = data_types_with_aggregate_udf(¤t_types, func)?;
expressions
.into_iter()
@@ -833,12 +865,9 @@ mod test {
signature: Signature::uniform(1, vec![DataType::Float32],
Volatility::Stable),
})
.call(vec![lit("Apple")]);
- let plan_err = Projection::try_new(vec![udf], empty)
+ Projection::try_new(vec![udf], empty)
.expect_err("Expected an error due to incorrect function input");
- let expected_error = "Error during planning: No function matches the
given name and argument types 'TestScalarUDF(Utf8)'. You might need to add
explicit type casts.";
-
- assert!(plan_err.to_string().starts_with(expected_error));
Ok(())
}
@@ -914,7 +943,7 @@ mod test {
.err()
.unwrap();
assert_eq!(
- "type_coercion\ncaused by\nError during planning: Coercion from
[Utf8] to the signature Uniform(1, [Float64]) failed.",
+ "type_coercion\ncaused by\nError during planning:
[data_types_with_aggregate_udf] Coercion from [Utf8] to the signature
Uniform(1, [Float64]) failed.",
err.strip_backtrace()
);
Ok(())
diff --git a/datafusion/physical-expr/src/scalar_function.rs
b/datafusion/physical-expr/src/scalar_function.rs
index 180f2a7946..1244a9b4db 100644
--- a/datafusion/physical-expr/src/scalar_function.rs
+++ b/datafusion/physical-expr/src/scalar_function.rs
@@ -39,7 +39,7 @@ use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::{internal_err, DFSchema, Result};
-use datafusion_expr::type_coercion::functions::data_types;
+use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, FuncMonotonicity,
ScalarUDF};
use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal};
@@ -220,7 +220,7 @@ pub fn create_physical_expr(
.collect::<Result<Vec<_>>>()?;
// verify that input data types is consistent with function's
`TypeSignature`
- data_types(&input_expr_types, fun.signature())?;
+ data_types_with_scalar_udf(&input_expr_types, fun)?;
// Since we have arg_types, we dont need args and schema.
let return_type =
diff --git a/datafusion/sqllogictest/test_files/array.slt
b/datafusion/sqllogictest/test_files/array.slt
index eaec0f4d8d..eeb5dc01b6 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -1137,7 +1137,7 @@ from arrays_values_without_nulls;
## array_element (aliases: array_extract, list_extract, list_element)
# array_element error
-query error DataFusion error: Error during planning: No function matches the
given name and argument types 'array_element\(Int64, Int64\)'. You might need
to add explicit type casts.\n\tCandidate functions:\n\tarray_element\(array,
index\)
+query error
select array_element(1, 2);
# array_element with null
@@ -4625,7 +4625,7 @@ NULL 10
## array_dims (aliases: `list_dims`)
# array dims error
-query error DataFusion error: Error during planning: No function matches the
given name and argument types 'array_dims\(Int64\)'. You might need to add
explicit type casts.\n\tCandidate functions:\n\tarray_dims\(array\)
+query error
select array_dims(1);
# array_dims scalar function
diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt
b/datafusion/sqllogictest/test_files/arrow_typeof.slt
index 3e8694f3b2..94cce61245 100644
--- a/datafusion/sqllogictest/test_files/arrow_typeof.slt
+++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt
@@ -92,10 +92,9 @@ SELECT arrow_cast('1', 'Int16')
1
# Basic error test
-query error DataFusion error: Error during planning: No function matches the
given name and argument types 'arrow_cast\(Utf8\)'. You might need to add
explicit type casts.
+query error
SELECT arrow_cast('1')
-
query error DataFusion error: Error during planning: arrow_cast requires its
second argument to be a constant string, got Literal\(Int64\(43\)\)
SELECT arrow_cast('1', 43)
diff --git a/datafusion/sqllogictest/test_files/coalesce.slt
b/datafusion/sqllogictest/test_files/coalesce.slt
index 527d4fe9c4..a0317ac4a5 100644
--- a/datafusion/sqllogictest/test_files/coalesce.slt
+++ b/datafusion/sqllogictest/test_files/coalesce.slt
@@ -23,7 +23,7 @@ select coalesce(1, 2, 3);
1
# test with first null
-query IT
+query ?T
select coalesce(null, 3, 2, 1), arrow_typeof(coalesce(null, 3, 2, 1));
----
3 Int64
@@ -35,7 +35,7 @@ select coalesce(null, null);
NULL
# cast to float
-query RT
+query IT
select
coalesce(1, 2.0),
arrow_typeof(coalesce(1, 2.0))
@@ -51,7 +51,7 @@ select
----
2 Float64
-query RT
+query IT
select
coalesce(1, arrow_cast(2.0, 'Float32')),
arrow_typeof(coalesce(1, arrow_cast(2.0, 'Float32')))
@@ -177,7 +177,7 @@ select
2 Decimal256(22, 2)
# coalesce string
-query TT
+query T?
select
coalesce('', 'test'),
coalesce(null, 'test');
@@ -226,7 +226,7 @@ select coalesce(column1, 'none_set') from test1;
foo
none_set
-query T
+query ?
select coalesce(null, column1, 'none_set') from test1;
----
foo
@@ -248,12 +248,12 @@ select coalesce(34, arrow_cast(123, 'Dictionary(Int32,
Int8)'));
----
34
-query I
+query ?
select coalesce(arrow_cast(123, 'Dictionary(Int32, Int8)'), 34);
----
123
-query I
+query ?
select coalesce(null, 34, arrow_cast(123, 'Dictionary(Int32, Int8)'));
----
34
@@ -288,7 +288,7 @@ SELECT COALESCE(c1, c2) FROM test
NULL
# numeric string is coerced to numeric in both Postgres and DuckDB
-query T
+query I
SELECT COALESCE(c1, c2, '-1') FROM test;
----
0
diff --git a/datafusion/sqllogictest/test_files/encoding.slt
b/datafusion/sqllogictest/test_files/encoding.slt
index 9f4f508e23..626af88aa9 100644
--- a/datafusion/sqllogictest/test_files/encoding.slt
+++ b/datafusion/sqllogictest/test_files/encoding.slt
@@ -40,7 +40,7 @@ select decode(12, 'hex')
query error DataFusion error: Error during planning: There is no built\-in
encoding named 'non_encoding', currently supported encodings are: base64, hex
select decode(hex_field, 'non_encoding') from test;
-query error DataFusion error: Error during planning: No function matches the
given name and argument types 'to_hex\(Utf8\)'. You might need to add explicit
type casts.\n\tCandidate functions:\n\tto_hex\(Int64\)
+query error
select to_hex(hex_field) from test;
# Arrays tests
diff --git a/datafusion/sqllogictest/test_files/errors.slt
b/datafusion/sqllogictest/test_files/errors.slt
index ab281eac31..b5464e2a27 100644
--- a/datafusion/sqllogictest/test_files/errors.slt
+++ b/datafusion/sqllogictest/test_files/errors.slt
@@ -38,7 +38,7 @@ WITH HEADER ROW
LOCATION '../../testing/data/csv/aggregate_test_100.csv'
# csv_query_error
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'sin\(Utf8\)'. You might need to add explicit
type casts.\n\tCandidate functions:\n\tsin\(Float64/Float32\)
+statement error
SELECT sin(c1) FROM aggregate_test_100
# cast_expressions_error
@@ -80,23 +80,23 @@ SELECT COUNT(*) FROM
way.too.many.namespaces.as.ident.prefixes.aggregate_test_10
#
# error message for wrong function signature (Variadic: arbitrary number of
args all from some common types)
-statement error Error during planning: No function matches the given name and
argument types 'concat\(\)'. You might need to add explicit type
casts.\n\tCandidate functions:\n\tconcat\(Utf8, ..\)
+statement error
SELECT concat();
# error message for wrong function signature (Uniform: t args all from some
common types)
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'nullif\(Int64\)'. You might need to add
explicit type casts.
+statement error
SELECT nullif(1);
# error message for wrong function signature (Exact: exact number of args of
an exact type)
-statement error Error during planning: No function matches the given name and
argument types 'pi\(Float64\)'. You might need to add explicit type
casts.\n\tCandidate functions:\n\tpi\(\)
+statement error
SELECT pi(3.14);
# error message for wrong function signature (Any: fixed number of args of
arbitrary types)
-statement error Error during planning: No function matches the given name and
argument types 'arrow_typeof\(Int64, Int64\)'. You might need to add explicit
type casts.\n\tCandidate functions:\n\tarrow_typeof\(Any\)
+statement error
SELECT arrow_typeof(1, 1);
# error message for wrong function signature (OneOf: fixed number of args of
arbitrary types)
-statement error Error during planning: No function matches the given name and
argument types 'power\(Int64, Int64, Int64\)'. You might need to add explicit
type casts.\n\tCandidate functions:\n\tpower\(Int64, Int64\)\n\tpower\(Float64,
Float64\)
+statement error
SELECT power(1, 2, 3);
#
diff --git a/datafusion/sqllogictest/test_files/expr.slt
b/datafusion/sqllogictest/test_files/expr.slt
index 7e7ebd8529..129a672083 100644
--- a/datafusion/sqllogictest/test_files/expr.slt
+++ b/datafusion/sqllogictest/test_files/expr.slt
@@ -1899,22 +1899,21 @@ a
# The 'from' and 'for' parameters don't support string types, because they
should be treated as
# regular expressions, which we have not implemented yet.
-query error DataFusion error: Error during planning: No function matches the
given name and argument types
+query error
SELECT substring('alphabet' FROM '3')
-query error DataFusion error: Error during planning: No function matches the
given name and argument types
+query error
SELECT substring('alphabet' FROM '3' FOR '2')
-query error DataFusion error: Error during planning: No function matches the
given name and argument types
+query error
SELECT substring('alphabet' FROM '3' FOR 2)
-query error DataFusion error: Error during planning: No function matches the
given name and argument types
+query error
SELECT substring('alphabet' FROM 3 FOR '2')
-query error DataFusion error: Error during planning: No function matches the
given name and argument types
+query error
SELECT substring('alphabet' FOR '2')
-
##### csv_query_nullif_divide_by_0
@@ -2275,13 +2274,13 @@ select f64, round(1.0 / f64) as i64_1, acos(round(1.0 /
f64)) from doubles;
10.1 0 1.570796326795
# common subexpr with coalesce (short-circuited)
-query RRR rowsort
+query RRR
select f64, coalesce(1.0 / f64, 0.0), acos(coalesce(1.0 / f64, 0.0)) from
doubles;
----
10.1 0.09900990099 1.471623942989
# common subexpr with coalesce (short-circuited) and alias
-query RRR rowsort
+query RRR
select f64, coalesce(1.0 / f64, 0.0) as f64_1, acos(coalesce(1.0 / f64, 0.0))
from doubles;
----
10.1 0.09900990099 1.471623942989
diff --git a/datafusion/sqllogictest/test_files/math.slt
b/datafusion/sqllogictest/test_files/math.slt
index 802323ca45..3315ff4549 100644
--- a/datafusion/sqllogictest/test_files/math.slt
+++ b/datafusion/sqllogictest/test_files/math.slt
@@ -113,11 +113,11 @@ SELECT iszero(1.0), iszero(0.0), iszero(-0.0),
iszero(NULL)
false true true NULL
# abs: empty argumnet
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'abs\(\)'. You might need to add explicit
type casts.\n\tCandidate functions:\n\tabs\(Any\)
+statement error
SELECT abs();
# abs: wrong number of arguments
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'abs\(Int64, Int64\)'. You might need to add
explicit type casts.\n\tCandidate functions:\n\tabs\(Any\)
+statement error
SELECT abs(1, 2);
# abs: unsupported argument type
diff --git a/datafusion/sqllogictest/test_files/scalar.slt
b/datafusion/sqllogictest/test_files/scalar.slt
index 7fb2d55ff8..c52881b7b0 100644
--- a/datafusion/sqllogictest/test_files/scalar.slt
+++ b/datafusion/sqllogictest/test_files/scalar.slt
@@ -1799,34 +1799,33 @@ statement ok
drop table test
# error message for wrong function signature (Variadic: arbitrary number of
args all from some common types)
-statement error Error during planning: No function matches the given name and
argument types 'concat\(\)'. You might need to add explicit type
casts.\n\tCandidate functions:\n\tconcat\(Utf8, ..\)
+statement error
SELECT concat();
# error message for wrong function signature (Uniform: t args all from some
common types)
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'nullif\(Int64\)'. You might need to add
explicit type casts.
+statement error
SELECT nullif(1);
-
# error message for wrong function signature (Exact: exact number of args of
an exact type)
-statement error Error during planning: No function matches the given name and
argument types 'pi\(Float64\)'. You might need to add explicit type
casts.\n\tCandidate functions:\n\tpi\(\)
+statement error
SELECT pi(3.14);
# error message for wrong function signature (Any: fixed number of args of
arbitrary types)
-statement error Error during planning: No function matches the given name and
argument types 'arrow_typeof\(Int64, Int64\)'. You might need to add explicit
type casts.\n\tCandidate functions:\n\tarrow_typeof\(Any\)
+statement error
SELECT arrow_typeof(1, 1);
# error message for wrong function signature (OneOf: fixed number of args of
arbitrary types)
-statement error Error during planning: No function matches the given name and
argument types 'power\(Int64, Int64, Int64\)'. You might need to add explicit
type casts.\n\tCandidate functions:\n\tpower\(Int64, Int64\)\n\tpower\(Float64,
Float64\)
+statement error
SELECT power(1, 2, 3);
# The following functions need 1 argument
-statement error Error during planning: No function matches the given name and
argument types 'abs\(\)'. You might need to add explicit type
casts.\n\tCandidate functions:\n\tabs\(Any\)
+statement error
SELECT abs();
-statement error Error during planning: No function matches the given name and
argument types 'acos\(\)'. You might need to add explicit type
casts.\n\tCandidate functions:\n\tacos\(Float64/Float32\)
+statement error
SELECT acos();
-statement error Error during planning: No function matches the given name and
argument types 'isnan\(\)'. You might need to add explicit type
casts.\n\tCandidate functions:\n\tisnan\(Float32\)\n\tisnan\(Float64\)
+statement error
SELECT isnan();
# turn off enable_ident_normalization
diff --git a/datafusion/sqllogictest/test_files/struct.slt
b/datafusion/sqllogictest/test_files/struct.slt
index 3e685cbb45..46a08709c3 100644
--- a/datafusion/sqllogictest/test_files/struct.slt
+++ b/datafusion/sqllogictest/test_files/struct.slt
@@ -92,7 +92,7 @@ physical_plan
02)--MemoryExec: partitions=1, partition_sizes=[1]
# error on 0 arguments
-query error DataFusion error: Error during planning: No function matches the
given name and argument types 'named_struct\(\)'. You might need to add
explicit type casts.
+query error
select named_struct();
# error on odd number of arguments #1
diff --git a/datafusion/sqllogictest/test_files/timestamps.slt
b/datafusion/sqllogictest/test_files/timestamps.slt
index 32a28231d0..13fb8fba0d 100644
--- a/datafusion/sqllogictest/test_files/timestamps.slt
+++ b/datafusion/sqllogictest/test_files/timestamps.slt
@@ -538,7 +538,7 @@ select to_timestamp_seconds(cast (1 as int));
##########
# invalid second arg type
-query error DataFusion error: Error during planning: No function matches the
given name and argument types 'date_bin\(Interval\(MonthDayNano\), Int64,
Timestamp\(Nanosecond, None\)\)'\.
+query error
SELECT DATE_BIN(INTERVAL '0 second', 25, TIMESTAMP '1970-01-01T00:00:00Z')
# not support interval 0
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]