martin-g commented on code in PR #2817:
URL: https://github.com/apache/datafusion-comet/pull/2817#discussion_r2560061131
##########
spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala:
##########
@@ -1471,6 +1471,42 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("AVG and try_avg - basic functionality") {
+ withParquetTable(
+ Seq(
+ (10L, 1),
+ (20L, 1),
+ (null.asInstanceOf[Long], 1),
+ (100L, 2),
+ (200L, 2),
Review Comment:
Idea: use bigger numbers which will lead to overflows during the
calculation, i.e. assert that INFINITY is the expected result
##########
native/core/src/execution/planner.rs:
##########
@@ -1893,19 +1893,24 @@ impl PhysicalPlanner {
let child = self.create_expr(expr.child.as_ref().unwrap(),
Arc::clone(&schema))?;
let datatype =
to_arrow_datatype(expr.datatype.as_ref().unwrap());
let input_datatype =
to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
+ let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
+
let builder = match datatype {
DataType::Decimal128(_, _) => {
let func =
AggregateUDF::new_from_impl(AvgDecimal::new(datatype, input_datatype));
Review Comment:
Is `eval_mode` ignored for AvgDecimal intentionally ?
https://github.com/apache/spark/blob/211dd995b221f135340375159672dcb77ef90ef4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L105-L113
shows that Spark uses it
##########
spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala:
##########
@@ -1471,6 +1471,42 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("AVG and try_avg - basic functionality") {
+ withParquetTable(
+ Seq(
+ (10L, 1),
+ (20L, 1),
+ (null.asInstanceOf[Long], 1),
+ (100L, 2),
+ (200L, 2),
+ (null.asInstanceOf[Long], 3)),
+ "tbl") {
+
+ Seq(true, false).foreach({ k =>
+ // without GROUP BY
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> k.toString) {
+ val res = sql("SELECT avg(_1) FROM tbl")
+ checkSparkAnswerAndOperator(res)
+ }
+
+ // with GROUP BY
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> k.toString) {
+ val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(res)
+ }
+
+ })
+
+ // try_avg without GROUP BY
+ val resTry = sql("SELECT try_avg(_1) FROM tbl")
+ checkSparkAnswerAndOperator(resTry)
+
+ // try_avg with GROUP BY
+ val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(resTryGroup)
+ }
+ }
+
Review Comment:
Add a test with Decimal128 ?!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]