This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 63f20c526be [SPARK-39166][SQL] Provide runtime error query context for 
binary arithmetic when WSCG is off
63f20c526be is described below

commit 63f20c526bed8346fe3399aff6c0b2f7a78b441e
Author: Gengliang Wang <[email protected]>
AuthorDate: Fri May 13 13:50:27 2022 +0800

    [SPARK-39166][SQL] Provide runtime error query context for binary 
arithmetic when WSCG is off
    
    ### What changes were proposed in this pull request?
    
    Currently, for most of the cases, the project 
https://issues.apache.org/jira/browse/SPARK-38615 is able to show where the 
runtime errors happen within the original query.
    However, after trying on production, I found that the following queries 
won't show where the divide by 0 error happens
    ```
    create table aggTest(i int, j int, k int, d date) using parquet
    insert into aggTest values(1, 2, 0, date'2022-01-01')
    select sum(j)/sum(k),percentile(i, 0.9) from aggTest group by d
    ```
    With `percentile` function in the query, the plan can't execute with whole 
stage codegen. Thus the child plan of `Project` is serialized to executors for 
execution, from ProjectExec:
    ```
      protected override def doExecute(): RDD[InternalRow] = {
        child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
          val project = UnsafeProjection.create(projectList, child.output)
          project.initialize(index)
          iter.map(project)
        }
      }
    ```
    
    Note that the `TreeNode.origin` is not serialized to executors since 
`TreeNode` doesn't extend the trait `Serializable`, which results in an empty 
query context on errors. For more details, please read 
https://issues.apache.org/jira/browse/SPARK-39140
    
    A dummy fix is to make `TreeNode` extend the trait `Serializable`. However, 
it can be performance regression if the query text is long (every `TreeNode` 
carries it for serialization).
    A better fix is to introduce a new trait `SupportQueryContext` and 
materialize the truncated query context for special expressions. This PR 
targets on binary arithmetic expressions only. I will create follow-ups for the 
remaining expressions which support runtime error query context.
    
    ### Why are the changes needed?
    
    Improve the error context framework and make sure it works when whole stage 
codegen is not available.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests
    
    Closes #36525 from gengliangwang/serializeContext.
    
    Lead-authored-by: Gengliang Wang <[email protected]>
    Co-authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
    (cherry picked from commit e336567c8a9704b500efecd276abaf5bd3988679)
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../sql/catalyst/expressions/Expression.scala      | 22 +++++++++++++
 .../sql/catalyst/expressions/arithmetic.scala      | 38 +++++++++++++---------
 .../expressions/ArithmeticExpressionSuite.scala    | 25 +++++++-------
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 20 ++++++++++++
 4 files changed, 77 insertions(+), 28 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index e1d8c2e43e2..d620c5d7392 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -588,6 +588,28 @@ abstract class UnaryExpression extends Expression with 
UnaryLike[Expression] {
   }
 }
 
+/**
+ * An expression with SQL query context. The context string can be serialized 
from the Driver
+ * to executors. It will also be kept after rule transforms.
+ */
+trait SupportQueryContext extends Expression with Serializable {
+  protected var queryContext: String = initQueryContext()
+
+  def initQueryContext(): String
+
+  // Note: Even though query contexts are serialized to executors, it will be 
regenerated from an
+  //       empty "Origin" during rule transforms since "Origin"s are not 
serialized to executors
+  //       for better performance. Thus, we need to copy the original query 
context during
+  //       transforms. The query context string is considered as a "tag" on 
the expression here.
+  override def copyTagsFrom(other: Expression): Unit = {
+    other match {
+      case s: SupportQueryContext =>
+        queryContext = s.queryContext
+      case _ =>
+    }
+    super.copyTagsFrom(other)
+  }
+}
 
 object UnaryExpression {
   def unapply(e: UnaryExpression): Option[Expression] = Some(e.child)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index c6d66d8e607..153187f9e30 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, 
TypeCheckResult, TypeCoercion}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, 
TreePattern,
-  UNARY_POSITIVE}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, 
TreePattern, UNARY_POSITIVE}
 import org.apache.spark.sql.catalyst.util.{IntervalUtils, MathUtils, TypeUtils}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
@@ -209,7 +208,8 @@ case class Abs(child: Expression, failOnError: Boolean = 
SQLConf.get.ansiEnabled
   override protected def withNewChildInternal(newChild: Expression): Abs = 
copy(child = newChild)
 }
 
-abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
+abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant
+    with SupportQueryContext {
 
   protected val failOnError: Boolean
 
@@ -219,6 +219,14 @@ abstract class BinaryArithmetic extends BinaryOperator 
with NullIntolerant {
 
   override lazy val resolved: Boolean = childrenResolved && 
checkInputDataTypes().isSuccess
 
+  override def initQueryContext(): String = {
+    if (failOnError) {
+      origin.context
+    } else {
+      ""
+    }
+  }
+
   /** Name of the function for this expression on a [[Decimal]] type. */
   def decimalMethod: String =
     throw 
QueryExecutionErrors.notOverrideExpectedMethodsError("BinaryArithmetics",
@@ -270,7 +278,7 @@ abstract class BinaryArithmetic extends BinaryOperator with 
NullIntolerant {
       })
     case IntegerType | LongType if failOnError && exactMathMethod.isDefined =>
       nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
-        val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+        val errorContext = ctx.addReferenceObj("errCtx", queryContext)
         val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
         s"""
            |${ev.value} = $mathUtils.${exactMathMethod.get}($eval1, $eval2, 
$errorContext);
@@ -331,9 +339,9 @@ case class Add(
     case _: YearMonthIntervalType =>
       MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
     case _: IntegerType if failOnError =>
-      MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], 
origin.context)
+      MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], 
queryContext)
     case _: LongType if failOnError =>
-      MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], 
origin.context)
+      MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], 
queryContext)
     case _ => numeric.plus(input1, input2)
   }
 
@@ -381,9 +389,9 @@ case class Subtract(
     case _: YearMonthIntervalType =>
       MathUtils.subtractExact(input1.asInstanceOf[Int], 
input2.asInstanceOf[Int])
     case _: IntegerType if failOnError =>
-      MathUtils.subtractExact(input1.asInstanceOf[Int], 
input2.asInstanceOf[Int], origin.context)
+      MathUtils.subtractExact(input1.asInstanceOf[Int], 
input2.asInstanceOf[Int], queryContext)
     case _: LongType if failOnError =>
-      MathUtils.subtractExact(input1.asInstanceOf[Long], 
input2.asInstanceOf[Long], origin.context)
+      MathUtils.subtractExact(input1.asInstanceOf[Long], 
input2.asInstanceOf[Long], queryContext)
     case _ => numeric.minus(input1, input2)
   }
 
@@ -418,9 +426,9 @@ case class Multiply(
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = 
dataType match {
     case _: IntegerType if failOnError =>
-      MathUtils.multiplyExact(input1.asInstanceOf[Int], 
input2.asInstanceOf[Int], origin.context)
+      MathUtils.multiplyExact(input1.asInstanceOf[Int], 
input2.asInstanceOf[Int], queryContext)
     case _: LongType if failOnError =>
-      MathUtils.multiplyExact(input1.asInstanceOf[Long], 
input2.asInstanceOf[Long], origin.context)
+      MathUtils.multiplyExact(input1.asInstanceOf[Long], 
input2.asInstanceOf[Long], queryContext)
     case _ => numeric.times(input1, input2)
   }
 
@@ -457,10 +465,10 @@ trait DivModLike extends BinaryArithmetic {
       } else {
         if (isZero(input2)) {
           // when we reach here, failOnError must be true.
-          throw QueryExecutionErrors.divideByZeroError(origin.context)
+          throw QueryExecutionErrors.divideByZeroError(queryContext)
         }
         if (checkDivideOverflow && input1 == Long.MinValue && input2 == -1) {
-          throw 
QueryExecutionErrors.overflowInIntegralDivideError(origin.context)
+          throw 
QueryExecutionErrors.overflowInIntegralDivideError(queryContext)
         }
         evalOperation(input1, input2)
       }
@@ -487,7 +495,7 @@ trait DivModLike extends BinaryArithmetic {
     } else {
       s"($javaType)(${eval1.value} $symbol ${eval2.value})"
     }
-    lazy val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+    lazy val errorContext = ctx.addReferenceObj("errCtx", queryContext)
     val checkIntegralDivideOverflow = if (checkDivideOverflow) {
       s"""
         |if (${eval1.value} == ${Long.MinValue}L && ${eval2.value} == -1)
@@ -743,7 +751,7 @@ case class Pmod(
       } else {
         if (isZero(input2)) {
           // when we reach here, failOnError must bet true.
-          throw QueryExecutionErrors.divideByZeroError(origin.context)
+          throw QueryExecutionErrors.divideByZeroError(queryContext)
         }
         input1 match {
           case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer])
@@ -768,7 +776,7 @@ case class Pmod(
     }
     val remainder = ctx.freshName("remainder")
     val javaType = CodeGenerator.javaType(dataType)
-    lazy val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+    lazy val errorContext = ctx.addReferenceObj("errCtx", queryContext)
     val result = dataType match {
       case DecimalType.Fixed(_, _) =>
         val decimalAdd = "$plus"
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 87777991cb9..e76ff0b4390 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -359,19 +359,18 @@ class ArithmeticExpressionSuite extends SparkFunSuite 
with ExpressionEvalHelper
   }
 
   test("Remainder/Pmod: exception should contain SQL text context") {
-    Seq(
-      Remainder(Literal(1L, LongType), Literal(0L, LongType), failOnError = 
true),
-      Pmod(Literal(1L, LongType), Literal(0L, LongType), failOnError = 
true)).foreach { expr =>
-        val query = s"1L ${expr.symbol} 0L"
-        val o = Origin(
-          line = Some(1),
-          startPosition = Some(7),
-          startIndex = Some(7),
-          stopIndex = Some(7 + query.length -1),
-          sqlText = Some(s"select $query"))
-        withOrigin(o) {
-          checkExceptionInExpression[ArithmeticException](expr, EmptyRow, 
query)
-        }
+    Seq(("%", Remainder), ("pmod", Pmod)).foreach { case (symbol, exprBuilder) 
=>
+      val query = s"1L $symbol 0L"
+      val o = Origin(
+        line = Some(1),
+        startPosition = Some(7),
+        startIndex = Some(7),
+        stopIndex = Some(7 + query.length -1),
+        sqlText = Some(s"select $query"))
+      withOrigin(o) {
+        val expression = exprBuilder(Literal(1L, LongType), Literal(0L, 
LongType), true)
+        checkExceptionInExpression[ArithmeticException](expression, EmptyRow, 
query)
+      }
     }
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index cba677da19a..f099d3c015c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -4359,6 +4359,26 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
     }
   }
 
+  test("SPARK-39166: Query context should be serialized to executors when WSCG 
is off") {
+    withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+      SQLConf.ANSI_ENABLED.key -> "true") {
+      withTable("t") {
+        sql("create table t(i int, j int) using parquet")
+        sql("insert into t values(2147483647, 10)")
+        Seq(
+          "select i + j from t",
+          "select -i - j from t",
+          "select i * j from t",
+          "select i / (j - 10) from t").foreach { query =>
+          val msg = intercept[SparkException] {
+            sql(query).collect()
+          }.getMessage
+          assert(msg.contains(query))
+        }
+      }
+    }
+  }
+
   test("SPARK-38589: try_avg should return null if overflow happens before 
merging") {
     val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
       .map(Period.ofMonths)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to