This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ac5676f74 Minor: add the concise way for matching numerics (#5537)
ac5676f74 is described below
commit ac5676f74bfac89707642f9221d35899b7c2c321
Author: Igor Izvekov <[email protected]>
AuthorDate: Fri Mar 10 15:44:49 2023 +0300
Minor: add the concise way for matching numerics (#5537)
* Minor: add the concise way for matching numerics
* fix: use values of argument type except links
---
datafusion/expr/src/type_coercion/aggregates.rs | 172 +++++-------------------
1 file changed, 34 insertions(+), 138 deletions(-)
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs
b/datafusion/expr/src/type_coercion/aggregates.rs
index fca851ce6..3ad197afb 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -314,77 +314,45 @@ pub fn sum_return_type(arg_type: &DataType) ->
Result<DataType> {
/// function return type of variance
pub fn variance_return_type(arg_type: &DataType) -> Result<DataType> {
- match arg_type {
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Float32
- | DataType::Float64 => Ok(DataType::Float64),
- other => Err(DataFusionError::Plan(format!(
- "VAR does not support {other:?}"
- ))),
+ if NUMERICS.contains(arg_type) {
+ Ok(DataType::Float64)
+ } else {
+ Err(DataFusionError::Plan(format!(
+ "VAR does not support {arg_type:?}"
+ )))
}
}
/// function return type of covariance
pub fn covariance_return_type(arg_type: &DataType) -> Result<DataType> {
- match arg_type {
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Float32
- | DataType::Float64 => Ok(DataType::Float64),
- other => Err(DataFusionError::Plan(format!(
- "COVAR does not support {other:?}"
- ))),
+ if NUMERICS.contains(arg_type) {
+ Ok(DataType::Float64)
+ } else {
+ Err(DataFusionError::Plan(format!(
+ "COVAR does not support {arg_type:?}"
+ )))
}
}
/// function return type of correlation
pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
- match arg_type {
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Float32
- | DataType::Float64 => Ok(DataType::Float64),
- other => Err(DataFusionError::Plan(format!(
- "CORR does not support {other:?}"
- ))),
+ if NUMERICS.contains(arg_type) {
+ Ok(DataType::Float64)
+ } else {
+ Err(DataFusionError::Plan(format!(
+ "CORR does not support {arg_type:?}"
+ )))
}
}
/// function return type of standard deviation
pub fn stddev_return_type(arg_type: &DataType) -> Result<DataType> {
- match arg_type {
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Float32
- | DataType::Float64 => Ok(DataType::Float64),
- other => Err(DataFusionError::Plan(format!(
- "STDDEV does not support {other:?}"
- ))),
+ if NUMERICS.contains(arg_type) {
+ Ok(DataType::Float64)
+ } else {
+ Err(DataFusionError::Plan(format!(
+ "STDDEV does not support {arg_type:?}"
+ )))
}
}
@@ -398,16 +366,7 @@ pub fn avg_return_type(arg_type: &DataType) ->
Result<DataType> {
let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal128(new_precision, new_scale))
}
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Float32
- | DataType::Float64 => Ok(DataType::Float64),
+ arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
other => Err(DataFusionError::Plan(format!(
"AVG does not support {other:?}"
))),
@@ -417,98 +376,44 @@ pub fn avg_return_type(arg_type: &DataType) ->
Result<DataType> {
pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
- | DataType::Decimal128(_, _)
+ arg_type if NUMERICS.contains(arg_type)
+ || matches!(arg_type, DataType::Decimal128(_, _))
)
}
pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
- | DataType::Decimal128(_, _)
+ arg_type if NUMERICS.contains(arg_type)
+ || matches!(arg_type, DataType::Decimal128(_, _))
)
}
pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
+ arg_type if NUMERICS.contains(arg_type)
)
}
pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
+ arg_type if NUMERICS.contains(arg_type)
)
}
pub fn is_stddev_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
+ arg_type if NUMERICS.contains(arg_type)
)
}
pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
+ arg_type if NUMERICS.contains(arg_type)
)
}
@@ -531,16 +436,7 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) ->
bool {
matches!(
arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
+ arg_type if NUMERICS.contains(arg_type)
)
}