parthchandra commented on code in PR #3766:
URL: https://github.com/apache/datafusion-comet/pull/3766#discussion_r2978374084
##########
spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala:
##########
@@ -1917,6 +1917,43 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("SumDecimal and AvgDecimal nullability depends on ANSI mode and input
nullability") {
+ // Non-nullable input: in ANSI mode the result should be non-nullable
because overflows
+ // throw exceptions instead of producing nulls.
+ val nonNullableData: Seq[(java.math.BigDecimal, Int)] = Seq(
+ (new java.math.BigDecimal("10.00"), 1),
+ (new java.math.BigDecimal("20.00"), 1),
+ (new java.math.BigDecimal("30.00"), 2))
+
+ // Nullable input: result should always be nullable regardless of ANSI
mode.
+ val nullableData: Seq[(java.math.BigDecimal, Int)] = Seq(
+ (new java.math.BigDecimal("10.00"), 1),
+ (null.asInstanceOf[java.math.BigDecimal], 1),
+ (new java.math.BigDecimal("30.00"), 2))
+
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+ // Test SUM with non-nullable input
+ withParquetTable(nonNullableData, "tbl") {
+ val sumRes = sql("SELECT _2, sum(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(sumRes)
+
+ val avgRes = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(avgRes)
+ }
+
+ // Test SUM/AVG with nullable input
+ withParquetTable(nullableData, "tbl") {
+ val sumRes = sql("SELECT _2, sum(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(sumRes)
+
+ val avgRes = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(avgRes)
+ }
Review Comment:
We can add an additional check to ensure that the nullability is compatible
with Spark, for example :
`assert(sumRes.schema.fields(1).nullable == true)`
##########
native/spark-expr/src/agg_funcs/avg_decimal.rs:
##########
@@ -207,6 +213,17 @@ impl AggregateUDFImpl for AvgDecimal {
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
avg_return_type(self.name(), &arg_types[0])
}
+
+ fn is_nullable(&self) -> bool {
Review Comment:
In Spark, `Sum.nullable` and Average.nullable` both return `true`
irrespective of Ansi mode.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]