This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 71f9d0c1d6 Signature::Coercible with user defined implicit casting
(#14440)
71f9d0c1d6 is described below
commit 71f9d0c1d69c2404422f9f19bf78f0062c6b2c5e
Author: Jay Zhan <[email protected]>
AuthorDate: Mon Feb 17 21:18:07 2025 +0800
Signature::Coercible with user defined implicit casting (#14440)
* coerciblev2
Signed-off-by: Jay Zhan <[email protected]>
* repeat
Signed-off-by: Jay Zhan <[email protected]>
* fix possible types
* replace all coerciblev1
* cleanup
* remove specialize logic
* comment
* err msg
* ci escape
* rm coerciblev1
Signed-off-by: Jay Zhan <[email protected]>
* fmt
* rename
* rename
* refactor
* make default_casted_type private
* cleanup
* fmt
* integer
* rm binary for ascii
* rm unused
* conflit
* fmt
* Rename get_example_types, make method on TypeSignatureClass
* Move more logic into TypeSignatureClass
* fix docs
* 46
* enum
* fmt
* fmt
* doc
* upd doc
---------
Signed-off-by: Jay Zhan <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
Cargo.lock | 1 +
datafusion/catalog/src/information_schema.rs | 6 +-
datafusion/common/src/types/native.rs | 9 +
datafusion/expr-common/Cargo.toml | 1 +
datafusion/expr-common/src/signature.rs | 359 +++++++++++++++++++++----
datafusion/expr/src/type_coercion/functions.rs | 107 +++-----
datafusion/functions/src/datetime/date_part.rs | 28 +-
datafusion/functions/src/string/ascii.rs | 11 +-
datafusion/functions/src/string/repeat.rs | 13 +-
datafusion/sqllogictest/test_files/expr.slt | 12 +-
10 files changed, 403 insertions(+), 144 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 29d88d80aa..bc8b2943b2 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2009,6 +2009,7 @@ version = "45.0.0"
dependencies = [
"arrow",
"datafusion-common",
+ "indexmap 2.7.1",
"itertools 0.14.0",
"paste",
]
diff --git a/datafusion/catalog/src/information_schema.rs
b/datafusion/catalog/src/information_schema.rs
index e68e636989..7948c0299d 100644
--- a/datafusion/catalog/src/information_schema.rs
+++ b/datafusion/catalog/src/information_schema.rs
@@ -405,7 +405,7 @@ fn get_udf_args_and_return_types(
udf: &Arc<ScalarUDF>,
) -> Result<Vec<(Vec<String>, Option<String>)>> {
let signature = udf.signature();
- let arg_types = signature.type_signature.get_possible_types();
+ let arg_types = signature.type_signature.get_example_types();
if arg_types.is_empty() {
Ok(vec![(vec![], None)])
} else {
@@ -428,7 +428,7 @@ fn get_udaf_args_and_return_types(
udaf: &Arc<AggregateUDF>,
) -> Result<Vec<(Vec<String>, Option<String>)>> {
let signature = udaf.signature();
- let arg_types = signature.type_signature.get_possible_types();
+ let arg_types = signature.type_signature.get_example_types();
if arg_types.is_empty() {
Ok(vec![(vec![], None)])
} else {
@@ -452,7 +452,7 @@ fn get_udwf_args_and_return_types(
udwf: &Arc<WindowUDF>,
) -> Result<Vec<(Vec<String>, Option<String>)>> {
let signature = udwf.signature();
- let arg_types = signature.type_signature.get_possible_types();
+ let arg_types = signature.type_signature.get_example_types();
if arg_types.is_empty() {
Ok(vec![(vec![], None)])
} else {
diff --git a/datafusion/common/src/types/native.rs
b/datafusion/common/src/types/native.rs
index a4c4dfc7b1..39c79b4b99 100644
--- a/datafusion/common/src/types/native.rs
+++ b/datafusion/common/src/types/native.rs
@@ -198,6 +198,11 @@ impl LogicalType for NativeType {
TypeSignature::Native(self)
}
+ /// Returns the default casted type for the given arrow type
+ ///
+ /// For types like String or Date, multiple arrow types mapped to the same
logical type
+ /// If the given arrow type is one of them, we return the same type
+ /// Otherwise, we define the default casted type for the given arrow type
fn default_cast_for(&self, origin: &DataType) -> Result<DataType> {
use DataType::*;
@@ -226,6 +231,10 @@ impl LogicalType for NativeType {
(Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s),
(Self::Decimal(p, s), _) => Decimal256(*p, *s),
(Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()),
+ // If given type is Date, return the same type
+ (Self::Date, origin) if matches!(origin, Date32 | Date64) => {
+ origin.to_owned()
+ }
(Self::Date, _) => Date32,
(Self::Time(tu), _) => match tu {
TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu),
diff --git a/datafusion/expr-common/Cargo.toml
b/datafusion/expr-common/Cargo.toml
index 109d8e0b89..abc78a9f08 100644
--- a/datafusion/expr-common/Cargo.toml
+++ b/datafusion/expr-common/Cargo.toml
@@ -39,5 +39,6 @@ path = "src/lib.rs"
[dependencies]
arrow = { workspace = true }
datafusion-common = { workspace = true }
+indexmap = { workspace = true }
itertools = { workspace = true }
paste = "^1.0"
diff --git a/datafusion/expr-common/src/signature.rs
b/datafusion/expr-common/src/signature.rs
index 4ca4961d7b..ba6fadbf72 100644
--- a/datafusion/expr-common/src/signature.rs
+++ b/datafusion/expr-common/src/signature.rs
@@ -19,11 +19,14 @@
//! and return types of functions in DataFusion.
use std::fmt::Display;
+use std::hash::Hash;
use crate::type_coercion::aggregates::NUMERICS;
use arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
-use datafusion_common::types::{LogicalTypeRef, NativeType};
+use datafusion_common::internal_err;
+use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType};
use datafusion_common::utils::ListCoercion;
+use indexmap::IndexSet;
use itertools::Itertools;
/// Constant that is used as a placeholder for any valid timezone.
@@ -127,12 +130,11 @@ pub enum TypeSignature {
Exact(Vec<DataType>),
/// One or more arguments belonging to the [`TypeSignatureClass`], in
order.
///
- /// For example, `Coercible(vec![logical_float64()])` accepts
- /// arguments like `vec![Int32]` or `vec![Float32]`
- /// since i32 and f32 can be cast to f64
+ /// [`Coercion`] contains not only the desired type but also the allowed
casts.
+ /// For example, if you expect a function has string type, but you also
allow it to be casted from binary type.
///
/// For functions that take no arguments (e.g. `random()`) see
[`TypeSignature::Nullary`].
- Coercible(Vec<TypeSignatureClass>),
+ Coercible(Vec<Coercion>),
/// One or more arguments coercible to a single, comparable type.
///
/// Each argument will be coerced to a single type using the
@@ -209,14 +211,13 @@ impl TypeSignature {
#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Hash)]
pub enum TypeSignatureClass {
Timestamp,
- Date,
Time,
Interval,
Duration,
Native(LogicalTypeRef),
// TODO:
// Numeric
- // Integer
+ Integer,
}
impl Display for TypeSignatureClass {
@@ -225,6 +226,89 @@ impl Display for TypeSignatureClass {
}
}
+impl TypeSignatureClass {
+ /// Get example acceptable types for this `TypeSignatureClass`
+ ///
+ /// This is used for `information_schema` and can be used to generate
+ /// documentation or error messages.
+ fn get_example_types(&self) -> Vec<DataType> {
+ match self {
+ TypeSignatureClass::Native(l) => get_data_types(l.native()),
+ TypeSignatureClass::Timestamp => {
+ vec![
+ DataType::Timestamp(TimeUnit::Nanosecond, None),
+ DataType::Timestamp(
+ TimeUnit::Nanosecond,
+ Some(TIMEZONE_WILDCARD.into()),
+ ),
+ ]
+ }
+ TypeSignatureClass::Time => {
+ vec![DataType::Time64(TimeUnit::Nanosecond)]
+ }
+ TypeSignatureClass::Interval => {
+ vec![DataType::Interval(IntervalUnit::DayTime)]
+ }
+ TypeSignatureClass::Duration => {
+ vec![DataType::Duration(TimeUnit::Nanosecond)]
+ }
+ TypeSignatureClass::Integer => {
+ vec![DataType::Int64]
+ }
+ }
+ }
+
+ /// Does the specified `NativeType` match this type signature class?
+ pub fn matches_native_type(
+ self: &TypeSignatureClass,
+ logical_type: &NativeType,
+ ) -> bool {
+ if logical_type == &NativeType::Null {
+ return true;
+ }
+
+ match self {
+ TypeSignatureClass::Native(t) if t.native() == logical_type =>
true,
+ TypeSignatureClass::Timestamp if logical_type.is_timestamp() =>
true,
+ TypeSignatureClass::Time if logical_type.is_time() => true,
+ TypeSignatureClass::Interval if logical_type.is_interval() => true,
+ TypeSignatureClass::Duration if logical_type.is_duration() => true,
+ TypeSignatureClass::Integer if logical_type.is_integer() => true,
+ _ => false,
+ }
+ }
+
+ /// What type would `origin_type` be casted to when casting to the
specified native type?
+ pub fn default_casted_type(
+ &self,
+ native_type: &NativeType,
+ origin_type: &DataType,
+ ) -> datafusion_common::Result<DataType> {
+ match self {
+ TypeSignatureClass::Native(logical_type) => {
+ logical_type.native().default_cast_for(origin_type)
+ }
+ // If the given type is already a timestamp, we don't change the
unit and timezone
+ TypeSignatureClass::Timestamp if native_type.is_timestamp() => {
+ Ok(origin_type.to_owned())
+ }
+ TypeSignatureClass::Time if native_type.is_time() => {
+ Ok(origin_type.to_owned())
+ }
+ TypeSignatureClass::Interval if native_type.is_interval() => {
+ Ok(origin_type.to_owned())
+ }
+ TypeSignatureClass::Duration if native_type.is_duration() => {
+ Ok(origin_type.to_owned())
+ }
+ TypeSignatureClass::Integer if native_type.is_integer() => {
+ Ok(origin_type.to_owned())
+ }
+ _ => internal_err!("May miss the matching logic in
`matches_native_type`"),
+ }
+ }
+}
+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum ArrayFunctionSignature {
/// A function takes at least one List/LargeList/FixedSizeList argument.
@@ -316,8 +400,8 @@ impl TypeSignature {
TypeSignature::Comparable(num) => {
vec![format!("Comparable({num})")]
}
- TypeSignature::Coercible(types) => {
- vec![Self::join_types(types, ", ")]
+ TypeSignature::Coercible(coercions) => {
+ vec![Self::join_types(coercions, ", ")]
}
TypeSignature::Exact(types) => {
vec![Self::join_types(types, ", ")]
@@ -371,44 +455,45 @@ impl TypeSignature {
}
}
- /// get all possible types for the given `TypeSignature`
+ #[deprecated(since = "46.0.0", note = "See get_example_types instead")]
pub fn get_possible_types(&self) -> Vec<Vec<DataType>> {
+ self.get_example_types()
+ }
+
+ /// Return example acceptable types for this `TypeSignature`'
+ ///
+ /// Returns a `Vec<DataType>` for each argument to the function
+ ///
+ /// This is used for `information_schema` and can be used to generate
+ /// documentation or error messages.
+ pub fn get_example_types(&self) -> Vec<Vec<DataType>> {
match self {
TypeSignature::Exact(types) => vec![types.clone()],
TypeSignature::OneOf(types) => types
.iter()
- .flat_map(|type_sig| type_sig.get_possible_types())
+ .flat_map(|type_sig| type_sig.get_example_types())
.collect(),
TypeSignature::Uniform(arg_count, types) => types
.iter()
.cloned()
.map(|data_type| vec![data_type; *arg_count])
.collect(),
- TypeSignature::Coercible(types) => types
+ TypeSignature::Coercible(coercions) => coercions
.iter()
- .map(|logical_type| match logical_type {
- TypeSignatureClass::Native(l) =>
get_data_types(l.native()),
- TypeSignatureClass::Timestamp => {
- vec![
- DataType::Timestamp(TimeUnit::Nanosecond, None),
- DataType::Timestamp(
- TimeUnit::Nanosecond,
- Some(TIMEZONE_WILDCARD.into()),
- ),
- ]
- }
- TypeSignatureClass::Date => {
- vec![DataType::Date64]
- }
- TypeSignatureClass::Time => {
- vec![DataType::Time64(TimeUnit::Nanosecond)]
- }
- TypeSignatureClass::Interval => {
- vec![DataType::Interval(IntervalUnit::DayTime)]
- }
- TypeSignatureClass::Duration => {
- vec![DataType::Duration(TimeUnit::Nanosecond)]
+ .map(|c| {
+ let mut all_types: IndexSet<DataType> =
+
c.desired_type().get_example_types().into_iter().collect();
+
+ if let Some(implicit_coercion) = c.implicit_coercion() {
+ let allowed_casts: Vec<DataType> = implicit_coercion
+ .allowed_source_types
+ .iter()
+ .flat_map(|t| t.get_example_types())
+ .collect();
+ all_types.extend(allowed_casts);
}
+
+ all_types.into_iter().collect::<Vec<_>>()
})
.multi_cartesian_product()
.collect(),
@@ -466,6 +551,186 @@ fn get_data_types(native_type: &NativeType) ->
Vec<DataType> {
}
}
+/// Represents type coercion rules for function arguments, specifying both the
desired type
+/// and optional implicit coercion rules for source types.
+///
+/// # Examples
+///
+/// ```
+/// use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
+/// use datafusion_common::types::{NativeType, logical_binary, logical_string};
+///
+/// // Exact coercion that only accepts timestamp types
+/// let exact = Coercion::new_exact(TypeSignatureClass::Timestamp);
+///
+/// // Implicit coercion that accepts string types but can coerce from binary
types
+/// let implicit = Coercion::new_implicit(
+/// TypeSignatureClass::Native(logical_string()),
+/// vec![TypeSignatureClass::Native(logical_binary())],
+/// NativeType::String
+/// );
+/// ```
+///
+/// There are two variants:
+///
+/// * `Exact` - Only accepts arguments that exactly match the desired type
+/// * `Implicit` - Accepts the desired type and can coerce from specified
source types
+#[derive(Debug, Clone, Eq, PartialOrd)]
+pub enum Coercion {
+ /// Coercion that only accepts arguments exactly matching the desired type.
+ Exact {
+ /// The required type for the argument
+ desired_type: TypeSignatureClass,
+ },
+
+ /// Coercion that accepts the desired type and can implicitly coerce from
other types.
+ Implicit {
+ /// The primary desired type for the argument
+ desired_type: TypeSignatureClass,
+ /// Rules for implicit coercion from other types
+ implicit_coercion: ImplicitCoercion,
+ },
+}
+
+impl Coercion {
+ pub fn new_exact(desired_type: TypeSignatureClass) -> Self {
+ Self::Exact { desired_type }
+ }
+
+ /// Create a new coercion with implicit coercion rules.
+ ///
+ /// `allowed_source_types` defines the possible types that can be coerced
to `desired_type`.
+ /// `default_casted_type` is the default type to be used for coercion if
we cast from other types via `allowed_source_types`.
+ pub fn new_implicit(
+ desired_type: TypeSignatureClass,
+ allowed_source_types: Vec<TypeSignatureClass>,
+ default_casted_type: NativeType,
+ ) -> Self {
+ Self::Implicit {
+ desired_type,
+ implicit_coercion: ImplicitCoercion {
+ allowed_source_types,
+ default_casted_type,
+ },
+ }
+ }
+
+ pub fn allowed_source_types(&self) -> &[TypeSignatureClass] {
+ match self {
+ Coercion::Exact { .. } => &[],
+ Coercion::Implicit {
+ implicit_coercion, ..
+ } => implicit_coercion.allowed_source_types.as_slice(),
+ }
+ }
+
+ pub fn default_casted_type(&self) -> Option<&NativeType> {
+ match self {
+ Coercion::Exact { .. } => None,
+ Coercion::Implicit {
+ implicit_coercion, ..
+ } => Some(&implicit_coercion.default_casted_type),
+ }
+ }
+
+ pub fn desired_type(&self) -> &TypeSignatureClass {
+ match self {
+ Coercion::Exact { desired_type } => desired_type,
+ Coercion::Implicit { desired_type, .. } => desired_type,
+ }
+ }
+
+ pub fn implicit_coercion(&self) -> Option<&ImplicitCoercion> {
+ match self {
+ Coercion::Exact { .. } => None,
+ Coercion::Implicit {
+ implicit_coercion, ..
+ } => Some(implicit_coercion),
+ }
+ }
+}
+
+impl Display for Coercion {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "Coercion({}", self.desired_type())?;
+ if let Some(implicit_coercion) = self.implicit_coercion() {
+ write!(f, ", implicit_coercion={implicit_coercion}",)
+ } else {
+ write!(f, ")")
+ }
+ }
+}
+
+impl PartialEq for Coercion {
+ fn eq(&self, other: &Self) -> bool {
+ self.desired_type() == other.desired_type()
+ && self.implicit_coercion() == other.implicit_coercion()
+ }
+}
+
+impl Hash for Coercion {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.desired_type().hash(state);
+ self.implicit_coercion().hash(state);
+ }
+}
+
+/// Defines rules for implicit type coercion, specifying which source types
can be
+/// coerced and the default type to use when coercing.
+///
+/// This is used by functions to specify which types they can accept via
implicit
+/// coercion in addition to their primary desired type.
+///
+/// # Examples
+///
+/// ```
+/// use arrow::datatypes::TimeUnit;
+///
+/// use datafusion_expr_common::signature::{Coercion, ImplicitCoercion,
TypeSignatureClass};
+/// use datafusion_common::types::{NativeType, logical_binary};
+///
+/// // Allow coercing from binary types to timestamp, coerce to specific
timestamp unit and timezone
+/// let implicit = Coercion::new_implicit(
+/// TypeSignatureClass::Timestamp,
+/// vec![TypeSignatureClass::Native(logical_binary())],
+/// NativeType::Timestamp(TimeUnit::Second, None),
+/// );
+/// ```
+#[derive(Debug, Clone, Eq, PartialOrd)]
+pub struct ImplicitCoercion {
+ /// The types that can be coerced from via implicit casting
+ allowed_source_types: Vec<TypeSignatureClass>,
+
+ /// The default type to use when coercing from allowed source types.
+ /// This is particularly important for types like Timestamp that have
multiple
+ /// possible configurations (different time units and timezones).
+ default_casted_type: NativeType,
+}
+
+impl Display for ImplicitCoercion {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(
+ f,
+ "ImplicitCoercion({:?}, default_type={:?})",
+ self.allowed_source_types, self.default_casted_type
+ )
+ }
+}
+
+impl PartialEq for ImplicitCoercion {
+ fn eq(&self, other: &Self) -> bool {
+ self.allowed_source_types == other.allowed_source_types
+ && self.default_casted_type == other.default_casted_type
+ }
+}
+
+impl Hash for ImplicitCoercion {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.allowed_source_types.hash(state);
+ self.default_casted_type.hash(state);
+ }
+}
+
/// Defines the supported argument types ([`TypeSignature`]) and
[`Volatility`] for a function.
///
/// DataFusion will automatically coerce (cast) argument types to one of the
supported
@@ -542,11 +807,9 @@ impl Signature {
volatility,
}
}
+
/// Target coerce types in order
- pub fn coercible(
- target_types: Vec<TypeSignatureClass>,
- volatility: Volatility,
- ) -> Self {
+ pub fn coercible(target_types: Vec<Coercion>, volatility: Volatility) ->
Self {
Self {
type_signature: TypeSignature::Coercible(target_types),
volatility,
@@ -721,14 +984,14 @@ mod tests {
#[test]
fn test_get_possible_types() {
let type_signature = TypeSignature::Exact(vec![DataType::Int32,
DataType::Int64]);
- let possible_types = type_signature.get_possible_types();
+ let possible_types = type_signature.get_example_types();
assert_eq!(possible_types, vec![vec![DataType::Int32,
DataType::Int64]]);
let type_signature = TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Int32, DataType::Int64]),
TypeSignature::Exact(vec![DataType::Float32, DataType::Float64]),
]);
- let possible_types = type_signature.get_possible_types();
+ let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![
@@ -742,7 +1005,7 @@ mod tests {
TypeSignature::Exact(vec![DataType::Float32, DataType::Float64]),
TypeSignature::Exact(vec![DataType::Utf8]),
]);
- let possible_types = type_signature.get_possible_types();
+ let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![
@@ -754,7 +1017,7 @@ mod tests {
let type_signature =
TypeSignature::Uniform(2, vec![DataType::Float32,
DataType::Int64]);
- let possible_types = type_signature.get_possible_types();
+ let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![
@@ -764,10 +1027,10 @@ mod tests {
);
let type_signature = TypeSignature::Coercible(vec![
- TypeSignatureClass::Native(logical_string()),
- TypeSignatureClass::Native(logical_int64()),
+ Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
+ Coercion::new_exact(TypeSignatureClass::Native(logical_int64())),
]);
- let possible_types = type_signature.get_possible_types();
+ let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![
@@ -779,14 +1042,14 @@ mod tests {
let type_signature =
TypeSignature::Variadic(vec![DataType::Int32, DataType::Int64]);
- let possible_types = type_signature.get_possible_types();
+ let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![vec![DataType::Int32], vec![DataType::Int64]]
);
let type_signature = TypeSignature::Numeric(2);
- let possible_types = type_signature.get_possible_types();
+ let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![
@@ -804,7 +1067,7 @@ mod tests {
);
let type_signature = TypeSignature::String(2);
- let possible_types = type_signature.get_possible_types();
+ let possible_types = type_signature.get_example_types();
assert_eq!(
possible_types,
vec![
diff --git a/datafusion/expr/src/type_coercion/functions.rs
b/datafusion/expr/src/type_coercion/functions.rs
index 7fda92862b..b471feca04 100644
--- a/datafusion/expr/src/type_coercion/functions.rs
+++ b/datafusion/expr/src/type_coercion/functions.rs
@@ -21,19 +21,15 @@ use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
+use datafusion_common::types::LogicalType;
use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion};
use datafusion_common::{
- exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
- types::{LogicalType, NativeType},
- utils::list_ndims,
- Result,
+ exec_err, internal_datafusion_err, internal_err, plan_err,
types::NativeType,
+ utils::list_ndims, Result,
};
use datafusion_expr_common::signature::ArrayFunctionArgument;
use datafusion_expr_common::{
- signature::{
- ArrayFunctionSignature, TypeSignatureClass, FIXED_SIZE_LIST_WILDCARD,
- TIMEZONE_WILDCARD,
- },
+ signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD,
TIMEZONE_WILDCARD},
type_coercion::binary::comparison_coercion_numeric,
type_coercion::binary::string_coercion,
};
@@ -604,75 +600,36 @@ fn get_valid_types(
vec![vec![target_type; *num]]
}
}
- TypeSignature::Coercible(target_types) => {
- function_length_check(
- function_name,
- current_types.len(),
- target_types.len(),
- )?;
-
- // Aim to keep this logic as SIMPLE as possible!
- // Make sure the corresponding test is covered
- // If this function becomes COMPLEX, create another new signature!
- fn can_coerce_to(
- function_name: &str,
- current_type: &DataType,
- target_type_class: &TypeSignatureClass,
- ) -> Result<DataType> {
- let logical_type: NativeType = current_type.into();
-
- match target_type_class {
- TypeSignatureClass::Native(native_type) => {
- let target_type = native_type.native();
- if &logical_type == target_type {
- return target_type.default_cast_for(current_type);
- }
-
- if logical_type == NativeType::Null {
- return target_type.default_cast_for(current_type);
- }
-
- if target_type.is_integer() &&
logical_type.is_integer() {
- return target_type.default_cast_for(current_type);
- }
-
- internal_err!(
- "Function '{function_name}' expects
{target_type_class} but received {current_type}"
- )
- }
- // Not consistent with Postgres and DuckDB but to avoid
regression we implicit cast string to timestamp
- TypeSignatureClass::Timestamp
- if logical_type == NativeType::String =>
- {
- Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
- }
- TypeSignatureClass::Timestamp if
logical_type.is_timestamp() => {
- Ok(current_type.to_owned())
- }
- TypeSignatureClass::Date if logical_type.is_date() => {
- Ok(current_type.to_owned())
- }
- TypeSignatureClass::Time if logical_type.is_time() => {
- Ok(current_type.to_owned())
- }
- TypeSignatureClass::Interval if logical_type.is_interval()
=> {
- Ok(current_type.to_owned())
- }
- TypeSignatureClass::Duration if logical_type.is_duration()
=> {
- Ok(current_type.to_owned())
- }
- _ => {
- not_impl_err!("Function '{function_name}' got
logical_type: {logical_type} with target_type_class: {target_type_class}")
- }
- }
- }
+ TypeSignature::Coercible(param_types) => {
+ function_length_check(function_name, current_types.len(),
param_types.len())?;
let mut new_types = Vec::with_capacity(current_types.len());
- for (current_type, target_type_class) in
- current_types.iter().zip(target_types.iter())
- {
- let target_type = can_coerce_to(function_name, current_type,
target_type_class)?;
- new_types.push(target_type);
+ for (current_type, param) in
current_types.iter().zip(param_types.iter()) {
+ let current_native_type: NativeType = current_type.into();
+
+ if
param.desired_type().matches_native_type(¤t_native_type) {
+ let casted_type = param.desired_type().default_casted_type(
+ ¤t_native_type,
+ current_type,
+ )?;
+
+ new_types.push(casted_type);
+ } else if param
+ .allowed_source_types()
+ .iter()
+ .any(|t| t.matches_native_type(¤t_native_type)) {
+ // If the condition is met which means `implicit
coercion`` is provided so we can safely unwrap
+ let default_casted_type =
param.default_casted_type().unwrap();
+ let casted_type =
default_casted_type.default_cast_for(current_type)?;
+ new_types.push(casted_type);
+ } else {
+ return internal_err!(
+ "Expect {} but received {}, DataType: {}",
+ param.desired_type(),
+ current_native_type,
+ current_type
+ );
+ }
}
vec![new_types]
diff --git a/datafusion/functions/src/datetime/date_part.rs
b/datafusion/functions/src/datetime/date_part.rs
index 9df91da67f..49b7a4ec46 100644
--- a/datafusion/functions/src/datetime/date_part.rs
+++ b/datafusion/functions/src/datetime/date_part.rs
@@ -27,6 +27,7 @@ use arrow::datatypes::DataType::{
};
use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
use arrow::datatypes::{DataType, TimeUnit};
+use datafusion_common::types::{logical_date, NativeType};
use datafusion_common::{
cast::{
@@ -44,7 +45,7 @@ use datafusion_expr::{
ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl,
Signature,
TypeSignature, Volatility,
};
-use datafusion_expr_common::signature::TypeSignatureClass;
+use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
use datafusion_macros::user_doc;
#[user_doc(
@@ -95,24 +96,29 @@ impl DatePartFunc {
signature: Signature::one_of(
vec![
TypeSignature::Coercible(vec![
- TypeSignatureClass::Native(logical_string()),
- TypeSignatureClass::Timestamp,
+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
+ Coercion::new_implicit(
+ TypeSignatureClass::Timestamp,
+ // Not consistent with Postgres and DuckDB but to
avoid regression we implicit cast string to timestamp
+ vec![TypeSignatureClass::Native(logical_string())],
+ NativeType::Timestamp(Nanosecond, None),
+ ),
]),
TypeSignature::Coercible(vec![
- TypeSignatureClass::Native(logical_string()),
- TypeSignatureClass::Date,
+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
+
Coercion::new_exact(TypeSignatureClass::Native(logical_date())),
]),
TypeSignature::Coercible(vec![
- TypeSignatureClass::Native(logical_string()),
- TypeSignatureClass::Time,
+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
+ Coercion::new_exact(TypeSignatureClass::Time),
]),
TypeSignature::Coercible(vec![
- TypeSignatureClass::Native(logical_string()),
- TypeSignatureClass::Interval,
+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
+ Coercion::new_exact(TypeSignatureClass::Interval),
]),
TypeSignature::Coercible(vec![
- TypeSignatureClass::Native(logical_string()),
- TypeSignatureClass::Duration,
+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
+ Coercion::new_exact(TypeSignatureClass::Duration),
]),
],
Volatility::Immutable,
diff --git a/datafusion/functions/src/string/ascii.rs
b/datafusion/functions/src/string/ascii.rs
index 858eddc7c8..3832ad2a34 100644
--- a/datafusion/functions/src/string/ascii.rs
+++ b/datafusion/functions/src/string/ascii.rs
@@ -19,9 +19,11 @@ use crate::utils::make_scalar_function;
use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array};
use arrow::datatypes::DataType;
use arrow::error::ArrowError;
+use datafusion_common::types::logical_string;
use datafusion_common::{internal_err, Result};
-use datafusion_expr::{ColumnarValue, Documentation};
+use datafusion_expr::{ColumnarValue, Documentation, TypeSignatureClass};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
+use datafusion_expr_common::signature::Coercion;
use datafusion_macros::user_doc;
use std::any::Any;
use std::sync::Arc;
@@ -61,7 +63,12 @@ impl Default for AsciiFunc {
impl AsciiFunc {
pub fn new() -> Self {
Self {
- signature: Signature::string(1, Volatility::Immutable),
+ signature: Signature::coercible(
+ vec![Coercion::new_exact(TypeSignatureClass::Native(
+ logical_string(),
+ ))],
+ Volatility::Immutable,
+ ),
}
}
}
diff --git a/datafusion/functions/src/string/repeat.rs
b/datafusion/functions/src/string/repeat.rs
index 8253754c2b..8fdbc3dd29 100644
--- a/datafusion/functions/src/string/repeat.rs
+++ b/datafusion/functions/src/string/repeat.rs
@@ -26,11 +26,11 @@ use arrow::array::{
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
use datafusion_common::cast::as_int64_array;
-use datafusion_common::types::{logical_int64, logical_string};
+use datafusion_common::types::{logical_int64, logical_string, NativeType};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};
-use datafusion_expr_common::signature::TypeSignatureClass;
+use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
use datafusion_macros::user_doc;
#[user_doc(
@@ -67,8 +67,13 @@ impl RepeatFunc {
Self {
signature: Signature::coercible(
vec![
- TypeSignatureClass::Native(logical_string()),
- TypeSignatureClass::Native(logical_int64()),
+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
+ // Accept all integer types but cast them to i64
+ Coercion::new_implicit(
+ TypeSignatureClass::Native(logical_int64()),
+ vec![TypeSignatureClass::Integer],
+ NativeType::Int64,
+ ),
],
Volatility::Immutable,
),
diff --git a/datafusion/sqllogictest/test_files/expr.slt
b/datafusion/sqllogictest/test_files/expr.slt
index a0264c4362..7980b180ae 100644
--- a/datafusion/sqllogictest/test_files/expr.slt
+++ b/datafusion/sqllogictest/test_files/expr.slt
@@ -324,6 +324,16 @@ SELECT ascii('x')
----
120
+query I
+SELECT ascii('222')
+----
+50
+
+query I
+SELECT ascii('0xa')
+----
+48
+
query I
SELECT ascii(NULL)
----
@@ -571,7 +581,7 @@ select repeat('-1.2', arrow_cast(3, 'Int32'));
----
-1.2-1.2-1.2
-query error DataFusion error: Error during planning: Internal error: Function
'repeat' expects TypeSignatureClass::Native\(LogicalType\(Native\(Int64\),
Int64\)\) but received Float64
+query error DataFusion error: Error during planning: Internal error: Expect
TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but
received NativeType::Float64, DataType: Float64
select repeat('-1.2', 3.2);
query T
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]