This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push:
new 63f20c526be [SPARK-39166][SQL] Provide runtime error query context for
binary arithmetic when WSCG is off
63f20c526be is described below
commit 63f20c526bed8346fe3399aff6c0b2f7a78b441e
Author: Gengliang Wang <[email protected]>
AuthorDate: Fri May 13 13:50:27 2022 +0800
[SPARK-39166][SQL] Provide runtime error query context for binary
arithmetic when WSCG is off
### What changes were proposed in this pull request?
Currently, for most of the cases, the project
https://issues.apache.org/jira/browse/SPARK-38615 is able to show where the
runtime errors happen within the original query.
However, after trying on production, I found that the following queries
won't show where the divide by 0 error happens
```
create table aggTest(i int, j int, k int, d date) using parquet
insert into aggTest values(1, 2, 0, date'2022-01-01')
select sum(j)/sum(k),percentile(i, 0.9) from aggTest group by d
```
With `percentile` function in the query, the plan can't execute with whole
stage codegen. Thus the child plan of `Project` is serialized to executors for
execution, from ProjectExec:
```
protected override def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
val project = UnsafeProjection.create(projectList, child.output)
project.initialize(index)
iter.map(project)
}
}
```
Note that the `TreeNode.origin` is not serialized to executors since
`TreeNode` doesn't extend the trait `Serializable`, which results in an empty
query context on errors. For more details, please read
https://issues.apache.org/jira/browse/SPARK-39140
A dummy fix is to make `TreeNode` extend the trait `Serializable`. However,
it can be performance regression if the query text is long (every `TreeNode`
carries it for serialization).
A better fix is to introduce a new trait `SupportQueryContext` and
materialize the truncated query context for special expressions. This PR
targets on binary arithmetic expressions only. I will create follow-ups for the
remaining expressions which support runtime error query context.
### Why are the changes needed?
Improve the error context framework and make sure it works when whole stage
codegen is not available.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit tests
Closes #36525 from gengliangwang/serializeContext.
Lead-authored-by: Gengliang Wang <[email protected]>
Co-authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
(cherry picked from commit e336567c8a9704b500efecd276abaf5bd3988679)
Signed-off-by: Gengliang Wang <[email protected]>
---
.../sql/catalyst/expressions/Expression.scala | 22 +++++++++++++
.../sql/catalyst/expressions/arithmetic.scala | 38 +++++++++++++---------
.../expressions/ArithmeticExpressionSuite.scala | 25 +++++++-------
.../scala/org/apache/spark/sql/SQLQuerySuite.scala | 20 ++++++++++++
4 files changed, 77 insertions(+), 28 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index e1d8c2e43e2..d620c5d7392 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -588,6 +588,28 @@ abstract class UnaryExpression extends Expression with
UnaryLike[Expression] {
}
}
+/**
+ * An expression with SQL query context. The context string can be serialized
from the Driver
+ * to executors. It will also be kept after rule transforms.
+ */
+trait SupportQueryContext extends Expression with Serializable {
+ protected var queryContext: String = initQueryContext()
+
+ def initQueryContext(): String
+
+ // Note: Even though query contexts are serialized to executors, it will be
regenerated from an
+ // empty "Origin" during rule transforms since "Origin"s are not
serialized to executors
+ // for better performance. Thus, we need to copy the original query
context during
+ // transforms. The query context string is considered as a "tag" on
the expression here.
+ override def copyTagsFrom(other: Expression): Unit = {
+ other match {
+ case s: SupportQueryContext =>
+ queryContext = s.queryContext
+ case _ =>
+ }
+ super.copyTagsFrom(other)
+ }
+}
object UnaryExpression {
def unapply(e: UnaryExpression): Option[Expression] = Some(e.child)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index c6d66d8e607..153187f9e30 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC,
TreePattern,
- UNARY_POSITIVE}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC,
TreePattern, UNARY_POSITIVE}
import org.apache.spark.sql.catalyst.util.{IntervalUtils, MathUtils, TypeUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
@@ -209,7 +208,8 @@ case class Abs(child: Expression, failOnError: Boolean =
SQLConf.get.ansiEnabled
override protected def withNewChildInternal(newChild: Expression): Abs =
copy(child = newChild)
}
-abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
+abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant
+ with SupportQueryContext {
protected val failOnError: Boolean
@@ -219,6 +219,14 @@ abstract class BinaryArithmetic extends BinaryOperator
with NullIntolerant {
override lazy val resolved: Boolean = childrenResolved &&
checkInputDataTypes().isSuccess
+ override def initQueryContext(): String = {
+ if (failOnError) {
+ origin.context
+ } else {
+ ""
+ }
+ }
+
/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String =
throw
QueryExecutionErrors.notOverrideExpectedMethodsError("BinaryArithmetics",
@@ -270,7 +278,7 @@ abstract class BinaryArithmetic extends BinaryOperator with
NullIntolerant {
})
case IntegerType | LongType if failOnError && exactMathMethod.isDefined =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
- val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+ val errorContext = ctx.addReferenceObj("errCtx", queryContext)
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
s"""
|${ev.value} = $mathUtils.${exactMathMethod.get}($eval1, $eval2,
$errorContext);
@@ -331,9 +339,9 @@ case class Add(
case _: YearMonthIntervalType =>
MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
case _: IntegerType if failOnError =>
- MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int],
origin.context)
+ MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int],
queryContext)
case _: LongType if failOnError =>
- MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long],
origin.context)
+ MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long],
queryContext)
case _ => numeric.plus(input1, input2)
}
@@ -381,9 +389,9 @@ case class Subtract(
case _: YearMonthIntervalType =>
MathUtils.subtractExact(input1.asInstanceOf[Int],
input2.asInstanceOf[Int])
case _: IntegerType if failOnError =>
- MathUtils.subtractExact(input1.asInstanceOf[Int],
input2.asInstanceOf[Int], origin.context)
+ MathUtils.subtractExact(input1.asInstanceOf[Int],
input2.asInstanceOf[Int], queryContext)
case _: LongType if failOnError =>
- MathUtils.subtractExact(input1.asInstanceOf[Long],
input2.asInstanceOf[Long], origin.context)
+ MathUtils.subtractExact(input1.asInstanceOf[Long],
input2.asInstanceOf[Long], queryContext)
case _ => numeric.minus(input1, input2)
}
@@ -418,9 +426,9 @@ case class Multiply(
protected override def nullSafeEval(input1: Any, input2: Any): Any =
dataType match {
case _: IntegerType if failOnError =>
- MathUtils.multiplyExact(input1.asInstanceOf[Int],
input2.asInstanceOf[Int], origin.context)
+ MathUtils.multiplyExact(input1.asInstanceOf[Int],
input2.asInstanceOf[Int], queryContext)
case _: LongType if failOnError =>
- MathUtils.multiplyExact(input1.asInstanceOf[Long],
input2.asInstanceOf[Long], origin.context)
+ MathUtils.multiplyExact(input1.asInstanceOf[Long],
input2.asInstanceOf[Long], queryContext)
case _ => numeric.times(input1, input2)
}
@@ -457,10 +465,10 @@ trait DivModLike extends BinaryArithmetic {
} else {
if (isZero(input2)) {
// when we reach here, failOnError must be true.
- throw QueryExecutionErrors.divideByZeroError(origin.context)
+ throw QueryExecutionErrors.divideByZeroError(queryContext)
}
if (checkDivideOverflow && input1 == Long.MinValue && input2 == -1) {
- throw
QueryExecutionErrors.overflowInIntegralDivideError(origin.context)
+ throw
QueryExecutionErrors.overflowInIntegralDivideError(queryContext)
}
evalOperation(input1, input2)
}
@@ -487,7 +495,7 @@ trait DivModLike extends BinaryArithmetic {
} else {
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
- lazy val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+ lazy val errorContext = ctx.addReferenceObj("errCtx", queryContext)
val checkIntegralDivideOverflow = if (checkDivideOverflow) {
s"""
|if (${eval1.value} == ${Long.MinValue}L && ${eval2.value} == -1)
@@ -743,7 +751,7 @@ case class Pmod(
} else {
if (isZero(input2)) {
// when we reach here, failOnError must bet true.
- throw QueryExecutionErrors.divideByZeroError(origin.context)
+ throw QueryExecutionErrors.divideByZeroError(queryContext)
}
input1 match {
case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer])
@@ -768,7 +776,7 @@ case class Pmod(
}
val remainder = ctx.freshName("remainder")
val javaType = CodeGenerator.javaType(dataType)
- lazy val errorContext = ctx.addReferenceObj("errCtx", origin.context)
+ lazy val errorContext = ctx.addReferenceObj("errCtx", queryContext)
val result = dataType match {
case DecimalType.Fixed(_, _) =>
val decimalAdd = "$plus"
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 87777991cb9..e76ff0b4390 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -359,19 +359,18 @@ class ArithmeticExpressionSuite extends SparkFunSuite
with ExpressionEvalHelper
}
test("Remainder/Pmod: exception should contain SQL text context") {
- Seq(
- Remainder(Literal(1L, LongType), Literal(0L, LongType), failOnError =
true),
- Pmod(Literal(1L, LongType), Literal(0L, LongType), failOnError =
true)).foreach { expr =>
- val query = s"1L ${expr.symbol} 0L"
- val o = Origin(
- line = Some(1),
- startPosition = Some(7),
- startIndex = Some(7),
- stopIndex = Some(7 + query.length -1),
- sqlText = Some(s"select $query"))
- withOrigin(o) {
- checkExceptionInExpression[ArithmeticException](expr, EmptyRow,
query)
- }
+ Seq(("%", Remainder), ("pmod", Pmod)).foreach { case (symbol, exprBuilder)
=>
+ val query = s"1L $symbol 0L"
+ val o = Origin(
+ line = Some(1),
+ startPosition = Some(7),
+ startIndex = Some(7),
+ stopIndex = Some(7 + query.length -1),
+ sqlText = Some(s"select $query"))
+ withOrigin(o) {
+ val expression = exprBuilder(Literal(1L, LongType), Literal(0L,
LongType), true)
+ checkExceptionInExpression[ArithmeticException](expression, EmptyRow,
query)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index cba677da19a..f099d3c015c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -4359,6 +4359,26 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
}
}
+ test("SPARK-39166: Query context should be serialized to executors when WSCG
is off") {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+ SQLConf.ANSI_ENABLED.key -> "true") {
+ withTable("t") {
+ sql("create table t(i int, j int) using parquet")
+ sql("insert into t values(2147483647, 10)")
+ Seq(
+ "select i + j from t",
+ "select -i - j from t",
+ "select i * j from t",
+ "select i / (j - 10) from t").foreach { query =>
+ val msg = intercept[SparkException] {
+ sql(query).collect()
+ }.getMessage
+ assert(msg.contains(query))
+ }
+ }
+ }
+ }
+
test("SPARK-38589: try_avg should return null if overflow happens before
merging") {
val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
.map(Period.ofMonths)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]