This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 4fb7fe2a406 [SPARK-39208][SQL] Fix query context bugs in decimal 
overflow under codegen mode
4fb7fe2a406 is described below

commit 4fb7fe2a40623526ed22311eac16c937450031e5
Author: Gengliang Wang <gengli...@apache.org>
AuthorDate: Tue May 17 22:31:30 2022 +0800

    [SPARK-39208][SQL] Fix query context bugs in decimal overflow under codegen 
mode
    
    ### What changes were proposed in this pull request?
    
    1. Fix logical bugs in adding query contexts as references under codegen 
mode. 
https://github.com/apache/spark/pull/36040/files#diff-4a70d2f3a4b99f58796b87192143f9838f4c4cf469f3313eb30af79c4e07153aR145
    The code
    ```
        val errorContextCode = if (nullOnOverflow) {
          ctx.addReferenceObj("errCtx", queryContext)
        } else {
          "\"\""
        }
    ```
    should be
    ```
        val errorContextCode = if (nullOnOverflow) {
          "\"\""
        } else {
          ctx.addReferenceObj("errCtx", queryContext)
        }
    ```
    
    2. Similar to https://github.com/apache/spark/pull/36557, make 
`CheckOverflowInSum` support query context when WSCG is not available.
    
    ### Why are the changes needed?
    
    Bugfix and enhancement in the query context of decimal expressions.
    ### Does this PR introduce _any_ user-facing change?
    
    No, the query context is not released yet.
    
    ### How was this patch tested?
    
    New UT
    
    Closes #36577 from gengliangwang/fixDecimalSumOverflow.
    
    Authored-by: Gengliang Wang <gengli...@apache.org>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
    (cherry picked from commit 191e535b975e5813719d3143797c9fcf86321368)
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../sql/catalyst/expressions/aggregate/Sum.scala    | 21 ++++++++++++++-------
 .../catalyst/expressions/decimalExpressions.scala   | 15 ++++++++-------
 .../expressions/DecimalExpressionSuite.scala        | 19 +++++++++++++++++++
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala  | 19 +++++++++++--------
 4 files changed, 52 insertions(+), 22 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index f2c6925b837..fa43565d807 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -143,10 +143,11 @@ abstract class SumBase(child: Expression) extends 
DeclarativeAggregate
    * So now, if ansi is enabled, then throw exception, if not then return null.
    * If sum is not null, then return the sum.
    */
-  protected def getEvaluateExpression: Expression = resultType match {
+  protected def getEvaluateExpression(queryContext: String): Expression = 
resultType match {
     case d: DecimalType =>
-      If(isEmpty, Literal.create(null, resultType),
-        CheckOverflowInSum(sum, d, !useAnsiAdd))
+      val checkOverflowInSum =
+        CheckOverflowInSum(sum, d, !useAnsiAdd, queryContext)
+      If(isEmpty, Literal.create(null, resultType), checkOverflowInSum)
     case _ if shouldTrackIsEmpty =>
       If(isEmpty, Literal.create(null, resultType), sum)
     case _ => sum
@@ -172,7 +173,7 @@ abstract class SumBase(child: Expression) extends 
DeclarativeAggregate
 case class Sum(
     child: Expression,
     useAnsiAdd: Boolean = SQLConf.get.ansiEnabled)
-  extends SumBase(child) {
+  extends SumBase(child) with SupportQueryContext {
   def this(child: Expression) = this(child, useAnsiAdd = 
SQLConf.get.ansiEnabled)
 
   override def shouldTrackIsEmpty: Boolean = resultType match {
@@ -186,7 +187,13 @@ case class Sum(
 
   override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions
 
-  override lazy val evaluateExpression: Expression = getEvaluateExpression
+  override lazy val evaluateExpression: Expression = 
getEvaluateExpression(queryContext)
+
+  override def initQueryContext(): String = if (useAnsiAdd) {
+    origin.context
+  } else {
+    ""
+  }
 }
 
 // scalastyle:off line.size.limit
@@ -243,9 +250,9 @@ case class TrySum(child: Expression) extends SumBase(child) 
{
 
   override lazy val evaluateExpression: Expression =
     if (useAnsiAdd) {
-      TryEval(getEvaluateExpression)
+      TryEval(getEvaluateExpression(""))
     } else {
-      getEvaluateExpression
+      getEvaluateExpression("")
     }
 
   override protected def withNewChildInternal(newChild: Expression): 
Expression =
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index 2cdd784ea4d..7d25df5ae9c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -142,9 +142,9 @@ case class CheckOverflow(
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
     val errorContextCode = if (nullOnOverflow) {
-      ctx.addReferenceObj("errCtx", queryContext)
-    } else {
       "\"\""
+    } else {
+      ctx.addReferenceObj("errCtx", queryContext)
     }
     nullSafeCodeGen(ctx, ev, eval => {
       // scalastyle:off line.size.limit
@@ -175,7 +175,8 @@ case class CheckOverflow(
 case class CheckOverflowInSum(
     child: Expression,
     dataType: DecimalType,
-    nullOnOverflow: Boolean) extends UnaryExpression {
+    nullOnOverflow: Boolean,
+    queryContext: String = "") extends UnaryExpression {
 
   override def nullable: Boolean = true
 
@@ -183,23 +184,23 @@ case class CheckOverflowInSum(
     val value = child.eval(input)
     if (value == null) {
       if (nullOnOverflow) null
-      else throw 
QueryExecutionErrors.overflowInSumOfDecimalError(origin.context)
+      else throw QueryExecutionErrors.overflowInSumOfDecimalError(queryContext)
     } else {
       value.asInstanceOf[Decimal].toPrecision(
         dataType.precision,
         dataType.scale,
         Decimal.ROUND_HALF_UP,
         nullOnOverflow,
-        origin.context)
+        queryContext)
     }
   }
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
     val childGen = child.genCode(ctx)
     val errorContextCode = if (nullOnOverflow) {
-      ctx.addReferenceObj("errCtx", origin.context)
-    } else {
       "\"\""
+    } else {
+      ctx.addReferenceObj("errCtx", queryContext)
     }
     val nullHandling = if (nullOnOverflow) {
       ""
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
index 36bc3db5804..1a8cd63aed0 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
@@ -18,6 +18,8 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
+import org.apache.spark.sql.catalyst.trees.Origin
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{Decimal, DecimalType, LongType}
 
@@ -83,4 +85,21 @@ class DecimalExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(CheckOverflow(
       Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), false), null)
   }
+
+  test("SPARK-39208: CheckOverflow & CheckOverflowInSum support query context 
in runtime errors") {
+    val d = Decimal(101, 3, 1)
+    val query = "select cast(d as decimal(4, 3)) from t"
+    val origin = Origin(
+      startIndex = Some(7),
+      stopIndex = Some(30),
+      sqlText = Some(query))
+
+    val expr1 = withOrigin(origin) {
+      CheckOverflow(Literal(d), DecimalType(4, 3), false)
+    }
+    checkExceptionInExpression[ArithmeticException](expr1, query)
+
+    val expr2 = CheckOverflowInSum(Literal(d), DecimalType(4, 3), false, 
queryContext = query)
+    checkExceptionInExpression[ArithmeticException](expr2, query)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index f6998fe5c1c..422ba7c2a9e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -4423,18 +4423,21 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
     }
   }
 
-  test("SPARK-39190: Query context of decimal overflow error should be 
serialized to executors" +
-    " when WSCG is off") {
+  test("SPARK-39190, SPARK-39208: Query context of decimal overflow error 
should be serialized " +
+    "to executors when WSCG is off") {
     withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
       SQLConf.ANSI_ENABLED.key -> "true") {
       withTable("t") {
         sql("create table t(d decimal(38, 0)) using parquet")
-        sql("insert into t values (2e37BD)")
-        val query = "select d / 0.1 from t"
-        val msg = intercept[SparkException] {
-          sql(query).collect()
-        }.getMessage
-        assert(msg.contains(query))
+        sql("insert into t values (6e37BD),(6e37BD)")
+        Seq(
+          "select d / 0.1 from t",
+          "select sum(d) from t").foreach { query =>
+          val msg = intercept[SparkException] {
+            sql(query).collect()
+          }.getMessage
+          assert(msg.contains(query))
+        }
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to