This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8b5b3e95f87 [SPARK-39210][SQL] Provide query context of Decimal 
overflow in AVG when WSCG is off
8b5b3e95f87 is described below

commit 8b5b3e95f8761af97255cbcba35c3d836a419dba
Author: Gengliang Wang <gengli...@apache.org>
AuthorDate: Wed May 18 18:52:15 2022 +0800

    [SPARK-39210][SQL] Provide query context of Decimal overflow in AVG when 
WSCG is off
    
    ### What changes were proposed in this pull request?
    
    Similar to https://github.com/apache/spark/pull/36525, this PR provides 
runtime error query context for the Average expression when WSCG is off.
    
    ### Why are the changes needed?
    
    Enhance the runtime error query context of Average function. After changes, 
it works when the whole stage codegen is not available.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New UT
    
    Closes #36582 from gengliangwang/fixAvgContext.
    
    Authored-by: Gengliang Wang <gengli...@apache.org>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../sql/catalyst/expressions/aggregate/Average.scala     | 16 +++++++++++-----
 .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala  |  7 ++++---
 2 files changed, 15 insertions(+), 8 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 14914576091..343e27d863b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -81,11 +81,11 @@ abstract class AverageBase
 
   // If all input are nulls, count will be 0 and we will get null after the 
division.
   // We can't directly use `/` as it throws an exception under ansi mode.
-  protected def getEvaluateExpression = child.dataType match {
+  protected def getEvaluateExpression(queryContext: String) = child.dataType 
match {
     case _: DecimalType =>
       DecimalPrecision.decimalAndDecimal()(
         Divide(
-          CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], 
!useAnsiAdd),
+          CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], 
!useAnsiAdd, queryContext),
           count.cast(DecimalType.LongDecimal), failOnError = 
false)).cast(resultType)
     case _: YearMonthIntervalType =>
       If(EqualTo(count, Literal(0L)),
@@ -123,7 +123,7 @@ abstract class AverageBase
   since = "1.0.0")
 case class Average(
     child: Expression,
-    useAnsiAdd: Boolean = SQLConf.get.ansiEnabled) extends AverageBase {
+    useAnsiAdd: Boolean = SQLConf.get.ansiEnabled) extends AverageBase with 
SupportQueryContext {
   def this(child: Expression) = this(child, useAnsiAdd = 
SQLConf.get.ansiEnabled)
 
   override protected def withNewChildInternal(newChild: Expression): Average =
@@ -133,7 +133,13 @@ case class Average(
 
   override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions
 
-  override lazy val evaluateExpression: Expression = getEvaluateExpression
+  override lazy val evaluateExpression: Expression = 
getEvaluateExpression(queryContext)
+
+  override def initQueryContext(): String = if (useAnsiAdd) {
+    origin.context
+  } else {
+    ""
+  }
 }
 
 // scalastyle:off line.size.limit
@@ -192,7 +198,7 @@ case class TryAverage(child: Expression) extends 
AverageBase {
   }
 
   override lazy val evaluateExpression: Expression = {
-    addTryEvalIfNeeded(getEvaluateExpression)
+    addTryEvalIfNeeded(getEvaluateExpression(""))
   }
 
   override protected def withNewChildInternal(newChild: Expression): 
Expression =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index dde4a6c3110..72897d15302 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -4376,8 +4376,8 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
     }
   }
 
-  test("SPARK-39190, SPARK-39208: Query context of decimal overflow error 
should be serialized " +
-    "to executors when WSCG is off") {
+  test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow 
error should " +
+    "be serialized to executors when WSCG is off") {
     withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
       SQLConf.ANSI_ENABLED.key -> "true") {
       withTable("t") {
@@ -4385,7 +4385,8 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
         sql("insert into t values (6e37BD),(6e37BD)")
         Seq(
           "select d / 0.1 from t",
-          "select sum(d) from t").foreach { query =>
+          "select sum(d) from t",
+          "select avg(d) from t").foreach { query =>
           val msg = intercept[SparkException] {
             sql(query).collect()
           }.getMessage


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to