This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new fd0ab6441 feat: Support ANSI mode SUM (Decimal types) (#2826)
fd0ab6441 is described below

commit fd0ab6441b078a8875b614ca2d615a102c592606
Author: B Vadlamani <[email protected]>
AuthorDate: Fri Dec 5 12:36:18 2025 -0800

    feat: Support ANSI mode SUM (Decimal types) (#2826)
---
 native/core/src/execution/planner.rs               |   4 +-
 native/proto/src/proto/expr.proto                  |   2 +-
 native/spark-expr/benches/aggregate.rs             |   4 +-
 native/spark-expr/src/agg_funcs/sum_decimal.rs     | 396 +++++++++++++--------
 .../scala/org/apache/comet/serde/aggregates.scala  |  13 +-
 .../apache/comet/exec/CometAggregateSuite.scala    | 165 ++++++++-
 6 files changed, 415 insertions(+), 169 deletions(-)

diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index d9c575fcb..d09393fc9 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -2022,7 +2022,9 @@ impl PhysicalPlanner {
 
                 let builder = match datatype {
                     DataType::Decimal128(_, _) => {
-                        let func = 
AggregateUDF::new_from_impl(SumDecimal::try_new(datatype)?);
+                        let eval_mode = 
from_protobuf_eval_mode(expr.eval_mode)?;
+                        let func =
+                            
AggregateUDF::new_from_impl(SumDecimal::try_new(datatype, eval_mode)?);
                         AggregateExprBuilder::new(Arc::new(func), vec![child])
                     }
                     _ => {
diff --git a/native/proto/src/proto/expr.proto 
b/native/proto/src/proto/expr.proto
index c9037dcd6..a7736f561 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -120,7 +120,7 @@ message Count {
 message Sum {
   Expr child = 1;
   DataType datatype = 2;
-  bool fail_on_error = 3;
+  EvalMode eval_mode = 3;
 }
 
 message Min {
diff --git a/native/spark-expr/benches/aggregate.rs 
b/native/spark-expr/benches/aggregate.rs
index 3aa023371..72628975b 100644
--- a/native/spark-expr/benches/aggregate.rs
+++ b/native/spark-expr/benches/aggregate.rs
@@ -31,8 +31,8 @@ use datafusion::physical_expr::expressions::Column;
 use datafusion::physical_expr::PhysicalExpr;
 use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, 
PhysicalGroupBy};
 use datafusion::physical_plan::ExecutionPlan;
-use datafusion_comet_spark_expr::AvgDecimal;
 use datafusion_comet_spark_expr::SumDecimal;
+use datafusion_comet_spark_expr::{AvgDecimal, EvalMode};
 use futures::StreamExt;
 use std::hint::black_box;
 use std::sync::Arc;
@@ -97,7 +97,7 @@ fn criterion_benchmark(c: &mut Criterion) {
 
     group.bench_function("sum_decimal_comet", |b| {
         let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(
-            SumDecimal::try_new(DataType::Decimal128(38, 10)).unwrap(),
+            SumDecimal::try_new(DataType::Decimal128(38, 10), 
EvalMode::Legacy).unwrap(),
         ));
         b.to_async(&rt).iter(|| {
             black_box(agg_test(
diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs 
b/native/spark-expr/src/agg_funcs/sum_decimal.rs
index cc2585590..50645391f 100644
--- a/native/spark-expr/src/agg_funcs/sum_decimal.rs
+++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs
@@ -15,19 +15,19 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::utils::{build_bool_state, is_valid_decimal_precision};
+use crate::utils::is_valid_decimal_precision;
+use crate::{arithmetic_overflow_error, EvalMode};
 use arrow::array::{
     cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, 
Decimal128Array,
 };
 use arrow::datatypes::{DataType, Field, FieldRef};
-use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer};
 use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue};
 use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
 use datafusion::logical_expr::Volatility::Immutable;
 use datafusion::logical_expr::{
     Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, 
Signature,
 };
-use std::{any::Any, ops::BitAnd, sync::Arc};
+use std::{any::Any, sync::Arc};
 
 #[derive(Debug, PartialEq, Eq, Hash)]
 pub struct SumDecimal {
@@ -40,11 +40,11 @@ pub struct SumDecimal {
     precision: u8,
     /// Decimal scale
     scale: i8,
+    eval_mode: EvalMode,
 }
 
 impl SumDecimal {
-    pub fn try_new(data_type: DataType) -> DFResult<Self> {
-        // The `data_type` is the SUM result type passed from Spark side
+    pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult<Self> 
{
         let (precision, scale) = match data_type {
             DataType::Decimal128(p, s) => (p, s),
             _ => {
@@ -58,6 +58,7 @@ impl SumDecimal {
             result_type: data_type,
             precision,
             scale,
+            eval_mode,
         })
     }
 }
@@ -71,19 +72,18 @@ impl AggregateUDFImpl for SumDecimal {
         Ok(Box::new(SumDecimalAccumulator::new(
             self.precision,
             self.scale,
+            self.eval_mode,
         )))
     }
 
     fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<FieldRef>> {
-        let fields = vec![
-            Arc::new(Field::new(
-                self.name(),
-                self.result_type.clone(),
-                self.is_nullable(),
-            )),
+        // For decimal sum, we always track is_empty regardless of eval_mode
+        // This matches Spark's behavior where DecimalType always uses 
shouldTrackIsEmpty = true
+        let data_type = self.result_type.clone();
+        Ok(vec![
+            Arc::new(Field::new("sum", data_type, true)),
             Arc::new(Field::new("is_empty", DataType::Boolean, false)),
-        ];
-        Ok(fields)
+        ])
     }
 
     fn name(&self) -> &str {
@@ -109,6 +109,7 @@ impl AggregateUDFImpl for SumDecimal {
         Ok(Box::new(SumDecimalGroupsAccumulator::new(
             self.result_type.clone(),
             self.precision,
+            self.eval_mode,
         )))
     }
 
@@ -131,37 +132,48 @@ impl AggregateUDFImpl for SumDecimal {
 
 #[derive(Debug)]
 struct SumDecimalAccumulator {
-    sum: i128,
+    sum: Option<i128>,
     is_empty: bool,
-    is_not_null: bool,
-
     precision: u8,
     scale: i8,
+    eval_mode: EvalMode,
 }
 
 impl SumDecimalAccumulator {
-    fn new(precision: u8, scale: i8) -> Self {
+    fn new(precision: u8, scale: i8, eval_mode: EvalMode) -> Self {
+        // For decimal sum, always track is_empty regardless of eval_mode
+        // This matches Spark's behavior where DecimalType always uses 
shouldTrackIsEmpty = true
         Self {
-            sum: 0,
+            sum: Some(0),
             is_empty: true,
-            is_not_null: true,
             precision,
             scale,
+            eval_mode,
         }
     }
 
-    fn update_single(&mut self, values: &Decimal128Array, idx: usize) {
+    fn update_single(&mut self, values: &Decimal128Array, idx: usize) -> 
DFResult<()> {
+        // If already overflowed (sum is None but not empty), stay in overflow 
state
+        if !self.is_empty && self.sum.is_none() {
+            return Ok(());
+        }
+
         let v = unsafe { values.value_unchecked(idx) };
-        let (new_sum, is_overflow) = self.sum.overflowing_add(v);
+        let running_sum = self.sum.unwrap_or(0);
+        let (new_sum, is_overflow) = running_sum.overflowing_add(v);
 
         if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) 
{
-            // Overflow: set buffer accumulator to null
-            self.is_not_null = false;
-            return;
+            if self.eval_mode == EvalMode::Ansi {
+                return 
Err(DataFusionError::from(arithmetic_overflow_error("decimal")));
+            }
+            self.sum = None;
+            self.is_empty = false;
+            return Ok(());
         }
 
-        self.sum = new_sum;
-        self.is_not_null = true;
+        self.sum = Some(new_sum);
+        self.is_empty = false;
+        Ok(())
     }
 }
 
@@ -174,49 +186,46 @@ impl Accumulator for SumDecimalAccumulator {
             values.len()
         );
 
-        if !self.is_empty && !self.is_not_null {
-            // This means there's a overflow in decimal, so we will just skip 
the rest
-            // of the computation
+        // For decimal sum, always check for overflow regardless of eval_mode 
(per Spark's expectation)
+        if !self.is_empty && self.sum.is_none() {
             return Ok(());
         }
 
         let values = &values[0];
         let data = values.as_primitive::<Decimal128Type>();
 
+        // Update is_empty: it remains true only if it was true AND all values 
are null
         self.is_empty = self.is_empty && values.len() == values.null_count();
 
-        if values.null_count() == 0 {
-            for i in 0..data.len() {
-                self.update_single(data, i);
-            }
-        } else {
-            for i in 0..data.len() {
-                if data.is_null(i) {
-                    continue;
-                }
-                self.update_single(data, i);
-            }
+        if self.is_empty {
+            return Ok(());
         }
 
+        for i in 0..data.len() {
+            if data.is_null(i) {
+                continue;
+            }
+            self.update_single(data, i)?;
+        }
         Ok(())
     }
 
     fn evaluate(&mut self) -> DFResult<ScalarValue> {
-        // For each group:
-        //   1. if `is_empty` is true, it means either there is no value or 
all values for the group
-        //      are null, in this case we'll return null
-        //   2. if `is_empty` is false, but `null_state` is true, it means 
there's an overflow. In
-        //      non-ANSI mode Spark returns null.
-        if self.is_empty
-            || !self.is_not_null
-            || !is_valid_decimal_precision(self.sum, self.precision)
-        {
+        if self.is_empty {
             ScalarValue::new_primitive::<Decimal128Type>(
                 None,
                 &DataType::Decimal128(self.precision, self.scale),
             )
         } else {
-            ScalarValue::try_new_decimal128(self.sum, self.precision, 
self.scale)
+            match self.sum {
+                Some(sum_value) if is_valid_decimal_precision(sum_value, 
self.precision) => {
+                    ScalarValue::try_new_decimal128(sum_value, self.precision, 
self.scale)
+                }
+                _ => ScalarValue::new_primitive::<Decimal128Type>(
+                    None,
+                    &DataType::Decimal128(self.precision, self.scale),
+                ),
+            }
         }
     }
 
@@ -225,38 +234,71 @@ impl Accumulator for SumDecimalAccumulator {
     }
 
     fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
-        let sum = if self.is_not_null {
-            ScalarValue::try_new_decimal128(self.sum, self.precision, 
self.scale)?
-        } else {
-            ScalarValue::new_primitive::<Decimal128Type>(
+        let sum = match self.sum {
+            Some(sum_value) => {
+                ScalarValue::try_new_decimal128(sum_value, self.precision, 
self.scale)?
+            }
+            None => ScalarValue::new_primitive::<Decimal128Type>(
                 None,
                 &DataType::Decimal128(self.precision, self.scale),
-            )?
+            )?,
         };
+
+        // For decimal sum, always return 2 state values regardless of 
eval_mode
         Ok(vec![sum, ScalarValue::from(self.is_empty)])
     }
 
     fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
+        // For decimal sum, always expect 2 state arrays regardless of 
eval_mode
         assert_eq!(
             states.len(),
             2,
-            "Expect two element in 'states' but found {}",
+            "Expect two elements in 'states' but found {}",
             states.len()
         );
         assert_eq!(states[0].len(), 1);
         assert_eq!(states[1].len(), 1);
 
-        let that_sum = states[0].as_primitive::<Decimal128Type>();
-        let that_is_empty = 
states[1].as_any().downcast_ref::<BooleanArray>().unwrap();
+        let that_sum_array = states[0].as_primitive::<Decimal128Type>();
+        let that_sum = if that_sum_array.is_null(0) {
+            None
+        } else {
+            Some(that_sum_array.value(0))
+        };
+
+        let that_is_empty = states[1].as_boolean().value(0);
+        let that_overflowed = !that_is_empty && that_sum.is_none();
+        let this_overflowed = !self.is_empty && self.sum.is_none();
 
-        let this_overflow = !self.is_empty && !self.is_not_null;
-        let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0);
+        if that_overflowed || this_overflowed {
+            self.sum = None;
+            self.is_empty = false;
+            return Ok(());
+        }
+
+        if that_is_empty {
+            return Ok(());
+        }
+
+        if self.is_empty {
+            self.sum = that_sum;
+            self.is_empty = false;
+            return Ok(());
+        }
 
-        self.is_not_null = !this_overflow && !that_overflow;
-        self.is_empty = self.is_empty && that_is_empty.value(0);
+        let left = self.sum.unwrap();
+        let right = that_sum.unwrap();
+        let (new_sum, is_overflow) = left.overflowing_add(right);
 
-        if self.is_not_null {
-            self.sum += that_sum.value(0);
+        if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) 
{
+            if self.eval_mode == EvalMode::Ansi {
+                return 
Err(DataFusionError::from(arithmetic_overflow_error("decimal")));
+            } else {
+                self.sum = None;
+                self.is_empty = false;
+            }
+        } else {
+            self.sum = Some(new_sum);
         }
 
         Ok(())
@@ -264,46 +306,50 @@ impl Accumulator for SumDecimalAccumulator {
 }
 
 struct SumDecimalGroupsAccumulator {
-    // Whether aggregate buffer for a particular group is null. True indicates 
it is not null.
-    is_not_null: BooleanBufferBuilder,
-    is_empty: BooleanBufferBuilder,
-    sum: Vec<i128>,
+    sum: Vec<Option<i128>>,
+    is_empty: Vec<bool>,
     result_type: DataType,
     precision: u8,
+    eval_mode: EvalMode,
 }
 
 impl SumDecimalGroupsAccumulator {
-    fn new(result_type: DataType, precision: u8) -> Self {
+    fn new(result_type: DataType, precision: u8, eval_mode: EvalMode) -> Self {
         Self {
-            is_not_null: BooleanBufferBuilder::new(0),
-            is_empty: BooleanBufferBuilder::new(0),
             sum: Vec::new(),
+            is_empty: Vec::new(),
             result_type,
             precision,
+            eval_mode,
         }
     }
 
-    fn is_overflow(&self, index: usize) -> bool {
-        !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index)
+    fn resize_helper(&mut self, total_num_groups: usize) {
+        // For decimal sum, always initialize properly regardless of eval_mode
+        self.sum.resize(total_num_groups, Some(0));
+        self.is_empty.resize(total_num_groups, true);
     }
 
     #[inline]
-    fn update_single(&mut self, group_index: usize, value: i128) {
-        self.is_empty.set_bit(group_index, false);
-        let (new_sum, is_overflow) = 
self.sum[group_index].overflowing_add(value);
-        self.sum[group_index] = new_sum;
+    fn update_single(&mut self, group_index: usize, value: i128) -> 
DFResult<()> {
+        // For decimal sum, always check for overflow regardless of eval_mode
+        if !self.is_empty[group_index] && self.sum[group_index].is_none() {
+            return Ok(());
+        }
+
+        let running_sum = self.sum[group_index].unwrap_or(0);
+        let (new_sum, is_overflow) = running_sum.overflowing_add(value);
 
         if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) 
{
-            // Overflow: set buffer accumulator to null
-            self.is_not_null.set_bit(group_index, false);
+            if self.eval_mode == EvalMode::Ansi {
+                return 
Err(DataFusionError::from(arithmetic_overflow_error("decimal")));
+            }
+            self.sum[group_index] = None;
+        } else {
+            self.sum[group_index] = Some(new_sum);
         }
-    }
-}
-
-fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) {
-    if builder.len() < capacity {
-        let additional = capacity - builder.len();
-        builder.append_n(additional, true);
+        self.is_empty[group_index] = false;
+        Ok(())
     }
 }
 
@@ -320,22 +366,19 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
         let values = values[0].as_primitive::<Decimal128Type>();
         let data = values.values();
 
-        // Update size for the accumulate states
-        self.sum.resize(total_num_groups, 0);
-        ensure_bit_capacity(&mut self.is_empty, total_num_groups);
-        ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
+        self.resize_helper(total_num_groups);
 
         let iter = group_indices.iter().zip(data.iter());
         if values.null_count() == 0 {
             for (&group_index, &value) in iter {
-                self.update_single(group_index, value);
+                self.update_single(group_index, value)?;
             }
         } else {
             for (idx, (&group_index, &value)) in iter.enumerate() {
                 if values.is_null(idx) {
                     continue;
                 }
-                self.update_single(group_index, value);
+                self.update_single(group_index, value)?;
             }
         }
 
@@ -343,42 +386,65 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
     }
 
     fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
-        // For each group:
-        //   1. if `is_empty` is true, it means either there is no value or 
all values for the group
-        //      are null, in this case we'll return null
-        //   2. if `is_empty` is false, but `null_state` is true, it means 
there's an overflow. In
-        //      non-ANSI mode Spark returns null.
-        let result = emit_to.take_needed(&mut self.sum);
-        result.iter().enumerate().for_each(|(i, &v)| {
-            if !is_valid_decimal_precision(v, self.precision) {
-                self.is_not_null.set_bit(i, false);
+        match emit_to {
+            EmitTo::All => {
+                let result =
+                    
Decimal128Array::from_iter(self.sum.iter().zip(self.is_empty.iter()).map(
+                        |(&sum, &empty)| {
+                            if empty {
+                                None
+                            } else {
+                                match sum {
+                                    Some(v) if is_valid_decimal_precision(v, 
self.precision) => {
+                                        Some(v)
+                                    }
+                                    _ => None,
+                                }
+                            }
+                        },
+                    ))
+                    .with_data_type(self.result_type.clone());
+
+                self.sum.clear();
+                self.is_empty.clear();
+                Ok(Arc::new(result))
             }
-        });
-
-        let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
-        let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
-        let x = (!&is_empty).bitand(&nulls);
-
-        let result = Decimal128Array::new(result.into(), 
Some(NullBuffer::new(x)))
-            .with_data_type(self.result_type.clone());
-
-        Ok(Arc::new(result))
+            EmitTo::First(n) => {
+                let result = Decimal128Array::from_iter(
+                    self.sum
+                        .drain(..n)
+                        .zip(self.is_empty.drain(..n))
+                        .map(|(sum, empty)| {
+                            if empty {
+                                None
+                            } else {
+                                match sum {
+                                    Some(v) if is_valid_decimal_precision(v, 
self.precision) => {
+                                        Some(v)
+                                    }
+                                    _ => None,
+                                }
+                            }
+                        }),
+                )
+                .with_data_type(self.result_type.clone());
+
+                Ok(Arc::new(result))
+            }
+        }
     }
 
     fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
-        let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
-        let nulls = Some(NullBuffer::new(nulls));
+        let sums = emit_to.take_needed(&mut self.sum);
 
-        let sum = emit_to.take_needed(&mut self.sum);
-        let sum = Decimal128Array::new(sum.into(), nulls.clone())
+        let sum_array = Decimal128Array::from_iter(sums.iter().copied())
             .with_data_type(self.result_type.clone());
 
-        let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
-        let is_empty = BooleanArray::new(is_empty, None);
-
+        // For decimal sum, always return 2 state arrays regardless of 
eval_mode
+        let is_empty = emit_to.take_needed(&mut self.is_empty);
         Ok(vec![
-            Arc::new(sum) as ArrayRef,
-            Arc::new(is_empty) as ArrayRef,
+            Arc::new(sum_array),
+            Arc::new(BooleanArray::from(is_empty)),
         ])
     }
 
@@ -389,57 +455,70 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
         opt_filter: Option<&BooleanArray>,
         total_num_groups: usize,
     ) -> DFResult<()> {
+        assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+
+        self.resize_helper(total_num_groups);
+
+        // For decimal sum, always expect 2 arrays regardless of eval_mode
         assert_eq!(
             values.len(),
             2,
             "Expected two arrays: 'sum' and 'is_empty', but found {}",
             values.len()
         );
-        assert!(opt_filter.is_none(), "opt_filter is not supported yet");
 
-        // Make sure we have enough capacity for the additional groups
-        self.sum.resize(total_num_groups, 0);
-        ensure_bit_capacity(&mut self.is_empty, total_num_groups);
-        ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
-
-        let that_sum = &values[0];
-        let that_sum = that_sum.as_primitive::<Decimal128Type>();
-        let that_is_empty = &values[1];
-        let that_is_empty = that_is_empty
-            .as_any()
-            .downcast_ref::<BooleanArray>()
-            .unwrap();
+        let that_sum = values[0].as_primitive::<Decimal128Type>();
+        let that_is_empty = values[1].as_boolean();
+
+        for (idx, &group_index) in group_indices.iter().enumerate() {
+            let that_sum_val = if that_sum.is_null(idx) {
+                None
+            } else {
+                Some(that_sum.value(idx))
+            };
 
-        group_indices
-            .iter()
-            .enumerate()
-            .for_each(|(idx, &group_index)| unsafe {
-                let this_overflow = self.is_overflow(group_index);
-                let that_is_empty = that_is_empty.value_unchecked(idx);
-                let that_overflow = !that_is_empty && that_sum.is_null(idx);
-                let is_overflow = this_overflow || that_overflow;
-
-                // This part follows the logic in Spark:
-                //   `org.apache.spark.sql.catalyst.expressions.aggregate.Sum`
-                self.is_not_null.set_bit(group_index, !is_overflow);
-                self.is_empty.set_bit(
-                    group_index,
-                    self.is_empty.get_bit(group_index) && that_is_empty,
-                );
-                if !is_overflow {
-                    // .. otherwise, the sum value for this particular index 
must not be null,
-                    // and thus we merge both values and update this sum.
-                    self.sum[group_index] += that_sum.value_unchecked(idx);
+            let that_is_empty_val = that_is_empty.value(idx);
+            let that_overflowed = !that_is_empty_val && that_sum_val.is_none();
+            let this_overflowed = !self.is_empty[group_index] && 
self.sum[group_index].is_none();
+
+            if that_overflowed || this_overflowed {
+                self.sum[group_index] = None;
+                self.is_empty[group_index] = false;
+                continue;
+            }
+
+            if that_is_empty_val {
+                continue;
+            }
+
+            if self.is_empty[group_index] {
+                self.sum[group_index] = that_sum_val;
+                self.is_empty[group_index] = false;
+                continue;
+            }
+
+            let left = self.sum[group_index].unwrap();
+            let right = that_sum_val.unwrap();
+            let (new_sum, is_overflow) = left.overflowing_add(right);
+
+            if is_overflow || !is_valid_decimal_precision(new_sum, 
self.precision) {
+                if self.eval_mode == EvalMode::Ansi {
+                    return 
Err(DataFusionError::from(arithmetic_overflow_error("decimal")));
+                } else {
+                    self.sum[group_index] = None;
+                    self.is_empty[group_index] = false;
                 }
-            });
+            } else {
+                self.sum[group_index] = Some(new_sum);
+            }
+        }
 
         Ok(())
     }
 
     fn size(&self) -> usize {
-        self.sum.capacity() * std::mem::size_of::<i128>()
-            + self.is_empty.capacity() / 8
-            + self.is_not_null.capacity() / 8
+        self.sum.capacity() * std::mem::size_of::<Option<i128>>()
+            + self.is_empty.capacity() * std::mem::size_of::<bool>()
     }
 }
 
@@ -463,7 +542,7 @@ mod tests {
 
     #[test]
     fn invalid_data_type() {
-        assert!(SumDecimal::try_new(DataType::Int32).is_err());
+        assert!(SumDecimal::try_new(DataType::Int32, 
EvalMode::Legacy).is_err());
     }
 
     #[tokio::test]
@@ -486,6 +565,7 @@ mod tests {
 
         let aggregate_udf = 
Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new(
             data_type.clone(),
+            EvalMode::Legacy,
         )?));
 
         let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1])
diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala 
b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
index d00bbf4df..8ab568dc8 100644
--- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
@@ -29,7 +29,8 @@ import org.apache.spark.sql.types.{ByteType, DataTypes, 
DecimalType, IntegerType
 import org.apache.comet.CometConf
 import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT
 import org.apache.comet.CometSparkSessionExtensions.withInfo
-import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
+import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, 
serializeDataType}
+import org.apache.comet.shims.CometEvalModeUtil
 
 object CometMin extends CometAggregateExpressionSerde[Min] {
 
@@ -214,10 +215,10 @@ object CometSum extends 
CometAggregateExpressionSerde[Sum] {
 
   override def getSupportLevel(sum: Sum): SupportLevel = {
     sum.evalMode match {
-      case EvalMode.ANSI =>
-        Incompatible(Some("ANSI mode is not supported"))
-      case EvalMode.TRY =>
-        Incompatible(Some("TRY mode is not supported"))
+      case EvalMode.ANSI if !sum.dataType.isInstanceOf[DecimalType] =>
+        Incompatible(Some("ANSI mode for non decimal inputs is not supported"))
+      case EvalMode.TRY if !sum.dataType.isInstanceOf[DecimalType] =>
+        Incompatible(Some("TRY mode for non decimal inputs is not supported"))
       case _ =>
         Compatible()
     }
@@ -242,7 +243,7 @@ object CometSum extends CometAggregateExpressionSerde[Sum] {
       val builder = ExprOuterClass.Sum.newBuilder()
       builder.setChild(childExpr.get)
       builder.setDatatype(dataType.get)
-      builder.setFailOnError(sum.evalMode == EvalMode.ANSI)
+      
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(sum.evalMode)))
 
       Some(
         ExprOuterClass.AggExpr
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 7e577c5fd..060579b2b 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -24,10 +24,11 @@ import scala.util.Random
 import org.apache.hadoop.fs.Path
 import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
 import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
 import org.apache.spark.sql.catalyst.optimizer.EliminateSorts
 import org.apache.spark.sql.comet.CometHashAggregateExec
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.functions.{avg, count_distinct, sum}
+import org.apache.spark.sql.functions.{avg, col, count_distinct, sum}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
 
@@ -1471,6 +1472,168 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("ANSI support for decimal sum - null test") {
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf(
+        SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+        CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+        withParquetTable(
+          Seq(
+            (null.asInstanceOf[java.math.BigDecimal], "a"),
+            (null.asInstanceOf[java.math.BigDecimal], "b")),
+          "null_tbl") {
+          val res = sql("SELECT sum(_1) FROM null_tbl")
+          checkSparkAnswerAndOperator(res)
+          assert(res.collect() === Array(Row(null)))
+        }
+      }
+    }
+  }
+
+  test("ANSI support for try_sum decimal - null test") {
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf(
+        SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+        CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+        withParquetTable(
+          Seq(
+            (null.asInstanceOf[java.math.BigDecimal], "a"),
+            (null.asInstanceOf[java.math.BigDecimal], "b")),
+          "null_tbl") {
+          val res = sql("SELECT try_sum(_1) FROM null_tbl")
+          checkSparkAnswerAndOperator(res)
+          assert(res.collect() === Array(Row(null)))
+        }
+      }
+    }
+  }
+
+  test("ANSI support for decimal sum - null test (group by)") {
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf(
+        SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+        CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+        withParquetTable(
+          Seq(
+            (null.asInstanceOf[java.math.BigDecimal], "a"),
+            (null.asInstanceOf[java.math.BigDecimal], "a"),
+            (null.asInstanceOf[java.math.BigDecimal], "b"),
+            (null.asInstanceOf[java.math.BigDecimal], "b"),
+            (null.asInstanceOf[java.math.BigDecimal], "b")),
+          "tbl") {
+          val res = sql("SELECT _2, sum(_1) FROM tbl group by 1")
+          checkSparkAnswerAndOperator(res)
+          assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), 
Row("b", null)))
+        }
+      }
+    }
+  }
+
+  test("ANSI support for try_sum decimal - null test (group by)") {
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf(
+        SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+        CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+        withParquetTable(
+          Seq(
+            (null.asInstanceOf[java.math.BigDecimal], "a"),
+            (null.asInstanceOf[java.math.BigDecimal], "a"),
+            (null.asInstanceOf[java.math.BigDecimal], "b"),
+            (null.asInstanceOf[java.math.BigDecimal], "b"),
+            (null.asInstanceOf[java.math.BigDecimal], "b")),
+          "tbl") {
+          val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1")
+          checkSparkAnswerAndOperator(res)
+          assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), 
Row("b", null)))
+        }
+      }
+    }
+  }
+
+  protected def generateOverflowDecimalInputs: Seq[(java.math.BigDecimal, 
Int)] = {
+    val maxDec38_0 = new java.math.BigDecimal("99999999999999999999")
+    (1 to 50).flatMap(_ => Seq((maxDec38_0, 1)))
+  }
+
+  test("ANSI support for decimal SUM function") {
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf(
+        SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+        CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+        withParquetTable(generateOverflowDecimalInputs, "tbl") {
+          val res = sql("SELECT SUM(_1) FROM tbl")
+          if (ansiEnabled) {
+            checkSparkAnswerMaybeThrows(res) match {
+              case (Some(sparkExc), Some(cometExc)) =>
+                assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+                assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+              case _ =>
+                fail("Exception should be thrown for decimal overflow in ANSI 
mode")
+            }
+          } else {
+            checkSparkAnswerAndOperator(res)
+          }
+        }
+      }
+    }
+  }
+
+  test("ANSI support for decimal SUM - GROUP BY") {
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf(
+        SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+        CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+        withParquetTable(generateOverflowDecimalInputs, "tbl") {
+          val res =
+            sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2)
+          if (ansiEnabled) {
+            checkSparkAnswerMaybeThrows(res) match {
+              case (Some(sparkExc), Some(cometExc)) =>
+                assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+                assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+              case _ =>
+                fail("Exception should be thrown for decimal overflow with 
GROUP BY in ANSI mode")
+            }
+          } else {
+            checkSparkAnswerAndOperator(res)
+          }
+        }
+      }
+    }
+  }
+
+  test("try_sum decimal overflow") {
+    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> 
"true") {
+      withParquetTable(generateOverflowDecimalInputs, "tbl") {
+        val res = sql("SELECT try_sum(_1) FROM tbl")
+        checkSparkAnswerAndOperator(res)
+      }
+    }
+  }
+
+  test("try_sum decimal overflow - with GROUP BY") {
+    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> 
"true") {
+      withParquetTable(generateOverflowDecimalInputs, "tbl") {
+        val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY 
_2").repartition(2, col("_2"))
+        checkSparkAnswerAndOperator(res)
+      }
+    }
+  }
+
+  test("try_sum decimal partial overflow - with GROUP BY") {
+    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> 
"true") {
+      // Group 1 overflows, Group 2 succeeds
+      val data: Seq[(java.math.BigDecimal, Int)] = 
generateOverflowDecimalInputs ++ Seq(
+        (new java.math.BigDecimal(300), 2),
+        (new java.math.BigDecimal(200), 2))
+      withParquetTable(data, "tbl") {
+        val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2")
+        // Group 1 should be NULL, Group 2 should be 500
+        checkSparkAnswerAndOperator(res)
+      }
+    }
+  }
+
   protected def checkSparkAnswerAndNumOfAggregates(query: String, 
numAggregates: Int): Unit = {
     val df = sql(query)
     checkSparkAnswer(df)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to