Jefffrey commented on code in PR #19593:
URL: https://github.com/apache/datafusion/pull/19593#discussion_r2656420842


##########
datafusion/functions-aggregate/src/sum.rs:
##########
@@ -233,56 +248,126 @@ impl AggregateUDFImpl for Sum {
                 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 
10);
                 Ok(DataType::Decimal256(new_precision, *scale))
             }
-            DataType::Duration(time_unit) => 
Ok(DataType::Duration(*time_unit)),
-            other => {
-                exec_err!("[return_type] SUM not supported for {}", other)
-            }
+            dt => Ok(dt.clone()),
         }
     }
 
     fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
-        if args.is_distinct {
-            macro_rules! helper {
-                ($t:ty, $dt:expr) => {
-                    Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt)))
-                };
+        if args.expr_fields[0].data_type() == &DataType::Null {
+            return 
Ok(Box::new(NoopAccumulator::new(ScalarValue::Float64(None))));
+        }

Review Comment:
   Ditto



##########
datafusion/functions-aggregate/src/sum.rs:
##########
@@ -396,6 +488,155 @@ impl<T: ArrowNumericType> Accumulator for 
SumAccumulator<T> {
     }
 }
 
+#[derive(Debug, Eq, PartialEq)]
+enum TrySumState<T: ArrowNativeTypeOp> {
+    Initial,
+    ValidSum(T),
+    Overflow,
+}
+
+/// Will return `NULL` if at any point the sum overflows.
+#[derive(Debug)]
+struct TrySumAccumulator<T: ArrowNumericType + std::fmt::Debug> {
+    state: TrySumState<T::Native>,
+    data_type: DataType,
+}
+
+impl<T: ArrowNumericType + std::fmt::Debug> TrySumAccumulator<T> {
+    fn new(data_type: DataType) -> Self {
+        Self {
+            state: TrySumState::Initial,
+            data_type,
+        }
+    }
+}
+
+impl<T: ArrowNumericType + std::fmt::Debug> Accumulator for 
TrySumAccumulator<T> {
+    fn state(&mut self) -> Result<Vec<ScalarValue>> {
+        match self.state {
+            TrySumState::Initial => Ok(vec![
+                ScalarValue::try_new_null(&self.data_type)?,
+                ScalarValue::from(false),
+            ]),
+            TrySumState::ValidSum(sum) => Ok(vec![
+                ScalarValue::new_primitive::<T>(Some(sum), &self.data_type)?,
+                ScalarValue::from(false),
+            ]),
+            TrySumState::Overflow => Ok(vec![
+                ScalarValue::try_new_null(&self.data_type)?,
+                ScalarValue::from(true),
+            ]),
+        }
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let v = match self.state {
+            TrySumState::Initial => T::Native::ZERO,
+            TrySumState::ValidSum(sum) => sum,
+            TrySumState::Overflow => return Ok(()),
+        };
+        let values = values[0].as_primitive::<T>();
+        match arrow::compute::sum_checked(values) {
+            Ok(Some(x)) => match v.add_checked(x) {
+                Ok(sum) => {
+                    self.state = TrySumState::ValidSum(sum);
+                }
+                Err(ArrowError::ArithmeticOverflow(_)) => {
+                    self.state = TrySumState::Overflow;
+                }
+                Err(e) => {
+                    return Err(e.into());
+                }

Review Comment:
   Technically these errs are unreachable as I think only arithmeticoverflow 
error variant is returned for `add_checked` but keeping this in case



##########
datafusion/functions-aggregate/src/sum.rs:
##########
@@ -212,9 +229,7 @@ impl AggregateUDFImpl for Sum {
 
     fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
         match &arg_types[0] {
-            DataType::Int64 => Ok(DataType::Int64),
-            DataType::UInt64 => Ok(DataType::UInt64),
-            DataType::Float64 => Ok(DataType::Float64),
+            DataType::Null => Ok(DataType::Float64),

Review Comment:
   Spark try_sum had test case where null inputs will return a null as double, 
so adding that here



##########
datafusion/functions-aggregate/src/sum.rs:
##########
@@ -233,56 +248,126 @@ impl AggregateUDFImpl for Sum {
                 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 
10);
                 Ok(DataType::Decimal256(new_precision, *scale))
             }
-            DataType::Duration(time_unit) => 
Ok(DataType::Duration(*time_unit)),
-            other => {
-                exec_err!("[return_type] SUM not supported for {}", other)
-            }
+            dt => Ok(dt.clone()),
         }
     }
 
     fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
-        if args.is_distinct {
-            macro_rules! helper {
-                ($t:ty, $dt:expr) => {
-                    Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt)))
-                };
+        if args.expr_fields[0].data_type() == &DataType::Null {
+            return 
Ok(Box::new(NoopAccumulator::new(ScalarValue::Float64(None))));
+        }
+        match (args.is_distinct, self.try_sum_mode) {
+            (true, false) => {
+                macro_rules! helper {
+                    ($t:ty, $dt:expr) => {
+                        Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt)))
+                    };
+                }
+                downcast_sum!(args, helper)
             }
-            downcast_sum!(args, helper)
-        } else {
-            macro_rules! helper {
-                ($t:ty, $dt:expr) => {
-                    Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
-                };
+            (false, false) => {
+                macro_rules! helper {
+                    ($t:ty, $dt:expr) => {
+                        Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
+                    };
+                }
+                downcast_sum!(args, helper)
+            }
+            (false, true) => {
+                match args.return_type() {
+                    DataType::UInt64 => Ok(Box::new(
+                        TrySumAccumulator::<UInt64Type>::new(DataType::UInt64),
+                    )),
+                    DataType::Int64 => 
Ok(Box::new(TrySumAccumulator::<Int64Type>::new(
+                        DataType::Int64,
+                    ))),
+                    DataType::Float64 => Ok(Box::new(
+                        
TrySumAccumulator::<Float64Type>::new(DataType::Float64),
+                    )),
+                    DataType::Duration(TimeUnit::Second) => {
+                        
Ok(Box::new(TrySumAccumulator::<DurationSecondType>::new(
+                            DataType::Duration(TimeUnit::Second),
+                        )))
+                    }
+                    DataType::Duration(TimeUnit::Millisecond) => {
+                        
Ok(Box::new(TrySumAccumulator::<DurationMillisecondType>::new(
+                            DataType::Duration(TimeUnit::Millisecond),
+                        )))
+                    }
+                    DataType::Duration(TimeUnit::Microsecond) => {
+                        
Ok(Box::new(TrySumAccumulator::<DurationMicrosecondType>::new(
+                            DataType::Duration(TimeUnit::Microsecond),
+                        )))
+                    }
+                    DataType::Duration(TimeUnit::Nanosecond) => {
+                        
Ok(Box::new(TrySumAccumulator::<DurationNanosecondType>::new(
+                            DataType::Duration(TimeUnit::Nanosecond),
+                        )))
+                    }
+                    dt @ DataType::Decimal32(..) => Ok(Box::new(
+                        
TrySumDecimalAccumulator::<Decimal32Type>::new(dt.clone()),
+                    )),
+                    dt @ DataType::Decimal64(..) => Ok(Box::new(
+                        
TrySumDecimalAccumulator::<Decimal64Type>::new(dt.clone()),
+                    )),
+                    dt @ DataType::Decimal128(..) => Ok(Box::new(
+                        
TrySumDecimalAccumulator::<Decimal128Type>::new(dt.clone()),
+                    )),
+                    dt @ DataType::Decimal256(..) => Ok(Box::new(
+                        
TrySumDecimalAccumulator::<Decimal256Type>::new(dt.clone()),
+                    )),
+                    dt => internal_err!("Unsupported datatype for sum: {dt}"),
+                }
+            }
+            (true, true) => {
+                not_impl_err!("Try sum mode not supported for distinct sum 
accumulators")
             }
-            downcast_sum!(args, helper)
         }
     }
 
     fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
-        if args.is_distinct {
-            Ok(vec![
-                Field::new_list(
-                    format_state_name(args.name, "sum distinct"),
-                    // See COMMENTS.md to understand why nullable is set to 
true
-                    Field::new_list_field(args.return_type().clone(), true),
-                    false,
+        match (args.is_distinct, self.try_sum_mode) {

Review Comment:
   It's a bit ugly here, would be nice if we manage to make accumulators own 
their state fields, see: 
https://github.com/apache/datafusion/issues/14701#issuecomment-3419742933



##########
datafusion/functions-aggregate/src/sum.rs:
##########
@@ -336,6 +424,10 @@ impl AggregateUDFImpl for Sum {
     }
 
     fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity {
+        // Can overflow into null
+        if self.try_sum_mode {
+            return SetMonotonicity::NotMonotonic;
+        }

Review Comment:
   I'm not certain on this 🤔 



##########
datafusion/functions-aggregate/src/sum.rs:
##########
@@ -396,6 +488,155 @@ impl<T: ArrowNumericType> Accumulator for 
SumAccumulator<T> {
     }
 }
 
+#[derive(Debug, Eq, PartialEq)]
+enum TrySumState<T: ArrowNativeTypeOp> {
+    Initial,
+    ValidSum(T),
+    Overflow,
+}
+
+/// Will return `NULL` if at any point the sum overflows.
+#[derive(Debug)]
+struct TrySumAccumulator<T: ArrowNumericType + std::fmt::Debug> {
+    state: TrySumState<T::Native>,
+    data_type: DataType,
+}
+
+impl<T: ArrowNumericType + std::fmt::Debug> TrySumAccumulator<T> {
+    fn new(data_type: DataType) -> Self {
+        Self {
+            state: TrySumState::Initial,
+            data_type,
+        }
+    }
+}
+
+impl<T: ArrowNumericType + std::fmt::Debug> Accumulator for 
TrySumAccumulator<T> {
+    fn state(&mut self) -> Result<Vec<ScalarValue>> {
+        match self.state {
+            TrySumState::Initial => Ok(vec![
+                ScalarValue::try_new_null(&self.data_type)?,
+                ScalarValue::from(false),
+            ]),
+            TrySumState::ValidSum(sum) => Ok(vec![
+                ScalarValue::new_primitive::<T>(Some(sum), &self.data_type)?,
+                ScalarValue::from(false),
+            ]),
+            TrySumState::Overflow => Ok(vec![
+                ScalarValue::try_new_null(&self.data_type)?,
+                ScalarValue::from(true),
+            ]),
+        }
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let v = match self.state {
+            TrySumState::Initial => T::Native::ZERO,
+            TrySumState::ValidSum(sum) => sum,
+            TrySumState::Overflow => return Ok(()),
+        };
+        let values = values[0].as_primitive::<T>();
+        match arrow::compute::sum_checked(values) {
+            Ok(Some(x)) => match v.add_checked(x) {
+                Ok(sum) => {
+                    self.state = TrySumState::ValidSum(sum);
+                }
+                Err(ArrowError::ArithmeticOverflow(_)) => {
+                    self.state = TrySumState::Overflow;
+                }
+                Err(e) => {
+                    return Err(e.into());
+                }
+            },
+            Ok(None) => (),
+            Err(ArrowError::ArithmeticOverflow(_)) => {
+                self.state = TrySumState::Overflow;
+            }
+            Err(e) => {
+                return Err(e.into());
+            }
+        }
+
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        let other_batch_failed = states[1].as_boolean().value(0);
+        if other_batch_failed {
+            self.state = TrySumState::Overflow;
+            return Ok(());
+        }
+        self.update_batch(states)
+    }
+
+    fn evaluate(&mut self) -> Result<ScalarValue> {
+        match self.state {
+            TrySumState::Initial | TrySumState::Overflow => {
+                ScalarValue::try_new_null(&self.data_type)
+            }
+            TrySumState::ValidSum(sum) => {
+                ScalarValue::new_primitive::<T>(Some(sum), &self.data_type)
+            }
+        }
+    }
+
+    fn size(&self) -> usize {
+        size_of_val(self)
+    }
+}
+
+// Only difference from TrySumAccumulator is that it verifies the resulting sum
+// can fit within the decimals precision; if Rust had specialization we could 
unify
+// the two types (╥﹏╥)
+#[derive(Debug)]
+struct TrySumDecimalAccumulator<T: DecimalType + std::fmt::Debug> {
+    inner: TrySumAccumulator<T>,
+}

Review Comment:
   Open to any ideas to unify these structs



##########
datafusion/spark/src/function/aggregate/try_sum.rs:
##########
@@ -15,22 +15,20 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::array::{ArrayRef, ArrowNumericType, AsArray, BooleanArray, 
PrimitiveArray};
-use arrow::datatypes::{
-    DECIMAL128_MAX_PRECISION, DataType, Decimal128Type, Field, FieldRef, 
Float64Type,
-    Int64Type,
-};
-use datafusion_common::{Result, ScalarValue, downcast_value, exec_err, 
not_impl_err};
+use arrow::datatypes::{DataType, FieldRef};
+use datafusion_common::Result;
 use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
-use datafusion_expr::utils::format_state_name;
-use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
+use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature};
+use datafusion_functions_aggregate::sum::Sum;
 use std::any::Any;
-use std::fmt::{Debug, Formatter};
-use std::mem::size_of_val;
+use std::fmt::Debug;
 
-#[derive(PartialEq, Eq, Hash)]
+/// Thin wrapper over DataFusion native [`Sum`] which is configurable into a 
try
+/// sum mode to return `null` on overflows. We need this thin wrapper to 
provide
+/// the `try_sum` named function for use in Spark.
+#[derive(PartialEq, Eq, Hash, Debug)]

Review Comment:
   Would be nice if we had a simpler way to do these types of thin wrappers 🤔 



##########
datafusion/spark/src/function/aggregate/try_sum.rs:
##########
@@ -257,404 +55,18 @@ impl AggregateUDFImpl for SparkTrySum {
     }
 
     fn signature(&self) -> &Signature {
-        &self.signature
+        self.inner.signature()
     }
 
     fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        use DataType::*;
-
-        let dt = &arg_types[0];
-        let result_type = match dt {
-            Null => Float64,
-            Decimal128(p, s) => {
-                let new_precision = DECIMAL128_MAX_PRECISION.min(p + 10);
-                Decimal128(new_precision, *s)
-            }
-            Int8 | Int16 | Int32 | Int64 => Int64,
-            Float16 | Float32 | Float64 => Float64,
-
-            other => return exec_err!("try_sum: unsupported type: {other:?}"),
-        };
-
-        Ok(result_type)
+        self.inner.return_type(arg_types)
     }
 
     fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
-        macro_rules! helper {
-            ($t:ty, $dt:expr) => {
-                Ok(Box::new(TrySumAccumulator::<$t>::new($dt.clone())))
-            };
-        }
-
-        match acc_args.return_field.data_type() {
-            DataType::Int64 => helper!(Int64Type, 
acc_args.return_field.data_type()),
-            DataType::Float64 => helper!(Float64Type, 
acc_args.return_field.data_type()),
-            DataType::Decimal128(_, _) => {
-                helper!(Decimal128Type, acc_args.return_field.data_type())
-            }
-            _ => not_impl_err!(
-                "try_sum: unsupported type for accumulator: {}",
-                acc_args.return_field.data_type()
-            ),
-        }
+        self.inner.accumulator(acc_args)
     }
 
     fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
-        let sum_dt = args.return_field.data_type().clone();
-        Ok(vec![
-            Field::new(format_state_name(args.name, "sum"), sum_dt, 
true).into(),
-            Field::new(
-                format_state_name(args.name, "failed"),
-                DataType::Boolean,
-                false,
-            )
-            .into(),
-        ])
-    }
-
-    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
-        use DataType::*;
-        if arg_types.len() != 1 {
-            return exec_err!(
-                "try_sum: exactly 1 argument expected, got {}",
-                arg_types.len()
-            );
-        }
-
-        let dt = &arg_types[0];
-        let coerced = match dt {
-            Null => Float64,
-            Decimal128(p, s) => Decimal128(*p, *s),
-            Int8 | Int16 | Int32 | Int64 => Int64,
-            Float16 | Float32 | Float64 => Float64,
-            other => return exec_err!("try_sum: unsupported type: {other:?}"),
-        };
-        Ok(vec![coerced])
-    }
-
-    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
-        Ok(ScalarValue::Null)
-    }
-}
-
-#[cfg(test)]
-mod tests {

Review Comment:
   These tests seemed to already be covered by SLTs so removed them



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to