This is an automated email from the ASF dual-hosted git repository.

parthc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new bc5fa0084 Test: Add test coverage and documentation for 
SumDecimal/AvgDecimal nullability behavior (#3766)
bc5fa0084 is described below

commit bc5fa008410c85e2f9d5bec658534ade918169b3
Author: Vipul Vaibhaw <[email protected]>
AuthorDate: Wed Mar 25 21:58:35 2026 +0530

    Test: Add test coverage and documentation for SumDecimal/AvgDecimal 
nullability behavior (#3766)
    
    * fix: Make SumDecimal and AvgDecimal nullability depend on ANSI mode and 
input nullability
    
    In Spark, Sum.nullable and Average.nullable both return true irrespective
    of ANSI mode.
---
 native/spark-expr/src/agg_funcs/avg_decimal.rs     |  6 ++++
 native/spark-expr/src/agg_funcs/sum_decimal.rs     |  3 +-
 .../apache/comet/exec/CometAggregateSuite.scala    | 39 ++++++++++++++++++++++
 3 files changed, 47 insertions(+), 1 deletion(-)

diff --git a/native/spark-expr/src/agg_funcs/avg_decimal.rs 
b/native/spark-expr/src/agg_funcs/avg_decimal.rs
index 773ddea05..08e335f42 100644
--- a/native/spark-expr/src/agg_funcs/avg_decimal.rs
+++ b/native/spark-expr/src/agg_funcs/avg_decimal.rs
@@ -207,6 +207,12 @@ impl AggregateUDFImpl for AvgDecimal {
     fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
         avg_return_type(self.name(), &arg_types[0])
     }
+
+    fn is_nullable(&self) -> bool {
+        // In Spark, Sum.nullable and Average.nullable both return true 
irrespective of ANSI mode.
+        // AvgDecimal is always nullable because overflows can cause null 
values.
+        true
+    }
 }
 
 /// An accumulator to compute the average for decimals
diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs 
b/native/spark-expr/src/agg_funcs/sum_decimal.rs
index bf5569b00..56a735493 100644
--- a/native/spark-expr/src/agg_funcs/sum_decimal.rs
+++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs
@@ -164,7 +164,8 @@ impl AggregateUDFImpl for SumDecimal {
     }
 
     fn is_nullable(&self) -> bool {
-        // SumDecimal is always nullable because overflows can cause null 
values
+        // In Spark, Sum.nullable and Average.nullable both return true 
irrespective of ANSI mode.
+        // SumDecimal is always nullable because overflows can cause null 
values.
         true
     }
 }
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 9426d1c84..be60f4aae 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -1917,6 +1917,45 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("SumDecimal and AvgDecimal nullable should always be true") {
+    // SumDecimal and AvgDecimal currently hardcode nullable=true.
+    // This matches Spark's Sum.nullable and Average.nullable which always 
return true,
+    // regardless of ANSI mode or input nullability.
+    val nonNullableData: Seq[(java.math.BigDecimal, Int)] = Seq(
+      (new java.math.BigDecimal("10.00"), 1),
+      (new java.math.BigDecimal("20.00"), 1),
+      (new java.math.BigDecimal("30.00"), 2))
+
+    val nullableData: Seq[(java.math.BigDecimal, Int)] = Seq(
+      (new java.math.BigDecimal("10.00"), 1),
+      (null.asInstanceOf[java.math.BigDecimal], 1),
+      (new java.math.BigDecimal("30.00"), 2))
+
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+        withParquetTable(nonNullableData, "tbl") {
+          val sumRes = sql("SELECT _2, sum(_1) FROM tbl GROUP BY _2")
+          checkSparkAnswerAndOperator(sumRes)
+          assert(sumRes.schema.fields(1).nullable == true)
+
+          val avgRes = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
+          checkSparkAnswerAndOperator(avgRes)
+          assert(avgRes.schema.fields(1).nullable == true)
+        }
+
+        withParquetTable(nullableData, "tbl") {
+          val sumRes = sql("SELECT _2, sum(_1) FROM tbl GROUP BY _2")
+          checkSparkAnswerAndOperator(sumRes)
+          assert(sumRes.schema.fields(1).nullable == true)
+
+          val avgRes = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
+          checkSparkAnswerAndOperator(avgRes)
+          assert(avgRes.schema.fields(1).nullable == true)
+        }
+      }
+    }
+  }
+
   protected def checkSparkAnswerAndNumOfAggregates(query: String, 
numAggregates: Int): Unit = {
     val df = sql(query)
     checkSparkAnswer(df)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to