This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ac5676f74 Minor: add the concise way for matching numerics (#5537)
ac5676f74 is described below

commit ac5676f74bfac89707642f9221d35899b7c2c321
Author: Igor Izvekov <[email protected]>
AuthorDate: Fri Mar 10 15:44:49 2023 +0300

    Minor: add the concise way for matching numerics (#5537)
    
    * Minor: add the concise way for matching numerics
    
    * fix: use values of argument type except links
---
 datafusion/expr/src/type_coercion/aggregates.rs | 172 +++++-------------------
 1 file changed, 34 insertions(+), 138 deletions(-)

diff --git a/datafusion/expr/src/type_coercion/aggregates.rs 
b/datafusion/expr/src/type_coercion/aggregates.rs
index fca851ce6..3ad197afb 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -314,77 +314,45 @@ pub fn sum_return_type(arg_type: &DataType) -> 
Result<DataType> {
 
 /// function return type of variance
 pub fn variance_return_type(arg_type: &DataType) -> Result<DataType> {
-    match arg_type {
-        DataType::Int8
-        | DataType::Int16
-        | DataType::Int32
-        | DataType::Int64
-        | DataType::UInt8
-        | DataType::UInt16
-        | DataType::UInt32
-        | DataType::UInt64
-        | DataType::Float32
-        | DataType::Float64 => Ok(DataType::Float64),
-        other => Err(DataFusionError::Plan(format!(
-            "VAR does not support {other:?}"
-        ))),
+    if NUMERICS.contains(arg_type) {
+        Ok(DataType::Float64)
+    } else {
+        Err(DataFusionError::Plan(format!(
+            "VAR does not support {arg_type:?}"
+        )))
     }
 }
 
 /// function return type of covariance
 pub fn covariance_return_type(arg_type: &DataType) -> Result<DataType> {
-    match arg_type {
-        DataType::Int8
-        | DataType::Int16
-        | DataType::Int32
-        | DataType::Int64
-        | DataType::UInt8
-        | DataType::UInt16
-        | DataType::UInt32
-        | DataType::UInt64
-        | DataType::Float32
-        | DataType::Float64 => Ok(DataType::Float64),
-        other => Err(DataFusionError::Plan(format!(
-            "COVAR does not support {other:?}"
-        ))),
+    if NUMERICS.contains(arg_type) {
+        Ok(DataType::Float64)
+    } else {
+        Err(DataFusionError::Plan(format!(
+            "COVAR does not support {arg_type:?}"
+        )))
     }
 }
 
 /// function return type of correlation
 pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
-    match arg_type {
-        DataType::Int8
-        | DataType::Int16
-        | DataType::Int32
-        | DataType::Int64
-        | DataType::UInt8
-        | DataType::UInt16
-        | DataType::UInt32
-        | DataType::UInt64
-        | DataType::Float32
-        | DataType::Float64 => Ok(DataType::Float64),
-        other => Err(DataFusionError::Plan(format!(
-            "CORR does not support {other:?}"
-        ))),
+    if NUMERICS.contains(arg_type) {
+        Ok(DataType::Float64)
+    } else {
+        Err(DataFusionError::Plan(format!(
+            "CORR does not support {arg_type:?}"
+        )))
     }
 }
 
 /// function return type of standard deviation
 pub fn stddev_return_type(arg_type: &DataType) -> Result<DataType> {
-    match arg_type {
-        DataType::Int8
-        | DataType::Int16
-        | DataType::Int32
-        | DataType::Int64
-        | DataType::UInt8
-        | DataType::UInt16
-        | DataType::UInt32
-        | DataType::UInt64
-        | DataType::Float32
-        | DataType::Float64 => Ok(DataType::Float64),
-        other => Err(DataFusionError::Plan(format!(
-            "STDDEV does not support {other:?}"
-        ))),
+    if NUMERICS.contains(arg_type) {
+        Ok(DataType::Float64)
+    } else {
+        Err(DataFusionError::Plan(format!(
+            "STDDEV does not support {arg_type:?}"
+        )))
     }
 }
 
@@ -398,16 +366,7 @@ pub fn avg_return_type(arg_type: &DataType) -> 
Result<DataType> {
             let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
             Ok(DataType::Decimal128(new_precision, new_scale))
         }
-        DataType::Int8
-        | DataType::Int16
-        | DataType::Int32
-        | DataType::Int64
-        | DataType::UInt8
-        | DataType::UInt16
-        | DataType::UInt32
-        | DataType::UInt64
-        | DataType::Float32
-        | DataType::Float64 => Ok(DataType::Float64),
+        arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
         other => Err(DataFusionError::Plan(format!(
             "AVG does not support {other:?}"
         ))),
@@ -417,98 +376,44 @@ pub fn avg_return_type(arg_type: &DataType) -> 
Result<DataType> {
 pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
     matches!(
         arg_type,
-        DataType::UInt8
-            | DataType::UInt16
-            | DataType::UInt32
-            | DataType::UInt64
-            | DataType::Int8
-            | DataType::Int16
-            | DataType::Int32
-            | DataType::Int64
-            | DataType::Float32
-            | DataType::Float64
-            | DataType::Decimal128(_, _)
+        arg_type if NUMERICS.contains(arg_type)
+        || matches!(arg_type, DataType::Decimal128(_, _))
     )
 }
 
 pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
     matches!(
         arg_type,
-        DataType::UInt8
-            | DataType::UInt16
-            | DataType::UInt32
-            | DataType::UInt64
-            | DataType::Int8
-            | DataType::Int16
-            | DataType::Int32
-            | DataType::Int64
-            | DataType::Float32
-            | DataType::Float64
-            | DataType::Decimal128(_, _)
+        arg_type if NUMERICS.contains(arg_type)
+            || matches!(arg_type, DataType::Decimal128(_, _))
     )
 }
 
 pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool {
     matches!(
         arg_type,
-        DataType::UInt8
-            | DataType::UInt16
-            | DataType::UInt32
-            | DataType::UInt64
-            | DataType::Int8
-            | DataType::Int16
-            | DataType::Int32
-            | DataType::Int64
-            | DataType::Float32
-            | DataType::Float64
+        arg_type if NUMERICS.contains(arg_type)
     )
 }
 
 pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool {
     matches!(
         arg_type,
-        DataType::UInt8
-            | DataType::UInt16
-            | DataType::UInt32
-            | DataType::UInt64
-            | DataType::Int8
-            | DataType::Int16
-            | DataType::Int32
-            | DataType::Int64
-            | DataType::Float32
-            | DataType::Float64
+        arg_type if NUMERICS.contains(arg_type)
     )
 }
 
 pub fn is_stddev_support_arg_type(arg_type: &DataType) -> bool {
     matches!(
         arg_type,
-        DataType::UInt8
-            | DataType::UInt16
-            | DataType::UInt32
-            | DataType::UInt64
-            | DataType::Int8
-            | DataType::Int16
-            | DataType::Int32
-            | DataType::Int64
-            | DataType::Float32
-            | DataType::Float64
+        arg_type if NUMERICS.contains(arg_type)
     )
 }
 
 pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
     matches!(
         arg_type,
-        DataType::UInt8
-            | DataType::UInt16
-            | DataType::UInt32
-            | DataType::UInt64
-            | DataType::Int8
-            | DataType::Int16
-            | DataType::Int32
-            | DataType::Int64
-            | DataType::Float32
-            | DataType::Float64
+        arg_type if NUMERICS.contains(arg_type)
     )
 }
 
@@ -531,16 +436,7 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
 pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> 
bool {
     matches!(
         arg_type,
-        DataType::UInt8
-            | DataType::UInt16
-            | DataType::UInt32
-            | DataType::UInt64
-            | DataType::Int8
-            | DataType::Int16
-            | DataType::Int32
-            | DataType::Int64
-            | DataType::Float32
-            | DataType::Float64
+        arg_type if NUMERICS.contains(arg_type)
     )
 }
 

Reply via email to