This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new 4fb7fe2a406 [SPARK-39208][SQL] Fix query context bugs in decimal overflow under codegen mode 4fb7fe2a406 is described below commit 4fb7fe2a40623526ed22311eac16c937450031e5 Author: Gengliang Wang <gengli...@apache.org> AuthorDate: Tue May 17 22:31:30 2022 +0800 [SPARK-39208][SQL] Fix query context bugs in decimal overflow under codegen mode ### What changes were proposed in this pull request? 1. Fix logical bugs in adding query contexts as references under codegen mode. https://github.com/apache/spark/pull/36040/files#diff-4a70d2f3a4b99f58796b87192143f9838f4c4cf469f3313eb30af79c4e07153aR145 The code ``` val errorContextCode = if (nullOnOverflow) { ctx.addReferenceObj("errCtx", queryContext) } else { "\"\"" } ``` should be ``` val errorContextCode = if (nullOnOverflow) { "\"\"" } else { ctx.addReferenceObj("errCtx", queryContext) } ``` 2. Similar to https://github.com/apache/spark/pull/36557, make `CheckOverflowInSum` support query context when WSCG is not available. ### Why are the changes needed? Bugfix and enhancement in the query context of decimal expressions. ### Does this PR introduce _any_ user-facing change? No, the query context is not released yet. ### How was this patch tested? New UT Closes #36577 from gengliangwang/fixDecimalSumOverflow. Authored-by: Gengliang Wang <gengli...@apache.org> Signed-off-by: Gengliang Wang <gengli...@apache.org> (cherry picked from commit 191e535b975e5813719d3143797c9fcf86321368) Signed-off-by: Gengliang Wang <gengli...@apache.org> --- .../sql/catalyst/expressions/aggregate/Sum.scala | 21 ++++++++++++++------- .../catalyst/expressions/decimalExpressions.scala | 15 ++++++++------- .../expressions/DecimalExpressionSuite.scala | 19 +++++++++++++++++++ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 19 +++++++++++-------- 4 files changed, 52 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index f2c6925b837..fa43565d807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -143,10 +143,11 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate * So now, if ansi is enabled, then throw exception, if not then return null. * If sum is not null, then return the sum. */ - protected def getEvaluateExpression: Expression = resultType match { + protected def getEvaluateExpression(queryContext: String): Expression = resultType match { case d: DecimalType => - If(isEmpty, Literal.create(null, resultType), - CheckOverflowInSum(sum, d, !useAnsiAdd)) + val checkOverflowInSum = + CheckOverflowInSum(sum, d, !useAnsiAdd, queryContext) + If(isEmpty, Literal.create(null, resultType), checkOverflowInSum) case _ if shouldTrackIsEmpty => If(isEmpty, Literal.create(null, resultType), sum) case _ => sum @@ -172,7 +173,7 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate case class Sum( child: Expression, useAnsiAdd: Boolean = SQLConf.get.ansiEnabled) - extends SumBase(child) { + extends SumBase(child) with SupportQueryContext { def this(child: Expression) = this(child, useAnsiAdd = SQLConf.get.ansiEnabled) override def shouldTrackIsEmpty: Boolean = resultType match { @@ -186,7 +187,13 @@ case class Sum( override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions - override lazy val evaluateExpression: Expression = getEvaluateExpression + override lazy val evaluateExpression: Expression = getEvaluateExpression(queryContext) + + override def initQueryContext(): String = if (useAnsiAdd) { + origin.context + } else { + "" + } } // scalastyle:off line.size.limit @@ -243,9 +250,9 @@ case class TrySum(child: Expression) extends SumBase(child) { override lazy val evaluateExpression: Expression = if (useAnsiAdd) { - TryEval(getEvaluateExpression) + TryEval(getEvaluateExpression("")) } else { - getEvaluateExpression + getEvaluateExpression("") } override protected def withNewChildInternal(newChild: Expression): Expression = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 2cdd784ea4d..7d25df5ae9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -142,9 +142,9 @@ case class CheckOverflow( override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val errorContextCode = if (nullOnOverflow) { - ctx.addReferenceObj("errCtx", queryContext) - } else { "\"\"" + } else { + ctx.addReferenceObj("errCtx", queryContext) } nullSafeCodeGen(ctx, ev, eval => { // scalastyle:off line.size.limit @@ -175,7 +175,8 @@ case class CheckOverflow( case class CheckOverflowInSum( child: Expression, dataType: DecimalType, - nullOnOverflow: Boolean) extends UnaryExpression { + nullOnOverflow: Boolean, + queryContext: String = "") extends UnaryExpression { override def nullable: Boolean = true @@ -183,23 +184,23 @@ case class CheckOverflowInSum( val value = child.eval(input) if (value == null) { if (nullOnOverflow) null - else throw QueryExecutionErrors.overflowInSumOfDecimalError(origin.context) + else throw QueryExecutionErrors.overflowInSumOfDecimalError(queryContext) } else { value.asInstanceOf[Decimal].toPrecision( dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow, - origin.context) + queryContext) } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) val errorContextCode = if (nullOnOverflow) { - ctx.addReferenceObj("errCtx", origin.context) - } else { "\"\"" + } else { + ctx.addReferenceObj("errCtx", queryContext) } val nullHandling = if (nullOnOverflow) { "" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala index 36bc3db5804..1a8cd63aed0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin +import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{Decimal, DecimalType, LongType} @@ -83,4 +85,21 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CheckOverflow( Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), false), null) } + + test("SPARK-39208: CheckOverflow & CheckOverflowInSum support query context in runtime errors") { + val d = Decimal(101, 3, 1) + val query = "select cast(d as decimal(4, 3)) from t" + val origin = Origin( + startIndex = Some(7), + stopIndex = Some(30), + sqlText = Some(query)) + + val expr1 = withOrigin(origin) { + CheckOverflow(Literal(d), DecimalType(4, 3), false) + } + checkExceptionInExpression[ArithmeticException](expr1, query) + + val expr2 = CheckOverflowInSum(Literal(d), DecimalType(4, 3), false, queryContext = query) + checkExceptionInExpression[ArithmeticException](expr2, query) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f6998fe5c1c..422ba7c2a9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4423,18 +4423,21 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } - test("SPARK-39190: Query context of decimal overflow error should be serialized to executors" + - " when WSCG is off") { + test("SPARK-39190, SPARK-39208: Query context of decimal overflow error should be serialized " + + "to executors when WSCG is off") { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", SQLConf.ANSI_ENABLED.key -> "true") { withTable("t") { sql("create table t(d decimal(38, 0)) using parquet") - sql("insert into t values (2e37BD)") - val query = "select d / 0.1 from t" - val msg = intercept[SparkException] { - sql(query).collect() - }.getMessage - assert(msg.contains(query)) + sql("insert into t values (6e37BD),(6e37BD)") + Seq( + "select d / 0.1 from t", + "select sum(d) from t").foreach { query => + val msg = intercept[SparkException] { + sql(query).collect() + }.getMessage + assert(msg.contains(query)) + } } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org