comphead commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r878835961


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => 
{{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> 
{
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, 
s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, 
&s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, 
&s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, 
"f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, 
"f32");

Review Comment:
   I came up to idea like that, it should put the boilerplate once. And then 
use external functions over values.
   
   ```
   use std::ops::Sub;
   use std::ops::Add;
   use std::fmt::Debug;
   use std::any::Any;
   
   #[derive(Debug, Copy, Clone)]
   enum DataType {
       Int32(Option<i32>),
       Float64(Option<f64>)
   }
   
   // boilerplate comes here
   macro_rules! op {
       ($ARG1: expr, $ARG2: expr, $FUNC: block) => {{
           let res = match ($ARG1, $ARG2) {
               (DataType::Int32(Some(v1)), DataType::Float64(Some(v2))) => 
               DataType::Float64(Some($FUNC(v1 as f64, v2 as f64))),
               _ => panic!("123")
           };
           
           res
       }};
   }
   
   fn sum<T:Add<Output = T> + Copy> (num1: T, num2: T) -> T {
       return num1 + num2;
   }
   
   fn minus<T:Sub<Output = T> + Copy> (num1: T, num2: T) -> T {
       return num1 - num2;
   }
   
   fn main() {
       let i_32 = DataType::Int32(Some(1));
       let f_64 = DataType::Float64(Some(2.0));
   
       let s = op!(i_32, f_64, {sum});
       dbg!(&s);
       
       let m = op!(i_32, f_64, {minus});
       dbg!(&m);
   }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to