This is an automated email from the ASF dual-hosted git repository.
parthc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new bc5fa0084 Test: Add test coverage and documentation for
SumDecimal/AvgDecimal nullability behavior (#3766)
bc5fa0084 is described below
commit bc5fa008410c85e2f9d5bec658534ade918169b3
Author: Vipul Vaibhaw <[email protected]>
AuthorDate: Wed Mar 25 21:58:35 2026 +0530
Test: Add test coverage and documentation for SumDecimal/AvgDecimal
nullability behavior (#3766)
* fix: Make SumDecimal and AvgDecimal nullability depend on ANSI mode and
input nullability
In Spark, Sum.nullable and Average.nullable both return true irrespective
of ANSI mode.
---
native/spark-expr/src/agg_funcs/avg_decimal.rs | 6 ++++
native/spark-expr/src/agg_funcs/sum_decimal.rs | 3 +-
.../apache/comet/exec/CometAggregateSuite.scala | 39 ++++++++++++++++++++++
3 files changed, 47 insertions(+), 1 deletion(-)
diff --git a/native/spark-expr/src/agg_funcs/avg_decimal.rs
b/native/spark-expr/src/agg_funcs/avg_decimal.rs
index 773ddea05..08e335f42 100644
--- a/native/spark-expr/src/agg_funcs/avg_decimal.rs
+++ b/native/spark-expr/src/agg_funcs/avg_decimal.rs
@@ -207,6 +207,12 @@ impl AggregateUDFImpl for AvgDecimal {
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
avg_return_type(self.name(), &arg_types[0])
}
+
+ fn is_nullable(&self) -> bool {
+ // In Spark, Sum.nullable and Average.nullable both return true
irrespective of ANSI mode.
+ // AvgDecimal is always nullable because overflows can cause null
values.
+ true
+ }
}
/// An accumulator to compute the average for decimals
diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs
b/native/spark-expr/src/agg_funcs/sum_decimal.rs
index bf5569b00..56a735493 100644
--- a/native/spark-expr/src/agg_funcs/sum_decimal.rs
+++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs
@@ -164,7 +164,8 @@ impl AggregateUDFImpl for SumDecimal {
}
fn is_nullable(&self) -> bool {
- // SumDecimal is always nullable because overflows can cause null
values
+ // In Spark, Sum.nullable and Average.nullable both return true
irrespective of ANSI mode.
+ // SumDecimal is always nullable because overflows can cause null
values.
true
}
}
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 9426d1c84..be60f4aae 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -1917,6 +1917,45 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("SumDecimal and AvgDecimal nullable should always be true") {
+ // SumDecimal and AvgDecimal currently hardcode nullable=true.
+ // This matches Spark's Sum.nullable and Average.nullable which always
return true,
+ // regardless of ANSI mode or input nullability.
+ val nonNullableData: Seq[(java.math.BigDecimal, Int)] = Seq(
+ (new java.math.BigDecimal("10.00"), 1),
+ (new java.math.BigDecimal("20.00"), 1),
+ (new java.math.BigDecimal("30.00"), 2))
+
+ val nullableData: Seq[(java.math.BigDecimal, Int)] = Seq(
+ (new java.math.BigDecimal("10.00"), 1),
+ (null.asInstanceOf[java.math.BigDecimal], 1),
+ (new java.math.BigDecimal("30.00"), 2))
+
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+ withParquetTable(nonNullableData, "tbl") {
+ val sumRes = sql("SELECT _2, sum(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(sumRes)
+ assert(sumRes.schema.fields(1).nullable == true)
+
+ val avgRes = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(avgRes)
+ assert(avgRes.schema.fields(1).nullable == true)
+ }
+
+ withParquetTable(nullableData, "tbl") {
+ val sumRes = sql("SELECT _2, sum(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(sumRes)
+ assert(sumRes.schema.fields(1).nullable == true)
+
+ val avgRes = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(avgRes)
+ assert(avgRes.schema.fields(1).nullable == true)
+ }
+ }
+ }
+ }
+
protected def checkSparkAnswerAndNumOfAggregates(query: String,
numAggregates: Int): Unit = {
val df = sql(query)
checkSparkAnswer(df)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]