This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1cd72c6b4a0e [SPARK-49993][SQL] Improve error messages for Sum and 
Average
1cd72c6b4a0e is described below

commit 1cd72c6b4a0e4a1086047248a007361d176439a0
Author: Mihailo Milosevic <mihailo.milose...@databricks.com>
AuthorDate: Mon Oct 28 16:13:39 2024 +0100

    [SPARK-49993][SQL] Improve error messages for Sum and Average
    
    ### What changes were proposed in this pull request?
    This PR improves messages for ANSI related issues for Sum and Average.
    
    ### Why are the changes needed?
    The [PR](https://github.com/apache/spark/pull/48206) for removing ANSI 
suggestion in ARITHMETIC_OVERFLOW was getting too big, so this PR aims to split 
the work into multiple tasks.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, new suggestions are added to do try_sum and try_average
    
    ### How was this patch tested?
    Tests added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48499 from mihailom-db/fixSuggestions.
    
    Authored-by: Mihailo Milosevic <mihailo.milose...@databricks.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../apache/spark/sql/errors/ExecutionErrors.scala  |  6 ++---
 .../catalyst/expressions/decimalExpressions.scala  | 11 +++++----
 .../spark/sql/errors/QueryExecutionErrors.scala    |  7 ++++--
 .../apache/spark/sql/DataFrameAggregateSuite.scala | 27 +++++++++++-----------
 4 files changed, 29 insertions(+), 22 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
index 907c46f583cf..0ee1d7037d43 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
@@ -116,10 +116,10 @@ private[sql] trait ExecutionErrors extends 
DataTypeErrorsBase {
 
   def arithmeticOverflowError(
       message: String,
-      hint: String = "",
+      suggestedFunc: String = "",
       context: QueryContext = null): ArithmeticException = {
-    val alternative = if (hint.nonEmpty) {
-      s" Use '$hint' to tolerate overflow and return NULL instead."
+    val alternative = if (suggestedFunc.nonEmpty) {
+      s" Use '$suggestedFunc' to tolerate overflow and return NULL instead."
     } else ""
     new SparkArithmeticException(
       errorClass = "ARITHMETIC_OVERFLOW",
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index 5f13d397d1bf..f7509f124ab5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -166,7 +166,9 @@ case class CheckOverflowInSum(
     val value = child.eval(input)
     if (value == null) {
       if (nullOnOverflow) null
-      else throw QueryExecutionErrors.overflowInSumOfDecimalError(context)
+      else {
+        throw QueryExecutionErrors.overflowInSumOfDecimalError(context, 
suggestedFunc = "try_sum")
+      }
     } else {
       value.asInstanceOf[Decimal].toPrecision(
         dataType.precision,
@@ -183,7 +185,7 @@ case class CheckOverflowInSum(
     val nullHandling = if (nullOnOverflow) {
       ""
     } else {
-      s"throw 
QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);"
+      s"""throw 
QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode, 
"try_sum");"""
     }
     // scalastyle:off line.size.limit
     val code = code"""
@@ -270,7 +272,8 @@ case class DecimalDivideWithOverflowCheck(
       if (nullOnOverflow)  {
         null
       } else {
-        throw 
QueryExecutionErrors.overflowInSumOfDecimalError(getContextOrNull())
+        throw 
QueryExecutionErrors.overflowInSumOfDecimalError(getContextOrNull(),
+          suggestedFunc = "try_avg")
       }
     } else {
       val value2 = right.eval(input)
@@ -286,7 +289,7 @@ case class DecimalDivideWithOverflowCheck(
     val nullHandling = if (nullOnOverflow) {
       ""
     } else {
-      s"throw 
QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);"
+      s"""throw 
QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode, 
"try_avg");"""
     }
 
     val eval1 = left.genCode(ctx)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index efdc06d4cbd8..0aed8e604bd9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -295,8 +295,11 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase with ExecutionE
         "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)))
   }
 
-  def overflowInSumOfDecimalError(context: QueryContext): ArithmeticException 
= {
-    arithmeticOverflowError("Overflow in sum of decimals", context = context)
+  def overflowInSumOfDecimalError(
+      context: QueryContext,
+      suggestedFunc: String): ArithmeticException = {
+    arithmeticOverflowError("Overflow in sum of decimals", suggestedFunc = 
suggestedFunc,
+      context = context)
   }
 
   def overflowInIntegralDivideError(context: QueryContext): 
ArithmeticException = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 25f4d9f62354..7ebcb280def6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -2270,7 +2270,7 @@ class DataFrameAggregateSuite extends QueryTest
   }
 
   private def assertDecimalSumOverflow(
-      df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = {
+      df: DataFrame, ansiEnabled: Boolean, fnName: String, expectedAnswer: 
Row): Unit = {
     if (!ansiEnabled) {
       checkAnswer(df, expectedAnswer)
     } else {
@@ -2278,11 +2278,12 @@ class DataFrameAggregateSuite extends QueryTest
         df.collect()
       }
       assert(e.getMessage.contains("cannot be represented as Decimal") ||
-        e.getMessage.contains("Overflow in sum of decimals"))
+        e.getMessage.contains(s"Overflow in sum of decimals. Use 'try_$fnName' 
to tolerate " +
+          s"overflow and return NULL instead."))
     }
   }
 
-  def checkAggResultsForDecimalOverflow(aggFn: Column => Column): Unit = {
+  def checkAggResultsForDecimalOverflow(aggFn: Column => Column, fnName: 
String): Unit = {
     Seq("true", "false").foreach { wholeStageEnabled =>
       withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) 
{
         Seq(true, false).foreach { ansiEnabled =>
@@ -2306,27 +2307,27 @@ class DataFrameAggregateSuite extends QueryTest
               join(df, "intNum").agg(aggFn($"decNum"))
 
             val expectedAnswer = Row(null)
-            assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer)
+            assertDecimalSumOverflow(df2, ansiEnabled, fnName, expectedAnswer)
 
             val decStr = "1" + "0" * 19
             val d1 = spark.range(0, 12, 1, 1)
             val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as 
d")).agg(aggFn($"d"))
-            assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer)
+            assertDecimalSumOverflow(d2, ansiEnabled, fnName, expectedAnswer)
 
             val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
             val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as 
d")).agg(aggFn($"d"))
-            assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer)
+            assertDecimalSumOverflow(d4, ansiEnabled, fnName, expectedAnswer)
 
             val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as 
d"),
               
lit(1).as("key")).groupBy("key").agg(aggFn($"d").alias("aggd")).select($"aggd")
-            assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer)
+            assertDecimalSumOverflow(d5, ansiEnabled, fnName, expectedAnswer)
 
             val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as 
decimal(38,18)) as d"))
 
             val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), 
BigDecimal("9"* 20 + ".123")).
               toDF("d")
             assertDecimalSumOverflow(
-              nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled, 
expectedAnswer)
+              nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled, 
fnName, expectedAnswer)
 
             val df3 = Seq(
               (BigDecimal("10000000000000000000"), 1),
@@ -2344,9 +2345,9 @@ class DataFrameAggregateSuite extends QueryTest
               (BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum")
 
             val df6 = df3.union(df4).union(df5)
-            val df7 = df6.groupBy("intNum").agg(sum("decNum"), 
countDistinct("decNum")).
+            val df7 = df6.groupBy("intNum").agg(aggFn($"decNum"), 
countDistinct("decNum")).
               filter("intNum == 1")
-            assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2))
+            assertDecimalSumOverflow(df7, ansiEnabled, fnName, Row(1, null, 2))
           }
         }
       }
@@ -2354,11 +2355,11 @@ class DataFrameAggregateSuite extends QueryTest
   }
 
   test("SPARK-28067: Aggregate sum should not return wrong results for decimal 
overflow") {
-    checkAggResultsForDecimalOverflow(c => sum(c))
+    checkAggResultsForDecimalOverflow(c => sum(c), "sum")
   }
 
   test("SPARK-35955: Aggregate avg should not return wrong results for decimal 
overflow") {
-    checkAggResultsForDecimalOverflow(c => avg(c))
+    checkAggResultsForDecimalOverflow(c => avg(c), "avg")
   }
 
   test("SPARK-28224: Aggregate sum big decimal overflow") {
@@ -2369,7 +2370,7 @@ class DataFrameAggregateSuite extends QueryTest
     Seq(true, false).foreach { ansiEnabled =>
       withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
         val structDf = largeDecimals.select("a").agg(sum("a"))
-        assertDecimalSumOverflow(structDf, ansiEnabled, Row(null))
+        assertDecimalSumOverflow(structDf, ansiEnabled, "sum", Row(null))
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to