This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new fd0ab6441 feat: Support ANSI mode SUM (Decimal types) (#2826)
fd0ab6441 is described below
commit fd0ab6441b078a8875b614ca2d615a102c592606
Author: B Vadlamani <[email protected]>
AuthorDate: Fri Dec 5 12:36:18 2025 -0800
feat: Support ANSI mode SUM (Decimal types) (#2826)
---
native/core/src/execution/planner.rs | 4 +-
native/proto/src/proto/expr.proto | 2 +-
native/spark-expr/benches/aggregate.rs | 4 +-
native/spark-expr/src/agg_funcs/sum_decimal.rs | 396 +++++++++++++--------
.../scala/org/apache/comet/serde/aggregates.scala | 13 +-
.../apache/comet/exec/CometAggregateSuite.scala | 165 ++++++++-
6 files changed, 415 insertions(+), 169 deletions(-)
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index d9c575fcb..d09393fc9 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -2022,7 +2022,9 @@ impl PhysicalPlanner {
let builder = match datatype {
DataType::Decimal128(_, _) => {
- let func =
AggregateUDF::new_from_impl(SumDecimal::try_new(datatype)?);
+ let eval_mode =
from_protobuf_eval_mode(expr.eval_mode)?;
+ let func =
+
AggregateUDF::new_from_impl(SumDecimal::try_new(datatype, eval_mode)?);
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
_ => {
diff --git a/native/proto/src/proto/expr.proto
b/native/proto/src/proto/expr.proto
index c9037dcd6..a7736f561 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -120,7 +120,7 @@ message Count {
message Sum {
Expr child = 1;
DataType datatype = 2;
- bool fail_on_error = 3;
+ EvalMode eval_mode = 3;
}
message Min {
diff --git a/native/spark-expr/benches/aggregate.rs
b/native/spark-expr/benches/aggregate.rs
index 3aa023371..72628975b 100644
--- a/native/spark-expr/benches/aggregate.rs
+++ b/native/spark-expr/benches/aggregate.rs
@@ -31,8 +31,8 @@ use datafusion::physical_expr::expressions::Column;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode,
PhysicalGroupBy};
use datafusion::physical_plan::ExecutionPlan;
-use datafusion_comet_spark_expr::AvgDecimal;
use datafusion_comet_spark_expr::SumDecimal;
+use datafusion_comet_spark_expr::{AvgDecimal, EvalMode};
use futures::StreamExt;
use std::hint::black_box;
use std::sync::Arc;
@@ -97,7 +97,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("sum_decimal_comet", |b| {
let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(
- SumDecimal::try_new(DataType::Decimal128(38, 10)).unwrap(),
+ SumDecimal::try_new(DataType::Decimal128(38, 10),
EvalMode::Legacy).unwrap(),
));
b.to_async(&rt).iter(|| {
black_box(agg_test(
diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs
b/native/spark-expr/src/agg_funcs/sum_decimal.rs
index cc2585590..50645391f 100644
--- a/native/spark-expr/src/agg_funcs/sum_decimal.rs
+++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs
@@ -15,19 +15,19 @@
// specific language governing permissions and limitations
// under the License.
-use crate::utils::{build_bool_state, is_valid_decimal_precision};
+use crate::utils::is_valid_decimal_precision;
+use crate::{arithmetic_overflow_error, EvalMode};
use arrow::array::{
cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray,
Decimal128Array,
};
use arrow::datatypes::{DataType, Field, FieldRef};
-use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer};
use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue};
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::Volatility::Immutable;
use datafusion::logical_expr::{
Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF,
Signature,
};
-use std::{any::Any, ops::BitAnd, sync::Arc};
+use std::{any::Any, sync::Arc};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SumDecimal {
@@ -40,11 +40,11 @@ pub struct SumDecimal {
precision: u8,
/// Decimal scale
scale: i8,
+ eval_mode: EvalMode,
}
impl SumDecimal {
- pub fn try_new(data_type: DataType) -> DFResult<Self> {
- // The `data_type` is the SUM result type passed from Spark side
+ pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult<Self>
{
let (precision, scale) = match data_type {
DataType::Decimal128(p, s) => (p, s),
_ => {
@@ -58,6 +58,7 @@ impl SumDecimal {
result_type: data_type,
precision,
scale,
+ eval_mode,
})
}
}
@@ -71,19 +72,18 @@ impl AggregateUDFImpl for SumDecimal {
Ok(Box::new(SumDecimalAccumulator::new(
self.precision,
self.scale,
+ self.eval_mode,
)))
}
fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<FieldRef>> {
- let fields = vec![
- Arc::new(Field::new(
- self.name(),
- self.result_type.clone(),
- self.is_nullable(),
- )),
+ // For decimal sum, we always track is_empty regardless of eval_mode
+ // This matches Spark's behavior where DecimalType always uses
shouldTrackIsEmpty = true
+ let data_type = self.result_type.clone();
+ Ok(vec![
+ Arc::new(Field::new("sum", data_type, true)),
Arc::new(Field::new("is_empty", DataType::Boolean, false)),
- ];
- Ok(fields)
+ ])
}
fn name(&self) -> &str {
@@ -109,6 +109,7 @@ impl AggregateUDFImpl for SumDecimal {
Ok(Box::new(SumDecimalGroupsAccumulator::new(
self.result_type.clone(),
self.precision,
+ self.eval_mode,
)))
}
@@ -131,37 +132,48 @@ impl AggregateUDFImpl for SumDecimal {
#[derive(Debug)]
struct SumDecimalAccumulator {
- sum: i128,
+ sum: Option<i128>,
is_empty: bool,
- is_not_null: bool,
-
precision: u8,
scale: i8,
+ eval_mode: EvalMode,
}
impl SumDecimalAccumulator {
- fn new(precision: u8, scale: i8) -> Self {
+ fn new(precision: u8, scale: i8, eval_mode: EvalMode) -> Self {
+ // For decimal sum, always track is_empty regardless of eval_mode
+ // This matches Spark's behavior where DecimalType always uses
shouldTrackIsEmpty = true
Self {
- sum: 0,
+ sum: Some(0),
is_empty: true,
- is_not_null: true,
precision,
scale,
+ eval_mode,
}
}
- fn update_single(&mut self, values: &Decimal128Array, idx: usize) {
+ fn update_single(&mut self, values: &Decimal128Array, idx: usize) ->
DFResult<()> {
+ // If already overflowed (sum is None but not empty), stay in overflow
state
+ if !self.is_empty && self.sum.is_none() {
+ return Ok(());
+ }
+
let v = unsafe { values.value_unchecked(idx) };
- let (new_sum, is_overflow) = self.sum.overflowing_add(v);
+ let running_sum = self.sum.unwrap_or(0);
+ let (new_sum, is_overflow) = running_sum.overflowing_add(v);
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision)
{
- // Overflow: set buffer accumulator to null
- self.is_not_null = false;
- return;
+ if self.eval_mode == EvalMode::Ansi {
+ return
Err(DataFusionError::from(arithmetic_overflow_error("decimal")));
+ }
+ self.sum = None;
+ self.is_empty = false;
+ return Ok(());
}
- self.sum = new_sum;
- self.is_not_null = true;
+ self.sum = Some(new_sum);
+ self.is_empty = false;
+ Ok(())
}
}
@@ -174,49 +186,46 @@ impl Accumulator for SumDecimalAccumulator {
values.len()
);
- if !self.is_empty && !self.is_not_null {
- // This means there's a overflow in decimal, so we will just skip
the rest
- // of the computation
+ // For decimal sum, always check for overflow regardless of eval_mode
(per Spark's expectation)
+ if !self.is_empty && self.sum.is_none() {
return Ok(());
}
let values = &values[0];
let data = values.as_primitive::<Decimal128Type>();
+ // Update is_empty: it remains true only if it was true AND all values
are null
self.is_empty = self.is_empty && values.len() == values.null_count();
- if values.null_count() == 0 {
- for i in 0..data.len() {
- self.update_single(data, i);
- }
- } else {
- for i in 0..data.len() {
- if data.is_null(i) {
- continue;
- }
- self.update_single(data, i);
- }
+ if self.is_empty {
+ return Ok(());
}
+ for i in 0..data.len() {
+ if data.is_null(i) {
+ continue;
+ }
+ self.update_single(data, i)?;
+ }
Ok(())
}
fn evaluate(&mut self) -> DFResult<ScalarValue> {
- // For each group:
- // 1. if `is_empty` is true, it means either there is no value or
all values for the group
- // are null, in this case we'll return null
- // 2. if `is_empty` is false, but `null_state` is true, it means
there's an overflow. In
- // non-ANSI mode Spark returns null.
- if self.is_empty
- || !self.is_not_null
- || !is_valid_decimal_precision(self.sum, self.precision)
- {
+ if self.is_empty {
ScalarValue::new_primitive::<Decimal128Type>(
None,
&DataType::Decimal128(self.precision, self.scale),
)
} else {
- ScalarValue::try_new_decimal128(self.sum, self.precision,
self.scale)
+ match self.sum {
+ Some(sum_value) if is_valid_decimal_precision(sum_value,
self.precision) => {
+ ScalarValue::try_new_decimal128(sum_value, self.precision,
self.scale)
+ }
+ _ => ScalarValue::new_primitive::<Decimal128Type>(
+ None,
+ &DataType::Decimal128(self.precision, self.scale),
+ ),
+ }
}
}
@@ -225,38 +234,71 @@ impl Accumulator for SumDecimalAccumulator {
}
fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
- let sum = if self.is_not_null {
- ScalarValue::try_new_decimal128(self.sum, self.precision,
self.scale)?
- } else {
- ScalarValue::new_primitive::<Decimal128Type>(
+ let sum = match self.sum {
+ Some(sum_value) => {
+ ScalarValue::try_new_decimal128(sum_value, self.precision,
self.scale)?
+ }
+ None => ScalarValue::new_primitive::<Decimal128Type>(
None,
&DataType::Decimal128(self.precision, self.scale),
- )?
+ )?,
};
+
+ // For decimal sum, always return 2 state values regardless of
eval_mode
Ok(vec![sum, ScalarValue::from(self.is_empty)])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
+ // For decimal sum, always expect 2 state arrays regardless of
eval_mode
assert_eq!(
states.len(),
2,
- "Expect two element in 'states' but found {}",
+ "Expect two elements in 'states' but found {}",
states.len()
);
assert_eq!(states[0].len(), 1);
assert_eq!(states[1].len(), 1);
- let that_sum = states[0].as_primitive::<Decimal128Type>();
- let that_is_empty =
states[1].as_any().downcast_ref::<BooleanArray>().unwrap();
+ let that_sum_array = states[0].as_primitive::<Decimal128Type>();
+ let that_sum = if that_sum_array.is_null(0) {
+ None
+ } else {
+ Some(that_sum_array.value(0))
+ };
+
+ let that_is_empty = states[1].as_boolean().value(0);
+ let that_overflowed = !that_is_empty && that_sum.is_none();
+ let this_overflowed = !self.is_empty && self.sum.is_none();
- let this_overflow = !self.is_empty && !self.is_not_null;
- let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0);
+ if that_overflowed || this_overflowed {
+ self.sum = None;
+ self.is_empty = false;
+ return Ok(());
+ }
+
+ if that_is_empty {
+ return Ok(());
+ }
+
+ if self.is_empty {
+ self.sum = that_sum;
+ self.is_empty = false;
+ return Ok(());
+ }
- self.is_not_null = !this_overflow && !that_overflow;
- self.is_empty = self.is_empty && that_is_empty.value(0);
+ let left = self.sum.unwrap();
+ let right = that_sum.unwrap();
+ let (new_sum, is_overflow) = left.overflowing_add(right);
- if self.is_not_null {
- self.sum += that_sum.value(0);
+ if is_overflow || !is_valid_decimal_precision(new_sum, self.precision)
{
+ if self.eval_mode == EvalMode::Ansi {
+ return
Err(DataFusionError::from(arithmetic_overflow_error("decimal")));
+ } else {
+ self.sum = None;
+ self.is_empty = false;
+ }
+ } else {
+ self.sum = Some(new_sum);
}
Ok(())
@@ -264,46 +306,50 @@ impl Accumulator for SumDecimalAccumulator {
}
struct SumDecimalGroupsAccumulator {
- // Whether aggregate buffer for a particular group is null. True indicates
it is not null.
- is_not_null: BooleanBufferBuilder,
- is_empty: BooleanBufferBuilder,
- sum: Vec<i128>,
+ sum: Vec<Option<i128>>,
+ is_empty: Vec<bool>,
result_type: DataType,
precision: u8,
+ eval_mode: EvalMode,
}
impl SumDecimalGroupsAccumulator {
- fn new(result_type: DataType, precision: u8) -> Self {
+ fn new(result_type: DataType, precision: u8, eval_mode: EvalMode) -> Self {
Self {
- is_not_null: BooleanBufferBuilder::new(0),
- is_empty: BooleanBufferBuilder::new(0),
sum: Vec::new(),
+ is_empty: Vec::new(),
result_type,
precision,
+ eval_mode,
}
}
- fn is_overflow(&self, index: usize) -> bool {
- !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index)
+ fn resize_helper(&mut self, total_num_groups: usize) {
+ // For decimal sum, always initialize properly regardless of eval_mode
+ self.sum.resize(total_num_groups, Some(0));
+ self.is_empty.resize(total_num_groups, true);
}
#[inline]
- fn update_single(&mut self, group_index: usize, value: i128) {
- self.is_empty.set_bit(group_index, false);
- let (new_sum, is_overflow) =
self.sum[group_index].overflowing_add(value);
- self.sum[group_index] = new_sum;
+ fn update_single(&mut self, group_index: usize, value: i128) ->
DFResult<()> {
+ // For decimal sum, always check for overflow regardless of eval_mode
+ if !self.is_empty[group_index] && self.sum[group_index].is_none() {
+ return Ok(());
+ }
+
+ let running_sum = self.sum[group_index].unwrap_or(0);
+ let (new_sum, is_overflow) = running_sum.overflowing_add(value);
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision)
{
- // Overflow: set buffer accumulator to null
- self.is_not_null.set_bit(group_index, false);
+ if self.eval_mode == EvalMode::Ansi {
+ return
Err(DataFusionError::from(arithmetic_overflow_error("decimal")));
+ }
+ self.sum[group_index] = None;
+ } else {
+ self.sum[group_index] = Some(new_sum);
}
- }
-}
-
-fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) {
- if builder.len() < capacity {
- let additional = capacity - builder.len();
- builder.append_n(additional, true);
+ self.is_empty[group_index] = false;
+ Ok(())
}
}
@@ -320,22 +366,19 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
let values = values[0].as_primitive::<Decimal128Type>();
let data = values.values();
- // Update size for the accumulate states
- self.sum.resize(total_num_groups, 0);
- ensure_bit_capacity(&mut self.is_empty, total_num_groups);
- ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
+ self.resize_helper(total_num_groups);
let iter = group_indices.iter().zip(data.iter());
if values.null_count() == 0 {
for (&group_index, &value) in iter {
- self.update_single(group_index, value);
+ self.update_single(group_index, value)?;
}
} else {
for (idx, (&group_index, &value)) in iter.enumerate() {
if values.is_null(idx) {
continue;
}
- self.update_single(group_index, value);
+ self.update_single(group_index, value)?;
}
}
@@ -343,42 +386,65 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
}
fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
- // For each group:
- // 1. if `is_empty` is true, it means either there is no value or
all values for the group
- // are null, in this case we'll return null
- // 2. if `is_empty` is false, but `null_state` is true, it means
there's an overflow. In
- // non-ANSI mode Spark returns null.
- let result = emit_to.take_needed(&mut self.sum);
- result.iter().enumerate().for_each(|(i, &v)| {
- if !is_valid_decimal_precision(v, self.precision) {
- self.is_not_null.set_bit(i, false);
+ match emit_to {
+ EmitTo::All => {
+ let result =
+
Decimal128Array::from_iter(self.sum.iter().zip(self.is_empty.iter()).map(
+ |(&sum, &empty)| {
+ if empty {
+ None
+ } else {
+ match sum {
+ Some(v) if is_valid_decimal_precision(v,
self.precision) => {
+ Some(v)
+ }
+ _ => None,
+ }
+ }
+ },
+ ))
+ .with_data_type(self.result_type.clone());
+
+ self.sum.clear();
+ self.is_empty.clear();
+ Ok(Arc::new(result))
}
- });
-
- let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
- let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
- let x = (!&is_empty).bitand(&nulls);
-
- let result = Decimal128Array::new(result.into(),
Some(NullBuffer::new(x)))
- .with_data_type(self.result_type.clone());
-
- Ok(Arc::new(result))
+ EmitTo::First(n) => {
+ let result = Decimal128Array::from_iter(
+ self.sum
+ .drain(..n)
+ .zip(self.is_empty.drain(..n))
+ .map(|(sum, empty)| {
+ if empty {
+ None
+ } else {
+ match sum {
+ Some(v) if is_valid_decimal_precision(v,
self.precision) => {
+ Some(v)
+ }
+ _ => None,
+ }
+ }
+ }),
+ )
+ .with_data_type(self.result_type.clone());
+
+ Ok(Arc::new(result))
+ }
+ }
}
fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
- let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
- let nulls = Some(NullBuffer::new(nulls));
+ let sums = emit_to.take_needed(&mut self.sum);
- let sum = emit_to.take_needed(&mut self.sum);
- let sum = Decimal128Array::new(sum.into(), nulls.clone())
+ let sum_array = Decimal128Array::from_iter(sums.iter().copied())
.with_data_type(self.result_type.clone());
- let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
- let is_empty = BooleanArray::new(is_empty, None);
-
+ // For decimal sum, always return 2 state arrays regardless of
eval_mode
+ let is_empty = emit_to.take_needed(&mut self.is_empty);
Ok(vec![
- Arc::new(sum) as ArrayRef,
- Arc::new(is_empty) as ArrayRef,
+ Arc::new(sum_array),
+ Arc::new(BooleanArray::from(is_empty)),
])
}
@@ -389,57 +455,70 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> DFResult<()> {
+ assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+
+ self.resize_helper(total_num_groups);
+
+ // For decimal sum, always expect 2 arrays regardless of eval_mode
assert_eq!(
values.len(),
2,
"Expected two arrays: 'sum' and 'is_empty', but found {}",
values.len()
);
- assert!(opt_filter.is_none(), "opt_filter is not supported yet");
- // Make sure we have enough capacity for the additional groups
- self.sum.resize(total_num_groups, 0);
- ensure_bit_capacity(&mut self.is_empty, total_num_groups);
- ensure_bit_capacity(&mut self.is_not_null, total_num_groups);
-
- let that_sum = &values[0];
- let that_sum = that_sum.as_primitive::<Decimal128Type>();
- let that_is_empty = &values[1];
- let that_is_empty = that_is_empty
- .as_any()
- .downcast_ref::<BooleanArray>()
- .unwrap();
+ let that_sum = values[0].as_primitive::<Decimal128Type>();
+ let that_is_empty = values[1].as_boolean();
+
+ for (idx, &group_index) in group_indices.iter().enumerate() {
+ let that_sum_val = if that_sum.is_null(idx) {
+ None
+ } else {
+ Some(that_sum.value(idx))
+ };
- group_indices
- .iter()
- .enumerate()
- .for_each(|(idx, &group_index)| unsafe {
- let this_overflow = self.is_overflow(group_index);
- let that_is_empty = that_is_empty.value_unchecked(idx);
- let that_overflow = !that_is_empty && that_sum.is_null(idx);
- let is_overflow = this_overflow || that_overflow;
-
- // This part follows the logic in Spark:
- // `org.apache.spark.sql.catalyst.expressions.aggregate.Sum`
- self.is_not_null.set_bit(group_index, !is_overflow);
- self.is_empty.set_bit(
- group_index,
- self.is_empty.get_bit(group_index) && that_is_empty,
- );
- if !is_overflow {
- // .. otherwise, the sum value for this particular index
must not be null,
- // and thus we merge both values and update this sum.
- self.sum[group_index] += that_sum.value_unchecked(idx);
+ let that_is_empty_val = that_is_empty.value(idx);
+ let that_overflowed = !that_is_empty_val && that_sum_val.is_none();
+ let this_overflowed = !self.is_empty[group_index] &&
self.sum[group_index].is_none();
+
+ if that_overflowed || this_overflowed {
+ self.sum[group_index] = None;
+ self.is_empty[group_index] = false;
+ continue;
+ }
+
+ if that_is_empty_val {
+ continue;
+ }
+
+ if self.is_empty[group_index] {
+ self.sum[group_index] = that_sum_val;
+ self.is_empty[group_index] = false;
+ continue;
+ }
+
+ let left = self.sum[group_index].unwrap();
+ let right = that_sum_val.unwrap();
+ let (new_sum, is_overflow) = left.overflowing_add(right);
+
+ if is_overflow || !is_valid_decimal_precision(new_sum,
self.precision) {
+ if self.eval_mode == EvalMode::Ansi {
+ return
Err(DataFusionError::from(arithmetic_overflow_error("decimal")));
+ } else {
+ self.sum[group_index] = None;
+ self.is_empty[group_index] = false;
}
- });
+ } else {
+ self.sum[group_index] = Some(new_sum);
+ }
+ }
Ok(())
}
fn size(&self) -> usize {
- self.sum.capacity() * std::mem::size_of::<i128>()
- + self.is_empty.capacity() / 8
- + self.is_not_null.capacity() / 8
+ self.sum.capacity() * std::mem::size_of::<Option<i128>>()
+ + self.is_empty.capacity() * std::mem::size_of::<bool>()
}
}
@@ -463,7 +542,7 @@ mod tests {
#[test]
fn invalid_data_type() {
- assert!(SumDecimal::try_new(DataType::Int32).is_err());
+ assert!(SumDecimal::try_new(DataType::Int32,
EvalMode::Legacy).is_err());
}
#[tokio::test]
@@ -486,6 +565,7 @@ mod tests {
let aggregate_udf =
Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new(
data_type.clone(),
+ EvalMode::Legacy,
)?));
let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1])
diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
index d00bbf4df..8ab568dc8 100644
--- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
@@ -29,7 +29,8 @@ import org.apache.spark.sql.types.{ByteType, DataTypes,
DecimalType, IntegerType
import org.apache.comet.CometConf
import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT
import org.apache.comet.CometSparkSessionExtensions.withInfo
-import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
+import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto,
serializeDataType}
+import org.apache.comet.shims.CometEvalModeUtil
object CometMin extends CometAggregateExpressionSerde[Min] {
@@ -214,10 +215,10 @@ object CometSum extends
CometAggregateExpressionSerde[Sum] {
override def getSupportLevel(sum: Sum): SupportLevel = {
sum.evalMode match {
- case EvalMode.ANSI =>
- Incompatible(Some("ANSI mode is not supported"))
- case EvalMode.TRY =>
- Incompatible(Some("TRY mode is not supported"))
+ case EvalMode.ANSI if !sum.dataType.isInstanceOf[DecimalType] =>
+ Incompatible(Some("ANSI mode for non decimal inputs is not supported"))
+ case EvalMode.TRY if !sum.dataType.isInstanceOf[DecimalType] =>
+ Incompatible(Some("TRY mode for non decimal inputs is not supported"))
case _ =>
Compatible()
}
@@ -242,7 +243,7 @@ object CometSum extends CometAggregateExpressionSerde[Sum] {
val builder = ExprOuterClass.Sum.newBuilder()
builder.setChild(childExpr.get)
builder.setDatatype(dataType.get)
- builder.setFailOnError(sum.evalMode == EvalMode.ANSI)
+
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(sum.evalMode)))
Some(
ExprOuterClass.AggExpr
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 7e577c5fd..060579b2b 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -24,10 +24,11 @@ import scala.util.Random
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.optimizer.EliminateSorts
import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.functions.{avg, count_distinct, sum}
+import org.apache.spark.sql.functions.{avg, col, count_distinct, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
@@ -1471,6 +1472,168 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("ANSI support for decimal sum - null test") {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+ CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+ withParquetTable(
+ Seq(
+ (null.asInstanceOf[java.math.BigDecimal], "a"),
+ (null.asInstanceOf[java.math.BigDecimal], "b")),
+ "null_tbl") {
+ val res = sql("SELECT sum(_1) FROM null_tbl")
+ checkSparkAnswerAndOperator(res)
+ assert(res.collect() === Array(Row(null)))
+ }
+ }
+ }
+ }
+
+ test("ANSI support for try_sum decimal - null test") {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+ CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+ withParquetTable(
+ Seq(
+ (null.asInstanceOf[java.math.BigDecimal], "a"),
+ (null.asInstanceOf[java.math.BigDecimal], "b")),
+ "null_tbl") {
+ val res = sql("SELECT try_sum(_1) FROM null_tbl")
+ checkSparkAnswerAndOperator(res)
+ assert(res.collect() === Array(Row(null)))
+ }
+ }
+ }
+ }
+
+ test("ANSI support for decimal sum - null test (group by)") {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+ CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+ withParquetTable(
+ Seq(
+ (null.asInstanceOf[java.math.BigDecimal], "a"),
+ (null.asInstanceOf[java.math.BigDecimal], "a"),
+ (null.asInstanceOf[java.math.BigDecimal], "b"),
+ (null.asInstanceOf[java.math.BigDecimal], "b"),
+ (null.asInstanceOf[java.math.BigDecimal], "b")),
+ "tbl") {
+ val res = sql("SELECT _2, sum(_1) FROM tbl group by 1")
+ checkSparkAnswerAndOperator(res)
+ assert(res.orderBy(col("_2")).collect() === Array(Row("a", null),
Row("b", null)))
+ }
+ }
+ }
+ }
+
+ test("ANSI support for try_sum decimal - null test (group by)") {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+ CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+ withParquetTable(
+ Seq(
+ (null.asInstanceOf[java.math.BigDecimal], "a"),
+ (null.asInstanceOf[java.math.BigDecimal], "a"),
+ (null.asInstanceOf[java.math.BigDecimal], "b"),
+ (null.asInstanceOf[java.math.BigDecimal], "b"),
+ (null.asInstanceOf[java.math.BigDecimal], "b")),
+ "tbl") {
+ val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1")
+ checkSparkAnswerAndOperator(res)
+ assert(res.orderBy(col("_2")).collect() === Array(Row("a", null),
Row("b", null)))
+ }
+ }
+ }
+ }
+
+ protected def generateOverflowDecimalInputs: Seq[(java.math.BigDecimal,
Int)] = {
+ val maxDec38_0 = new java.math.BigDecimal("99999999999999999999")
+ (1 to 50).flatMap(_ => Seq((maxDec38_0, 1)))
+ }
+
+ test("ANSI support for decimal SUM function") {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+ CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+ withParquetTable(generateOverflowDecimalInputs, "tbl") {
+ val res = sql("SELECT SUM(_1) FROM tbl")
+ if (ansiEnabled) {
+ checkSparkAnswerMaybeThrows(res) match {
+ case (Some(sparkExc), Some(cometExc)) =>
+ assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+ assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+ case _ =>
+ fail("Exception should be thrown for decimal overflow in ANSI
mode")
+ }
+ } else {
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+ }
+ }
+ }
+
+ test("ANSI support for decimal SUM - GROUP BY") {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+ CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+ withParquetTable(generateOverflowDecimalInputs, "tbl") {
+ val res =
+ sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2)
+ if (ansiEnabled) {
+ checkSparkAnswerMaybeThrows(res) match {
+ case (Some(sparkExc), Some(cometExc)) =>
+ assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+ assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+ case _ =>
+ fail("Exception should be thrown for decimal overflow with
GROUP BY in ANSI mode")
+ }
+ } else {
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+ }
+ }
+ }
+
+ test("try_sum decimal overflow") {
+ withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) ->
"true") {
+ withParquetTable(generateOverflowDecimalInputs, "tbl") {
+ val res = sql("SELECT try_sum(_1) FROM tbl")
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+ }
+
+ test("try_sum decimal overflow - with GROUP BY") {
+ withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) ->
"true") {
+ withParquetTable(generateOverflowDecimalInputs, "tbl") {
+ val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY
_2").repartition(2, col("_2"))
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+ }
+
+ test("try_sum decimal partial overflow - with GROUP BY") {
+ withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) ->
"true") {
+ // Group 1 overflows, Group 2 succeeds
+ val data: Seq[(java.math.BigDecimal, Int)] =
generateOverflowDecimalInputs ++ Seq(
+ (new java.math.BigDecimal(300), 2),
+ (new java.math.BigDecimal(200), 2))
+ withParquetTable(data, "tbl") {
+ val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2")
+ // Group 1 should be NULL, Group 2 should be 500
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+ }
+
protected def checkSparkAnswerAndNumOfAggregates(query: String,
numAggregates: Int): Unit = {
val df = sql(query)
checkSparkAnswer(df)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]