This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 9321be6  fix: Comet should not translate try_sum to native sum 
expression (#277)
9321be6 is described below

commit 9321be6e24707f5800c51bc215003d30521ccc29
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Wed Apr 17 12:24:18 2024 -0700

    fix: Comet should not translate try_sum to native sum expression (#277)
    
    * fix: Comet should not translate try_sum to native sum expression
    
    * For Spark 3.2 and 3.3
    
    * Update 
spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala
    
    Co-authored-by: advancedxy <xian...@apache.org>
    
    * Fix format
    
    ---------
    
    Co-authored-by: advancedxy <xian...@apache.org>
---
 .../org/apache/comet/serde/QueryPlanSerde.scala      |  4 ++--
 .../org/apache/comet/shims/ShimQueryPlanSerde.scala  | 18 ++++++++++++++++++
 .../scala/org/apache/comet/exec/CometExecSuite.scala | 20 ++++++++++++++++++--
 3 files changed, 38 insertions(+), 4 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 26fc708..172a5b5 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -202,7 +202,7 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
       inputs: Seq[Attribute],
       binding: Boolean): Option[AggExpr] = {
     aggExpr.aggregateFunction match {
-      case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) =>
+      case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) && 
isLegacyMode(s) =>
         val childExpr = exprToProto(child, inputs, binding)
         val dataType = serializeDataType(s.dataType)
 
@@ -220,7 +220,7 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
         } else {
           None
         }
-      case s @ Average(child, _) if avgDataTypeSupported(s.dataType) =>
+      case s @ Average(child, _) if avgDataTypeSupported(s.dataType) && 
isLegacyMode(s) =>
         val childExpr = exprToProto(child, inputs, binding)
         val dataType = serializeDataType(s.dataType)
 
diff --git 
a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala
index 7bdf2c0..b92d3fc 100644
--- a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala
@@ -45,6 +45,24 @@ trait ShimQueryPlanSerde {
     }
   }
 
+  // TODO: delete after drop Spark 3.2/3.3 support
+  // This method is used to check if the aggregate function is in legacy mode.
+  // EvalMode is an enum object in Spark 3.4.
+  def isLegacyMode(aggregate: DeclarativeAggregate): Boolean = {
+    val evalMode = aggregate.getClass.getDeclaredMethods
+      .flatMap(m =>
+        m.getName match {
+          case "evalMode" => Some(m.invoke(aggregate))
+          case _ => None
+        })
+
+    if (evalMode.isEmpty) {
+      true
+    } else {
+      "legacy".equalsIgnoreCase(evalMode.head.toString)
+    }
+  }
+
   // TODO: delete after drop Spark 3.2 support
   def isBloomFilterMightContain(binary: BinaryExpression): Boolean = {
     binary.getClass.getName == 
"org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain"
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 0bb21ab..a8b05cc 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -19,6 +19,8 @@
 
 package org.apache.comet.exec
 
+import java.time.{Duration, Period}
+
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.util.Random
@@ -38,13 +40,13 @@ import org.apache.spark.sql.execution.{CollectLimitExec, 
ProjectExec, SQLExecuti
 import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
 import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, 
CartesianProductExec, SortMergeJoinExec}
 import org.apache.spark.sql.execution.window.WindowExec
-import org.apache.spark.sql.functions.{date_add, expr, sum}
+import org.apache.spark.sql.functions.{col, date_add, expr, sum}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
 import org.apache.spark.unsafe.types.UTF8String
 
 import org.apache.comet.CometConf
-import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
+import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, 
isSpark34Plus}
 
 class CometExecSuite extends CometTestBase {
   import testImplicits._
@@ -58,6 +60,20 @@ class CometExecSuite extends CometTestBase {
     }
   }
 
+  test("try_sum should return null if overflow happens before merging") {
+    assume(isSpark33Plus, "try_sum is available in Spark 3.3+")
+    val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v")
+    val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
+      .map(Period.ofMonths)
+      .toDF("v")
+    val dayTimeDf = Seq(106751991L, 106751991L, 2L)
+      .map(Duration.ofDays)
+      .toDF("v")
+    Seq(longDf, yearMonthDf, dayTimeDf).foreach { df =>
+      checkSparkAnswer(df.repartitionByRange(2, 
col("v")).selectExpr("try_sum(v)"))
+    }
+  }
+
   test("Fix corrupted AggregateMode when transforming plan parameters") {
     withParquetTable((0 until 5).map(i => (i, i + 1)), "table") {
       val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))

Reply via email to