This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 9321be6 fix: Comet should not translate try_sum to native sum
expression (#277)
9321be6 is described below
commit 9321be6e24707f5800c51bc215003d30521ccc29
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Apr 17 12:24:18 2024 -0700
fix: Comet should not translate try_sum to native sum expression (#277)
* fix: Comet should not translate try_sum to native sum expression
* For Spark 3.2 and 3.3
* Update
spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala
Co-authored-by: advancedxy <[email protected]>
* Fix format
---------
Co-authored-by: advancedxy <[email protected]>
---
.../org/apache/comet/serde/QueryPlanSerde.scala | 4 ++--
.../org/apache/comet/shims/ShimQueryPlanSerde.scala | 18 ++++++++++++++++++
.../scala/org/apache/comet/exec/CometExecSuite.scala | 20 ++++++++++++++++++--
3 files changed, 38 insertions(+), 4 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 26fc708..172a5b5 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -202,7 +202,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
inputs: Seq[Attribute],
binding: Boolean): Option[AggExpr] = {
aggExpr.aggregateFunction match {
- case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) =>
+ case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) &&
isLegacyMode(s) =>
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(s.dataType)
@@ -220,7 +220,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
} else {
None
}
- case s @ Average(child, _) if avgDataTypeSupported(s.dataType) =>
+ case s @ Average(child, _) if avgDataTypeSupported(s.dataType) &&
isLegacyMode(s) =>
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(s.dataType)
diff --git
a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala
index 7bdf2c0..b92d3fc 100644
--- a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala
@@ -45,6 +45,24 @@ trait ShimQueryPlanSerde {
}
}
+ // TODO: delete after drop Spark 3.2/3.3 support
+ // This method is used to check if the aggregate function is in legacy mode.
+ // EvalMode is an enum object in Spark 3.4.
+ def isLegacyMode(aggregate: DeclarativeAggregate): Boolean = {
+ val evalMode = aggregate.getClass.getDeclaredMethods
+ .flatMap(m =>
+ m.getName match {
+ case "evalMode" => Some(m.invoke(aggregate))
+ case _ => None
+ })
+
+ if (evalMode.isEmpty) {
+ true
+ } else {
+ "legacy".equalsIgnoreCase(evalMode.head.toString)
+ }
+ }
+
// TODO: delete after drop Spark 3.2 support
def isBloomFilterMightContain(binary: BinaryExpression): Boolean = {
binary.getClass.getName ==
"org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain"
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 0bb21ab..a8b05cc 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -19,6 +19,8 @@
package org.apache.comet.exec
+import java.time.{Duration, Period}
+
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random
@@ -38,13 +40,13 @@ import org.apache.spark.sql.execution.{CollectLimitExec,
ProjectExec, SQLExecuti
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec,
CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
-import org.apache.spark.sql.functions.{date_add, expr, sum}
+import org.apache.spark.sql.functions.{col, date_add, expr, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
import org.apache.spark.unsafe.types.UTF8String
import org.apache.comet.CometConf
-import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
+import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus,
isSpark34Plus}
class CometExecSuite extends CometTestBase {
import testImplicits._
@@ -58,6 +60,20 @@ class CometExecSuite extends CometTestBase {
}
}
+ test("try_sum should return null if overflow happens before merging") {
+ assume(isSpark33Plus, "try_sum is available in Spark 3.3+")
+ val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v")
+ val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
+ .map(Period.ofMonths)
+ .toDF("v")
+ val dayTimeDf = Seq(106751991L, 106751991L, 2L)
+ .map(Duration.ofDays)
+ .toDF("v")
+ Seq(longDf, yearMonthDf, dayTimeDf).foreach { df =>
+ checkSparkAnswer(df.repartitionByRange(2,
col("v")).selectExpr("try_sum(v)"))
+ }
+ }
+
test("Fix corrupted AggregateMode when transforming plan parameters") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "table") {
val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))