This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e9eb28e27d1 [SPARK-39319][CORE][SQL] Make query contexts as a part of `SparkThrowable` e9eb28e27d1 is described below commit e9eb28e27d10497c8b36774609823f4bbd2c8500 Author: Max Gekk <max.g...@gmail.com> AuthorDate: Tue Jul 26 20:17:09 2022 +0500 [SPARK-39319][CORE][SQL] Make query contexts as a part of `SparkThrowable` ### What changes were proposed in this pull request? In the PR, I propose to add new interface `QueryContext` Spark core, and allow to get an instance of `QueryContext` from Spark's exceptions of the type `SparkThrowable`. For instance, `QueryContext` should help users to figure out where an error occur while executing queries in Spark SQL. Also this PR adds `SqlQueryContext` as one of implementation of `QueryContext` to Spark SQL `Origin` which contains a context of TreeNodes + textual summary of the error. The `context` value in `Origin` will have all necessary structural info about the fragment of SQL query to which an error can be linked. All Spark's exceptions are modified to accept the optional `QueryContext` and pre-built text summary. Apparently, SQL expressions init and pass new context to exceptions. Closes #36702 ### Why are the changes needed? In the future, this enriches the information of the error message. With the change, it is possible to have a new pretty printing format error message like ```sql > SELECT * FROM v1; { “errorClass” : [ “DIVIDE_BY_ZERO” ], “parameters” : [ { “name” = “config”, “value” = “spark.sql.ansi.enabled” } ], “sqlState” : “42000”, “context” : { “objectType” : “VIEW”, “objectName” : “default.v1” “indexStart” : 36, “indexEnd” : 41, “fragment” : “1 / 0” } } } ``` ### Does this PR introduce _any_ user-facing change? Yes. The PR changes Spark's exception by replacing the type of `queryContext` from `String` to `Option[QueryContext]`. User's code can fail if it uses `queryContext`. ### How was this patch tested? By running the modified test suites: ``` $ build/sbt "test:testOnly *DecimalExpressionSuite" $ build/sbt "test:testOnly *TreeNodeSuite" ``` and affected test suites: ``` $ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite" ``` Authored-by: Max Gekk <max.gekkgmail.com> Co-authored-by: Gengliang Wang <gengliangapache.org> Closes #37209 from MaxGekk/query-context-in-sparkthrowable. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../main/java/org/apache/spark/QueryContext.java | 48 ++++++++ .../main/java/org/apache/spark/SparkThrowable.java | 2 + .../apache/spark/memory/SparkOutOfMemoryError.java | 3 +- .../main/scala/org/apache/spark/ErrorInfo.scala | 16 ++- .../scala/org/apache/spark/SparkException.scala | 58 +++++---- .../spark/sql/catalyst/expressions/Cast.scala | 20 ++-- .../sql/catalyst/expressions/Expression.scala | 6 +- .../catalyst/expressions/aggregate/Average.scala | 14 +-- .../sql/catalyst/expressions/aggregate/Sum.scala | 18 +-- .../sql/catalyst/expressions/arithmetic.scala | 27 ++--- .../expressions/collectionOperations.scala | 8 +- .../expressions/complexTypeExtractors.scala | 13 ++- .../catalyst/expressions/decimalExpressions.scala | 25 ++-- .../catalyst/expressions/intervalExpressions.scala | 22 ++-- .../sql/catalyst/expressions/mathExpressions.scala | 2 +- .../catalyst/expressions/stringExpressions.scala | 10 +- .../spark/sql/catalyst/trees/SQLQueryContext.scala | 130 +++++++++++++++++++++ .../apache/spark/sql/catalyst/trees/TreeNode.scala | 99 ++-------------- .../spark/sql/catalyst/util/DateTimeUtils.scala | 22 ++-- .../spark/sql/catalyst/util/IntervalUtils.scala | 2 +- .../apache/spark/sql/catalyst/util/MathUtils.scala | 38 +++--- .../spark/sql/catalyst/util/UTF8StringUtils.scala | 25 ++-- .../apache/spark/sql/errors/QueryErrorsBase.scala | 5 + .../spark/sql/errors/QueryExecutionErrors.scala | 74 +++++++----- .../scala/org/apache/spark/sql/types/Decimal.scala | 7 +- .../expressions/DecimalExpressionSuite.scala | 4 +- .../spark/sql/catalyst/trees/TreeNodeSuite.scala | 15 ++- 27 files changed, 445 insertions(+), 268 deletions(-) diff --git a/core/src/main/java/org/apache/spark/QueryContext.java b/core/src/main/java/org/apache/spark/QueryContext.java new file mode 100644 index 00000000000..de5b29d0295 --- /dev/null +++ b/core/src/main/java/org/apache/spark/QueryContext.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.apache.spark.annotation.Evolving; + +/** + * Query context of a {@link SparkThrowable}. It helps users understand where error occur + * while executing queries. + * + * @since 3.4.0 + */ +@Evolving +public interface QueryContext { + // The object type of the query which throws the exception. + // If the exception is directly from the main query, it should be an empty string. + // Otherwise, it should be the exact object type in upper case. For example, a "VIEW". + String objectType(); + + // The object name of the query which throws the exception. + // If the exception is directly from the main query, it should be an empty string. + // Otherwise, it should be the object name. For example, a view name "V1". + String objectName(); + + // The starting index in the query text which throws the exception. The index starts from 0. + int startIndex(); + + // The stopping index in the query which throws the exception. The index starts from 0. + int stopIndex(); + + // The corresponding fragment of the query which throws the exception. + String fragment(); +} diff --git a/core/src/main/java/org/apache/spark/SparkThrowable.java b/core/src/main/java/org/apache/spark/SparkThrowable.java index 581e1f6eebb..52fd64135a9 100644 --- a/core/src/main/java/org/apache/spark/SparkThrowable.java +++ b/core/src/main/java/org/apache/spark/SparkThrowable.java @@ -59,4 +59,6 @@ public interface SparkThrowable { default String[] getParameterNames() { return SparkThrowableHelper.getParameterNames(this.getErrorClass(), this.getErrorSubClass()); } + + default QueryContext[] getQueryContext() { return new QueryContext[0]; } } diff --git a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java index 9d2739018a0..3ff3094456e 100644 --- a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java +++ b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java @@ -39,8 +39,7 @@ public final class SparkOutOfMemoryError extends OutOfMemoryError implements Spa } public SparkOutOfMemoryError(String errorClass, String[] messageParameters) { - super(SparkThrowableHelper.getMessage(errorClass, null, - messageParameters, "")); + super(SparkThrowableHelper.getMessage(errorClass, null, messageParameters)); this.errorClass = errorClass; this.messageParameters = messageParameters; } diff --git a/core/src/main/scala/org/apache/spark/ErrorInfo.scala b/core/src/main/scala/org/apache/spark/ErrorInfo.scala index d98f5296fee..6c72a27aa4b 100644 --- a/core/src/main/scala/org/apache/spark/ErrorInfo.scala +++ b/core/src/main/scala/org/apache/spark/ErrorInfo.scala @@ -71,11 +71,18 @@ private[spark] object SparkThrowableHelper { mapper.readValue(errorClassesUrl, new TypeReference[SortedMap[String, ErrorInfo]]() {}) } + def getMessage( + errorClass: String, + errorSubClass: String, + messageParameters: Array[String]): String = { + getMessage(errorClass, errorSubClass, messageParameters, "") + } + def getMessage( errorClass: String, errorSubClass: String, messageParameters: Array[String], - queryContext: String = ""): String = { + context: String): String = { val errorInfo = errorClassToInfoMap.getOrElse(errorClass, throw new IllegalArgumentException(s"Cannot find error class '$errorClass'")) val (displayClass, displayMessageParameters, displayFormat) = if (errorInfo.subClass.isEmpty) { @@ -93,11 +100,8 @@ private[spark] object SparkThrowableHelper { val displayMessage = String.format( displayFormat.replaceAll("<[a-zA-Z0-9_-]+>", "%s"), displayMessageParameters : _*) - val displayQueryContext = if (queryContext.isEmpty) { - "" - } else { - s"\n$queryContext" - } + val displayQueryContext = (if (context.isEmpty) "" else "\n") + context + s"[$displayClass] $displayMessage$displayQueryContext" } diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index 739d2aab23b..d6add48ffb1 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -119,15 +119,17 @@ private[spark] class SparkArithmeticException( errorClass: String, errorSubClass: Option[String] = None, messageParameters: Array[String], - queryContext: String = "") + context: Option[QueryContext], + summary: String) extends ArithmeticException( - SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, - messageParameters, queryContext)) + SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) with SparkThrowable { override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass - override def getErrorSubClass: String = errorSubClass.orNull} + override def getErrorSubClass: String = errorSubClass.orNull + override def getQueryContext: Array[QueryContext] = context.toArray +} /** * Unsupported operation exception thrown from Spark with an error class. @@ -193,15 +195,17 @@ private[spark] class SparkDateTimeException( errorClass: String, errorSubClass: Option[String] = None, messageParameters: Array[String], - queryContext: String = "") + context: Option[QueryContext], + summary: String) extends DateTimeException( - SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, - messageParameters, queryContext)) + SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) with SparkThrowable { override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass - override def getErrorSubClass: String = errorSubClass.orNull} + override def getErrorSubClass: String = errorSubClass.orNull + override def getQueryContext: Array[QueryContext] = context.toArray +} /** * Hadoop file already exists exception thrown from Spark with an error class. @@ -240,15 +244,17 @@ private[spark] class SparkNumberFormatException( errorClass: String, errorSubClass: Option[String] = None, messageParameters: Array[String], - queryContext: String) + context: Option[QueryContext], + summary: String) extends NumberFormatException( - SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, - messageParameters, queryContext)) + SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) with SparkThrowable { override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass - override def getErrorSubClass: String = errorSubClass.orNull} + override def getErrorSubClass: String = errorSubClass.orNull + override def getQueryContext: Array[QueryContext] = context.toArray +} /** * No such method exception thrown from Spark with an error class. @@ -317,10 +323,10 @@ private[spark] class SparkRuntimeException( errorSubClass: Option[String] = None, messageParameters: Array[String], cause: Throwable = null, - queryContext: String = "") + context: Option[QueryContext] = None, + summary: String = "") extends RuntimeException( - SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, - messageParameters, queryContext), + SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary), cause) with SparkThrowable { @@ -328,12 +334,12 @@ private[spark] class SparkRuntimeException( errorSubClass: String, messageParameters: Array[String], cause: Throwable, - queryContext: String) + context: Option[QueryContext]) = this(errorClass = errorClass, errorSubClass = Some(errorSubClass), messageParameters = messageParameters, cause = cause, - queryContext = queryContext) + context = context) def this(errorClass: String, errorSubClass: String, @@ -342,11 +348,12 @@ private[spark] class SparkRuntimeException( errorSubClass = Some(errorSubClass), messageParameters = messageParameters, cause = null, - queryContext = "") + context = None) override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull + override def getQueryContext: Array[QueryContext] = context.toArray } /** @@ -372,16 +379,16 @@ private[spark] class SparkArrayIndexOutOfBoundsException( errorClass: String, errorSubClass: Option[String] = None, messageParameters: Array[String], - queryContext: String) + context: Option[QueryContext], + summary: String) extends ArrayIndexOutOfBoundsException( - // scalastyle:off line.size.limit - SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, queryContext)) - // scalastyle:on line.size.limit + SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) with SparkThrowable { override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull + override def getQueryContext: Array[QueryContext] = context.toArray } /** @@ -413,15 +420,16 @@ private[spark] class SparkNoSuchElementException( errorClass: String, errorSubClass: Option[String] = None, messageParameters: Array[String], - queryContext: String) + context: Option[QueryContext], + summary: String) extends NoSuchElementException( - SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, - messageParameters, queryContext)) + SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) with SparkThrowable { override def getMessageParameters: Array[String] = messageParameters override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull + override def getQueryContext: Array[QueryContext] = context.toArray } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8afd3c13461..82de2a0de14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, TreeNodeTag} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ @@ -481,10 +481,10 @@ case class Cast( override def nullable: Boolean = child.nullable || Cast.forceNullable(child.dataType, dataType) - override def initQueryContext(): String = if (ansiEnabled) { - origin.context + override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) { + Some(origin.context) } else { - "" + None } // When this cast involves TimeZone, it's only resolved if the timeZoneId is set; @@ -995,9 +995,12 @@ case class Cast( * If overflow occurs, if `spark.sql.ansi.enabled` is false, null is returned; * otherwise, an `ArithmeticException` is thrown. */ - private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal = + private[this] def toPrecision( + value: Decimal, + decimalType: DecimalType, + context: Option[SQLQueryContext]): Decimal = value.toPrecision( - decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled) + decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled, context) private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { @@ -1010,14 +1013,15 @@ case class Cast( buildCast[UTF8String](_, s => changePrecision(Decimal.fromStringANSI(s, target, queryContext), target)) case BooleanType => - buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) + buildCast[Boolean](_, + b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target, queryContext)) case DateType => buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) case dt: DecimalType => - b => toPrecision(b.asInstanceOf[Decimal], target) + b => toPrecision(b.asInstanceOf[Decimal], target, queryContext) case t: IntegralType => b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target) case x: FractionalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d620c5d7392..d623357b9da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.QueryExecutionErrors @@ -593,9 +593,9 @@ abstract class UnaryExpression extends Expression with UnaryLike[Expression] { * to executors. It will also be kept after rule transforms. */ trait SupportQueryContext extends Expression with Serializable { - protected var queryContext: String = initQueryContext() + protected var queryContext: Option[SQLQueryContext] = initQueryContext() - def initQueryContext(): String + def initQueryContext(): Option[SQLQueryContext] // Note: Even though query contexts are serialized to executors, it will be regenerated from an // empty "Origin" during rule transforms since "Origin"s are not serialized to executors diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index e64f76bdb0a..b749dfdaea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern} -import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -86,7 +86,7 @@ abstract class AverageBase // If all input are nulls, count will be 0 and we will get null after the division. // We can't directly use `/` as it throws an exception under ansi mode. - protected def getEvaluateExpression(queryContext: String) = child.dataType match { + protected def getEvaluateExpression(context: Option[SQLQueryContext]) = child.dataType match { case _: DecimalType => If(EqualTo(count, Literal(0L)), Literal(null, resultType), @@ -94,7 +94,7 @@ abstract class AverageBase sum, count.cast(DecimalType.LongDecimal), resultType.asInstanceOf[DecimalType], - queryContext, + context, !useAnsiAdd)) case _: YearMonthIntervalType => If(EqualTo(count, Literal(0L)), @@ -143,10 +143,10 @@ case class Average( override lazy val evaluateExpression: Expression = getEvaluateExpression(queryContext) - override def initQueryContext(): String = if (useAnsiAdd) { - origin.context + override def initQueryContext(): Option[SQLQueryContext] = if (useAnsiAdd) { + Some(origin.context) } else { - "" + None } } @@ -206,7 +206,7 @@ case class TryAverage(child: Expression) extends AverageBase { } override lazy val evaluateExpression: Expression = { - addTryEvalIfNeeded(getEvaluateExpression("")) + addTryEvalIfNeeded(getEvaluateExpression(None)) } override protected def withNewChildInternal(newChild: Expression): Expression = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 1c4297d735b..9230bd9bf44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{SUM, TreePattern} -import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -148,10 +148,10 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate * So now, if ansi is enabled, then throw exception, if not then return null. * If sum is not null, then return the sum. */ - protected def getEvaluateExpression(queryContext: String): Expression = resultType match { + protected def getEvaluateExpression( + context: Option[SQLQueryContext]): Expression = resultType match { case d: DecimalType => - val checkOverflowInSum = - CheckOverflowInSum(sum, d, !useAnsiAdd, queryContext) + val checkOverflowInSum = CheckOverflowInSum(sum, d, !useAnsiAdd, context) If(isEmpty, Literal.create(null, resultType), checkOverflowInSum) case _ if shouldTrackIsEmpty => If(isEmpty, Literal.create(null, resultType), sum) @@ -194,10 +194,10 @@ case class Sum( override lazy val evaluateExpression: Expression = getEvaluateExpression(queryContext) - override def initQueryContext(): String = if (useAnsiAdd) { - origin.context + override def initQueryContext(): Option[SQLQueryContext] = if (useAnsiAdd) { + Some(origin.context) } else { - "" + None } } @@ -255,9 +255,9 @@ case class TrySum(child: Expression) extends SumBase(child) { override lazy val evaluateExpression: Expression = if (useAnsiAdd) { - TryEval(getEvaluateExpression("")) + TryEval(getEvaluateExpression(None)) } else { - getEvaluateExpression("") + getEvaluateExpression(None) } override protected def withNewChildInternal(newChild: Expression): Expression = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 97411c05d5f..7bbe5d15b91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern, UNARY_POSITIVE} import org.apache.spark.sql.catalyst.util.{IntervalUtils, MathUtils, TypeUtils} import org.apache.spark.sql.errors.QueryExecutionErrors @@ -255,11 +256,11 @@ abstract class BinaryArithmetic extends BinaryOperator override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess - override def initQueryContext(): String = { + override def initQueryContext(): Option[SQLQueryContext] = { if (failOnError) { - origin.context + Some(origin.context) } else { - "" + None } } @@ -287,7 +288,7 @@ abstract class BinaryArithmetic extends BinaryOperator val errorContextCode = if (failOnError) { ctx.addReferenceObj("errCtx", queryContext) } else { - "\"\"" + "scala.None$.MODULE$" } val updateIsNull = if (failOnError) { "" @@ -602,17 +603,17 @@ trait DivModLike extends BinaryArithmetic { s"${eval2.value} == 0" } val javaType = CodeGenerator.javaType(dataType) - val errorContext = if (failOnError) { + val errorContextCode = if (failOnError) { ctx.addReferenceObj("errCtx", queryContext) } else { - "\"\"" + "scala.None$.MODULE$" } val operation = super.dataType match { case DecimalType.Fixed(precision, scale) => val decimalValue = ctx.freshName("decimalValue") s""" |Decimal $decimalValue = ${eval1.value}.$decimalMethod(${eval2.value}).toPrecision( - | $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, $errorContext); + | $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, $errorContextCode); |if ($decimalValue != null) { | ${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")}; |} else { @@ -624,7 +625,7 @@ trait DivModLike extends BinaryArithmetic { val checkIntegralDivideOverflow = if (checkDivideOverflow) { s""" |if (${eval1.value} == ${Long.MinValue}L && ${eval2.value} == -1) - | throw QueryExecutionErrors.overflowInIntegralDivideError($errorContext); + | throw QueryExecutionErrors.overflowInIntegralDivideError($errorContextCode); |""".stripMargin } else { "" @@ -633,7 +634,7 @@ trait DivModLike extends BinaryArithmetic { // evaluate right first as we have a chance to skip left if right is 0 if (!left.nullable && !right.nullable) { val divByZero = if (failOnError) { - s"throw QueryExecutionErrors.divideByZeroError($errorContext);" + s"throw QueryExecutionErrors.divideByZeroError($errorContextCode);" } else { s"${ev.isNull} = true;" } @@ -651,7 +652,7 @@ trait DivModLike extends BinaryArithmetic { } else { val nullOnErrorCondition = if (failOnError) "" else s" || $isZero" val failOnErrorBranch = if (failOnError) { - s"if ($isZero) throw QueryExecutionErrors.divideByZeroError($errorContext);" + s"if ($isZero) throw QueryExecutionErrors.divideByZeroError($errorContextCode);" } else { "" } @@ -978,11 +979,7 @@ case class Pmod( } val remainder = ctx.freshName("remainder") val javaType = CodeGenerator.javaType(dataType) - val errorContext = if (failOnError) { - ctx.addReferenceObj("errCtx", queryContext) - } else { - "\"\"" - } + val errorContext = ctx.addReferenceObj("errCtx", queryContext) val result = dataType match { case DecimalType.Fixed(precision, scale) => val decimalAdd = "$plus" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c74f8e1a685..79e6144c9f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ @@ -2266,10 +2266,10 @@ case class ElementAt( override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ElementAt = copy(left = newLeft, right = newRight) - override def initQueryContext(): String = if (failOnError) { - origin.context + override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + Some(origin.context) } else { - "" + None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 198fd0cd1f2..fedfcfb978f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.trees.TreePattern.{EXTRACT_VALUE, TreePattern} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -314,10 +315,10 @@ case class GetArrayItem( newLeft: Expression, newRight: Expression): GetArrayItem = copy(child = newLeft, ordinal = newRight) - override def initQueryContext(): String = if (failOnError) { - origin.context + override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + Some(origin.context) } else { - "" + None } } @@ -503,9 +504,9 @@ case class GetMapValue( newLeft: Expression, newRight: Expression): GetMapValue = copy(child = newLeft, key = newRight) - override def initQueryContext(): String = if (failOnError) { - origin.context + override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + Some(origin.context) } else { - "" + None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 2dd60a9d9ca..e672fffda19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -126,7 +127,7 @@ case class CheckOverflow( override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val errorContextCode = if (nullOnOverflow) { - "\"\"" + "scala.None$.MODULE$" } else { ctx.addReferenceObj("errCtx", queryContext) } @@ -148,10 +149,10 @@ case class CheckOverflow( override protected def withNewChildInternal(newChild: Expression): CheckOverflow = copy(child = newChild) - override def initQueryContext(): String = if (nullOnOverflow) { - "" + override def initQueryContext(): Option[SQLQueryContext] = if (!nullOnOverflow) { + Some(origin.context) } else { - origin.context + None } } @@ -160,7 +161,7 @@ case class CheckOverflowInSum( child: Expression, dataType: DecimalType, nullOnOverflow: Boolean, - queryContext: String = "") extends UnaryExpression { + context: Option[SQLQueryContext] = None) extends UnaryExpression { override def nullable: Boolean = true @@ -168,23 +169,23 @@ case class CheckOverflowInSum( val value = child.eval(input) if (value == null) { if (nullOnOverflow) null - else throw QueryExecutionErrors.overflowInSumOfDecimalError(queryContext) + else throw QueryExecutionErrors.overflowInSumOfDecimalError(context) } else { value.asInstanceOf[Decimal].toPrecision( dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow, - queryContext) + context) } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) val errorContextCode = if (nullOnOverflow) { - "\"\"" + "scala.None$.MODULE$" } else { - ctx.addReferenceObj("errCtx", queryContext) + ctx.addReferenceObj("errCtx", context) } val nullHandling = if (nullOnOverflow) { "" @@ -260,12 +261,12 @@ case class DecimalDivideWithOverflowCheck( left: Expression, right: Expression, override val dataType: DecimalType, - avgQueryContext: String, + context: Option[SQLQueryContext], nullOnOverflow: Boolean) extends BinaryExpression with ExpectsInputTypes with SupportQueryContext { override def nullable: Boolean = nullOnOverflow override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, DecimalType) - override def initQueryContext(): String = avgQueryContext + override def initQueryContext(): Option[SQLQueryContext] = context def decimalMethod: String = "$div" override def eval(input: InternalRow): Any = { @@ -286,7 +287,7 @@ case class DecimalDivideWithOverflowCheck( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val errorContextCode = if (nullOnOverflow) { - "\"\"" + "scala.None$.MODULE$" } else { ctx.addReferenceObj("errCtx", queryContext) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 3b06f811546..0a275d0760f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -23,6 +23,7 @@ import java.util.Locale import com.google.common.math.{DoubleMath, IntMath, LongMath} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants.MONTHS_PER_YEAR import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils._ @@ -603,7 +604,7 @@ trait IntervalDivide { minValue: Any, num: Expression, numValue: Any, - context: String): Unit = { + context: Option[SQLQueryContext]): Unit = { if (value == minValue && num.dataType.isInstanceOf[IntegralType]) { if (numValue.asInstanceOf[Number].longValue() == -1) { throw QueryExecutionErrors.overflowInIntegralDivideError(context) @@ -611,7 +612,10 @@ trait IntervalDivide { } } - def divideByZeroCheck(dataType: DataType, num: Any, context: String): Unit = dataType match { + def divideByZeroCheck( + dataType: DataType, + num: Any, + context: Option[SQLQueryContext]): Unit = dataType match { case _: DecimalType => if (num.asInstanceOf[Decimal].isZero) throw QueryExecutionErrors.divideByZeroError(context) case _ => if (num == 0) throw QueryExecutionErrors.divideByZeroError(context) @@ -656,13 +660,14 @@ case class DivideYMInterval( } override def nullSafeEval(interval: Any, num: Any): Any = { - checkDivideOverflow(interval.asInstanceOf[Int], Int.MinValue, right, num, origin.context) - divideByZeroCheck(right.dataType, num, origin.context) + checkDivideOverflow( + interval.asInstanceOf[Int], Int.MinValue, right, num, Some(origin.context)) + divideByZeroCheck(right.dataType, num, Some(origin.context)) evalFunc(interval.asInstanceOf[Int], num) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val errorContext = ctx.addReferenceObj("errCtx", origin.context) + val errorContext = ctx.addReferenceObj("errCtx", Some(origin.context)) right.dataType match { case t: IntegralType => val math = t match { @@ -733,13 +738,14 @@ case class DivideDTInterval( } override def nullSafeEval(interval: Any, num: Any): Any = { - checkDivideOverflow(interval.asInstanceOf[Long], Long.MinValue, right, num, origin.context) - divideByZeroCheck(right.dataType, num, origin.context) + checkDivideOverflow( + interval.asInstanceOf[Long], Long.MinValue, right, num, Some(origin.context)) + divideByZeroCheck(right.dataType, num, Some(origin.context)) evalFunc(interval.asInstanceOf[Long], num) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val errorContext = ctx.addReferenceObj("errCtx", origin.context) + val errorContext = ctx.addReferenceObj("errCtx", Some(origin.context)) right.dataType match { case _: IntegralType => val math = classOf[LongMath].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 228b2a974e2..55ff36e9863 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1520,7 +1520,7 @@ abstract class RoundBase(child: Expression, scale: Expression, if (_scale >= 0) { s""" ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, - Decimal.$modeStr(), true, ""); + Decimal.$modeStr(), true, scala.None$$.MODULE$$); ${ev.isNull} = ${ev.value} == null;""" } else { s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index bc24a12f083..815eb8977b6 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -28,12 +28,12 @@ import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegist import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext} import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{StringType, _} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -382,10 +382,10 @@ case class Elt( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Elt = copy(children = newChildren) - override def initQueryContext(): String = if (failOnError) { - origin.context + override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + Some(origin.context) } else { - "" + None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala new file mode 100644 index 00000000000..8f75079fcf9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.trees + +import org.apache.spark.QueryContext + +/** The class represents error context of a SQL query. */ +case class SQLQueryContext( + line: Option[Int], + startPosition: Option[Int], + originStartIndex: Option[Int], + originStopIndex: Option[Int], + sqlText: Option[String], + originObjectType: Option[String], + originObjectName: Option[String]) extends QueryContext { + + override val objectType = originObjectType.getOrElse("") + override val objectName = originObjectName.getOrElse("") + override val startIndex = originStartIndex.getOrElse(-1) + override val stopIndex = originStopIndex.getOrElse(-1) + + /** + * The SQL query context of current node. For example: + * == SQL of VIEW v1(line 1, position 25) == + * SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i + * ^^^^^^^^^^^^^^^ + */ + lazy val summary: String = { + // If the query context is missing or incorrect, simply return an empty string. + if (sqlText.isEmpty || originStartIndex.isEmpty || originStopIndex.isEmpty || + originStartIndex.get < 0 || originStopIndex.get >= sqlText.get.length || + originStartIndex.get > originStopIndex.get) { + "" + } else { + val positionContext = if (line.isDefined && startPosition.isDefined) { + // Note that the line number starts from 1, while the start position starts from 0. + // Here we increase the start position by 1 for consistency. + s"(line ${line.get}, position ${startPosition.get + 1})" + } else { + "" + } + val objectContext = if (originObjectType.isDefined && originObjectName.isDefined) { + s" of ${originObjectType.get} ${originObjectName.get}" + } else { + "" + } + val builder = new StringBuilder + builder ++= s"== SQL$objectContext$positionContext ==\n" + + val text = sqlText.get + val start = math.max(originStartIndex.get, 0) + val stop = math.min(originStopIndex.getOrElse(text.length - 1), text.length - 1) + // Ideally we should show all the lines which contains the SQL text context of the current + // node: + // [additional text] [current tree node] [additional text] + // However, we need to truncate the additional text in case it is too long. The following + // variable is to define the max length of additional text. + val maxExtraContextLength = 32 + val truncatedText = "..." + var lineStartIndex = start + // Collect the SQL text within the starting line of current Node. + // The text is truncated if it is too long. + while (lineStartIndex >= 0 && + start - lineStartIndex <= maxExtraContextLength && + text.charAt(lineStartIndex) != '\n') { + lineStartIndex -= 1 + } + val startTruncated = start - lineStartIndex > maxExtraContextLength + var currentIndex = lineStartIndex + if (startTruncated) { + currentIndex -= truncatedText.length + } + + var lineStopIndex = stop + // Collect the SQL text within the ending line of current Node. + // The text is truncated if it is too long. + while (lineStopIndex < text.length && + lineStopIndex - stop <= maxExtraContextLength && + text.charAt(lineStopIndex) != '\n') { + lineStopIndex += 1 + } + val stopTruncated = lineStopIndex - stop > maxExtraContextLength + + val truncatedSubText = (if (startTruncated) truncatedText else "") + + text.substring(lineStartIndex + 1, lineStopIndex) + + (if (stopTruncated) truncatedText else "") + val lines = truncatedSubText.split("\n") + lines.foreach { lineText => + builder ++= lineText + "\n" + currentIndex += 1 + (0 until lineText.length).foreach { _ => + if (currentIndex < start) { + builder ++= " " + } else if (currentIndex >= start && currentIndex <= stop) { + builder ++= "^" + } + currentIndex += 1 + } + builder ++= "\n" + } + builder.result() + } + } + + /** Gets the textual fragment of a SQL query. */ + override lazy val fragment: String = { + if (sqlText.isEmpty || originStartIndex.isEmpty || originStopIndex.isEmpty || + originStartIndex.get < 0 || originStopIndex.get >= sqlText.get.length || + originStartIndex.get > originStopIndex.get) { + "" + } else { + sqlText.get.substring(originStartIndex.get, originStopIndex.get) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 71d8a0740bc..b8cfdcdbe7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -58,95 +58,16 @@ private class MutableInt(var i: Int) * objects which contain SQL text. */ case class Origin( - line: Option[Int] = None, - startPosition: Option[Int] = None, - startIndex: Option[Int] = None, - stopIndex: Option[Int] = None, - sqlText: Option[String] = None, - objectType: Option[String] = None, - objectName: Option[String] = None) { - - /** - * The SQL query context of current node. For example: - * == SQL of VIEW v1(line 1, position 25) == - * SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i - * ^^^^^^^^^^^^^^^ - */ - lazy val context: String = { - // If the query context is missing or incorrect, simply return an empty string. - if (sqlText.isEmpty || startIndex.isEmpty || stopIndex.isEmpty || - startIndex.get < 0 || stopIndex.get >= sqlText.get.length || startIndex.get > stopIndex.get) { - "" - } else { - val positionContext = if (line.isDefined && startPosition.isDefined) { - // Note that the line number starts from 1, while the start position starts from 0. - // Here we increase the start position by 1 for consistency. - s"(line ${line.get}, position ${startPosition.get + 1})" - } else { - "" - } - val objectContext = if (objectType.isDefined && objectName.isDefined) { - s" of ${objectType.get} ${objectName.get}" - } else { - "" - } - val builder = new StringBuilder - builder ++= s"== SQL$objectContext$positionContext ==\n" - - val text = sqlText.get - val start = math.max(startIndex.get, 0) - val stop = math.min(stopIndex.getOrElse(text.length - 1), text.length - 1) - // Ideally we should show all the lines which contains the SQL text context of the current - // node: - // [additional text] [current tree node] [additional text] - // However, we need to truncate the additional text in case it is too long. The following - // variable is to define the max length of additional text. - val maxExtraContextLength = 32 - val truncatedText = "..." - var lineStartIndex = start - // Collect the SQL text within the starting line of current Node. - // The text is truncated if it is too long. - while (lineStartIndex >= 0 && - start - lineStartIndex <= maxExtraContextLength && - text.charAt(lineStartIndex) != '\n') { - lineStartIndex -= 1 - } - val startTruncated = start - lineStartIndex > maxExtraContextLength - var currentIndex = lineStartIndex - if (startTruncated) { - currentIndex -= truncatedText.length - } - - var lineStopIndex = stop - // Collect the SQL text within the ending line of current Node. - // The text is truncated if it is too long. - while (lineStopIndex < text.length && - lineStopIndex - stop <= maxExtraContextLength && - text.charAt(lineStopIndex) != '\n') { - lineStopIndex += 1 - } - val stopTruncated = lineStopIndex - stop > maxExtraContextLength - - val truncatedSubText = (if (startTruncated) truncatedText else "") + - text.substring(lineStartIndex + 1, lineStopIndex) + - (if (stopTruncated) truncatedText else "") - val lines = truncatedSubText.split("\n") - lines.foreach { lineText => - builder ++= lineText + "\n" - currentIndex += 1 - (0 until lineText.length).foreach { _ => - if (currentIndex < start) { - builder ++= " " - } else if (currentIndex >= start && currentIndex <= stop) { - builder ++= "^" - } - currentIndex += 1 - } - builder ++= "\n" - } - builder.result() - } - } + line: Option[Int] = None, + startPosition: Option[Int] = None, + startIndex: Option[Int] = None, + stopIndex: Option[Int] = None, + sqlText: Option[String] = None, + objectType: Option[String] = None, + objectName: Option[String] = None) { + + val context: SQLQueryContext = SQLQueryContext( + line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index d206585ea53..172c2e54034 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -27,6 +27,7 @@ import scala.util.control.NonFatal import sun.util.calendar.ZoneInfo +import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.RebaseDateTime._ import org.apache.spark.sql.errors.QueryExecutionErrors @@ -464,17 +465,20 @@ object DateTimeUtils { } } - def stringToTimestampAnsi(s: UTF8String, timeZoneId: ZoneId, errorContext: String = ""): Long = { + def stringToTimestampAnsi( + s: UTF8String, + timeZoneId: ZoneId, + context: Option[SQLQueryContext] = None): Long = { stringToTimestamp(s, timeZoneId).getOrElse { throw QueryExecutionErrors.invalidInputInCastToDatetimeError( - s, StringType, TimestampType, errorContext) + s, StringType, TimestampType, context) } } - def doubleToTimestampAnsi(d: Double, errorContext: String): Long = { + def doubleToTimestampAnsi(d: Double, context: Option[SQLQueryContext]): Long = { if (d.isNaN || d.isInfinite) { throw QueryExecutionErrors.invalidInputInCastToDatetimeError( - d, DoubleType, TimestampType, errorContext) + d, DoubleType, TimestampType, context) } else { DoubleExactNumeric.toLong(d * MICROS_PER_SECOND) } @@ -521,10 +525,12 @@ object DateTimeUtils { stringToTimestampWithoutTimeZone(s, true) } - def stringToTimestampWithoutTimeZoneAnsi(s: UTF8String, errorContext: String): Long = { + def stringToTimestampWithoutTimeZoneAnsi( + s: UTF8String, + context: Option[SQLQueryContext]): Long = { stringToTimestampWithoutTimeZone(s, true).getOrElse { throw QueryExecutionErrors.invalidInputInCastToDatetimeError( - s, StringType, TimestampNTZType, errorContext) + s, StringType, TimestampNTZType, context) } } @@ -640,10 +646,10 @@ object DateTimeUtils { } } - def stringToDateAnsi(s: UTF8String, errorContext: String = ""): Int = { + def stringToDateAnsi(s: UTF8String, context: Option[SQLQueryContext] = None): Int = { stringToDate(s).getOrElse { throw QueryExecutionErrors.invalidInputInCastToDatetimeError( - s, StringType, DateType, errorContext) + s, StringType, DateType, context) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 721f50208ad..de486157cbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -733,7 +733,7 @@ object IntervalUtils { * @throws ArithmeticException if the result overflows any field value or divided by zero */ def divideExact(interval: CalendarInterval, num: Double): CalendarInterval = { - if (num == 0) throw QueryExecutionErrors.divideByZeroError("") + if (num == 0) throw QueryExecutionErrors.divideByZeroError(None) fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala index e5c87a41ea8..6cb3616d4e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.errors.QueryExecutionErrors /** @@ -26,33 +27,39 @@ object MathUtils { def addExact(a: Int, b: Int): Int = withOverflow(Math.addExact(a, b)) - def addExact(a: Int, b: Int, errorContext: String): Int = - withOverflow(Math.addExact(a, b), hint = "try_add", errorContext = errorContext) + def addExact(a: Int, b: Int, context: Option[SQLQueryContext]): Int = { + withOverflow(Math.addExact(a, b), hint = "try_add", context) + } def addExact(a: Long, b: Long): Long = withOverflow(Math.addExact(a, b)) - def addExact(a: Long, b: Long, errorContext: String): Long = - withOverflow(Math.addExact(a, b), hint = "try_add", errorContext = errorContext) + def addExact(a: Long, b: Long, context: Option[SQLQueryContext]): Long = { + withOverflow(Math.addExact(a, b), hint = "try_add", context) + } def subtractExact(a: Int, b: Int): Int = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Int, b: Int, errorContext: String): Int = - withOverflow(Math.subtractExact(a, b), hint = "try_subtract", errorContext = errorContext) + def subtractExact(a: Int, b: Int, context: Option[SQLQueryContext]): Int = { + withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) + } def subtractExact(a: Long, b: Long): Long = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Long, b: Long, errorContext: String): Long = - withOverflow(Math.subtractExact(a, b), hint = "try_subtract", errorContext = errorContext) + def subtractExact(a: Long, b: Long, context: Option[SQLQueryContext]): Long = { + withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) + } def multiplyExact(a: Int, b: Int): Int = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Int, b: Int, errorContext: String): Int = - withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", errorContext = errorContext) + def multiplyExact(a: Int, b: Int, context: Option[SQLQueryContext]): Int = { + withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) + } def multiplyExact(a: Long, b: Long): Long = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Long, b: Long, errorContext: String): Long = - withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", errorContext = errorContext) + def multiplyExact(a: Long, b: Long, context: Option[SQLQueryContext]): Long = { + withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) + } def negateExact(a: Int): Int = withOverflow(Math.negateExact(a)) @@ -68,12 +75,15 @@ object MathUtils { def floorMod(a: Long, b: Long): Long = withOverflow(Math.floorMod(a, b)) - private def withOverflow[A](f: => A, hint: String = "", errorContext: String = ""): A = { + private def withOverflow[A]( + f: => A, + hint: String = "", + context: Option[SQLQueryContext] = None): A = { try { f } catch { case e: ArithmeticException => - throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage, hint, errorContext) + throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage, hint, context) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala index d4aac3e88df..503c0e181ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, LongType, ShortType} import org.apache.spark.unsafe.types.UTF8String @@ -26,24 +27,28 @@ import org.apache.spark.unsafe.types.UTF8String */ object UTF8StringUtils { - def toLongExact(s: UTF8String, errorContext: String): Long = - withException(s.toLongExact, errorContext, LongType, s) + def toLongExact(s: UTF8String, context: Option[SQLQueryContext]): Long = + withException(s.toLongExact, context, LongType, s) - def toIntExact(s: UTF8String, errorContext: String): Int = - withException(s.toIntExact, errorContext, IntegerType, s) + def toIntExact(s: UTF8String, context: Option[SQLQueryContext]): Int = + withException(s.toIntExact, context, IntegerType, s) - def toShortExact(s: UTF8String, errorContext: String): Short = - withException(s.toShortExact, errorContext, ShortType, s) + def toShortExact(s: UTF8String, context: Option[SQLQueryContext]): Short = + withException(s.toShortExact, context, ShortType, s) - def toByteExact(s: UTF8String, errorContext: String): Byte = - withException(s.toByteExact, errorContext, ByteType, s) + def toByteExact(s: UTF8String, context: Option[SQLQueryContext]): Byte = + withException(s.toByteExact, context, ByteType, s) - private def withException[A](f: => A, errorContext: String, to: DataType, s: UTF8String): A = { + private def withException[A]( + f: => A, + context: Option[SQLQueryContext], + to: DataType, + s: UTF8String): A = { try { f } catch { case e: NumberFormatException => - throw QueryExecutionErrors.invalidInputInCastToNumberError(to, s, errorContext) + throw QueryExecutionErrors.invalidInputInCastToNumberError(to, s, context) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala index 345fc8e0232..9617f7d4b0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.{quoteIdentifier, toPrettySQL} import org.apache.spark.sql.types.{DataType, DoubleType, FloatType} @@ -95,4 +96,8 @@ private[sql] trait QueryErrorsBase { def toSQLExpr(e: Expression): String = { quoteByDefault(toPrettySQL(e)) } + + def getSummary(context: Option[SQLQueryContext]): String = { + context.map(_.summary).getOrElse("") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 1ef31673d6a..e0b08df940d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.fs.{FileAlreadyExistsException, FileStatus, Path} import org.apache.hadoop.fs.permission.FsPermission import org.codehaus.commons.compiler.{CompileException, InternalCompilerException} -import org.apache.spark.{Partition, SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkClassNotFoundException, SparkConcurrentModificationException, SparkDateTimeException, SparkException, SparkFileAlreadyExistsException, SparkFileNotFoundException, SparkIllegalArgumentException, SparkIndexOutOfBoundsException, SparkNoSuchElementException, SparkNoSuchMethodException, SparkNumberFormatException, SparkRuntimeException, SparkSecurityException, SparkSQLException, SparkSQLFea [...] +import org.apache.spark._ import org.apache.spark.executor.CommitDeniedException import org.apache.spark.launcher.SparkLauncher import org.apache.spark.memory.SparkOutOfMemoryError @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ValueInterval -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, TreeNode} import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, DateTimeUtils, FailFastMode} import org.apache.spark.sql.connector.catalog.{CatalogNotFoundException, Identifier, Table, TableProvider} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -88,14 +88,16 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { toSQLValue(t, from), toSQLType(from), toSQLType(to), - toSQLConf(SQLConf.ANSI_ENABLED.key))) + toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = None, + summary = "") } def cannotChangeDecimalPrecisionError( value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: String): ArithmeticException = { + context: Option[SQLQueryContext] = None): ArithmeticException = { new SparkArithmeticException( errorClass = "CANNOT_CHANGE_DECIMAL_PRECISION", messageParameters = Array( @@ -103,14 +105,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { decimalPrecision.toString, decimalScale.toString, toSQLConf(SQLConf.ANSI_ENABLED.key)), - queryContext = context) + context = context, + summary = getSummary(context)) } def invalidInputInCastToDatetimeError( value: Any, from: DataType, to: DataType, - errorContext: String): Throwable = { + context: Option[SQLQueryContext]): Throwable = { new SparkDateTimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Array( @@ -118,12 +121,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { toSQLType(from), toSQLType(to), toSQLConf(SQLConf.ANSI_ENABLED.key)), - queryContext = errorContext) + context = context, + summary = getSummary(context)) } def invalidInputSyntaxForBooleanError( s: UTF8String, - errorContext: String): SparkRuntimeException = { + context: Option[SQLQueryContext]): SparkRuntimeException = { new SparkRuntimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Array( @@ -131,13 +135,14 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { toSQLType(StringType), toSQLType(BooleanType), toSQLConf(SQLConf.ANSI_ENABLED.key)), - queryContext = errorContext) + context = context, + summary = getSummary(context)) } def invalidInputInCastToNumberError( to: DataType, s: UTF8String, - errorContext: String): SparkNumberFormatException = { + context: Option[SQLQueryContext]): SparkNumberFormatException = { new SparkNumberFormatException( errorClass = "CAST_INVALID_INPUT", messageParameters = Array( @@ -145,7 +150,8 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { toSQLType(StringType), toSQLType(to), toSQLConf(SQLConf.ANSI_ENABLED.key)), - queryContext = errorContext) + context = context, + summary = getSummary(context)) } def cannotCastFromNullTypeError(to: DataType): Throwable = { @@ -175,30 +181,32 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { messageParameters = Array(funcCls, inputTypes, outputType), e) } - def divideByZeroError(context: String): ArithmeticException = { + def divideByZeroError(context: Option[SQLQueryContext]): ArithmeticException = { new SparkArithmeticException( errorClass = "DIVIDE_BY_ZERO", messageParameters = Array(toSQLConf(SQLConf.ANSI_ENABLED.key)), - queryContext = context) + context = context, + summary = getSummary(context)) } def invalidArrayIndexError( index: Int, numElements: Int, - context: String): ArrayIndexOutOfBoundsException = { + context: Option[SQLQueryContext]): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX", messageParameters = Array( toSQLValue(index, IntegerType), toSQLValue(numElements, IntegerType), toSQLConf(SQLConf.ANSI_ENABLED.key)), - queryContext = context) + context = context, + summary = getSummary(context)) } def invalidElementAtIndexError( index: Int, numElements: Int, - context: String): ArrayIndexOutOfBoundsException = { + context: Option[SQLQueryContext]): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", messageParameters = @@ -206,30 +214,39 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { toSQLValue(index, IntegerType), toSQLValue(numElements, IntegerType), toSQLConf(SQLConf.ANSI_ENABLED.key)), - queryContext = context) + context = context, + summary = getSummary(context)) } - def mapKeyNotExistError(key: Any, dataType: DataType, context: String): NoSuchElementException = { + def mapKeyNotExistError( + key: Any, + dataType: DataType, + context: Option[SQLQueryContext]): NoSuchElementException = { new SparkNoSuchElementException( errorClass = "MAP_KEY_DOES_NOT_EXIST", messageParameters = Array( toSQLValue(key, dataType), toSQLConf(SQLConf.ANSI_ENABLED.key)), - queryContext = context) + context = context, + summary = getSummary(context)) } def invalidFractionOfSecondError(): DateTimeException = { new SparkDateTimeException( errorClass = "INVALID_FRACTION_OF_SECOND", errorSubClass = None, - Array(toSQLConf(SQLConf.ANSI_ENABLED.key))) + Array(toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = None, + summary = "") } def ansiDateTimeParseError(e: Exception): SparkDateTimeException = { new SparkDateTimeException( errorClass = "CANNOT_PARSE_TIMESTAMP", errorSubClass = None, - Array(e.getMessage, toSQLConf(SQLConf.ANSI_ENABLED.key))) + Array(e.getMessage, toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = None, + summary = "") } def ansiDateTimeError(e: DateTimeException): DateTimeException = { @@ -254,11 +271,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { ansiIllegalArgumentError(e.getMessage) } - def overflowInSumOfDecimalError(context: String): ArithmeticException = { - arithmeticOverflowError("Overflow in sum of decimals", errorContext = context) + def overflowInSumOfDecimalError(context: Option[SQLQueryContext]): ArithmeticException = { + arithmeticOverflowError("Overflow in sum of decimals", context = context) } - def overflowInIntegralDivideError(context: String): ArithmeticException = { + def overflowInIntegralDivideError(context: Option[SQLQueryContext]): ArithmeticException = { arithmeticOverflowError("Overflow in integral divide", "try_divide", context) } @@ -474,14 +491,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { def arithmeticOverflowError( message: String, hint: String = "", - errorContext: String = ""): ArithmeticException = { + context: Option[SQLQueryContext] = None): ArithmeticException = { val alternative = if (hint.nonEmpty) { s" Use '$hint' to tolerate overflow and return NULL instead." } else "" new SparkArithmeticException( errorClass = "ARITHMETIC_OVERFLOW", messageParameters = Array(message, alternative, SQLConf.ANSI_ENABLED.key), - queryContext = errorContext) + context = context, + summary = getSummary(context)) } def unaryMinusCauseOverflowError(originValue: Int): ArithmeticException = { @@ -2019,7 +2037,9 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { errorClass = "DATETIME_OVERFLOW", messageParameters = Array( s"add ${toSQLValue(amount, IntegerType)} $unit to " + - s"${toSQLValue(DateTimeUtils.microsToInstant(micros), TimestampType)}")) + s"${toSQLValue(DateTimeUtils.microsToInstant(micros), TimestampType)}"), + context = None, + summary = "") } def invalidBucketFile(path: String): Throwable = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index f4f54d2f93b..00172f69fda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -22,6 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal, BigInteger, MathContext, Roundin import scala.util.Try import org.apache.spark.annotation.Unstable +import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.unsafe.types.UTF8String @@ -366,7 +367,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { scale: Int, roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP, nullOnOverflow: Boolean = true, - context: String = ""): Decimal = { + context: Option[SQLQueryContext] = None): Decimal = { val copy = clone() if (copy.changePrecision(precision, scale, roundMode)) { copy @@ -631,7 +632,7 @@ object Decimal { def fromStringANSI( str: UTF8String, to: DecimalType = DecimalType.USER_DEFAULT, - errorContext: String = ""): Decimal = { + context: Option[SQLQueryContext] = None): Decimal = { try { val bigDecimal = stringToJavaBigDecimal(str) // We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow. @@ -644,7 +645,7 @@ object Decimal { } } catch { case _: NumberFormatException => - throw QueryExecutionErrors.invalidInputInCastToNumberError(to, str, errorContext) + throw QueryExecutionErrors.invalidInputInCastToNumberError(to, str, context) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala index 637510d81b0..d96ca4b87f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -85,13 +85,13 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { startIndex = Some(7), stopIndex = Some(30), sqlText = Some(query)) - val expr1 = withOrigin(origin) { CheckOverflow(Literal(d), DecimalType(4, 3), false) } checkExceptionInExpression[ArithmeticException](expr1, query) - val expr2 = CheckOverflowInSum(Literal(d), DecimalType(4, 3), false, queryContext = query) + val expr2 = CheckOverflowInSum( + Literal(d), DecimalType(4, 3), false, context = Some(origin.context)) checkExceptionInExpression[ArithmeticException](expr2, query) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 1e1206c0e1e..442bd01aa5b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -875,7 +875,7 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { sqlText = Some(text), objectType = Some("VIEW"), objectName = Some("some_view")) - val expected = + val expectedSummary = """== SQL of VIEW some_view(line 3, position 39) == |...7890 + 1234567890 + 1234567890, cast('a' | ^^^^^^^^ @@ -885,7 +885,16 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { |^^^^^ |""".stripMargin - assert(origin.context == expected) + val expectedFragment = + """cast('a' + |as /* comment */ + |int)""".stripMargin + assert(origin.context.summary == expectedSummary) + assert(origin.context.startIndex == origin.startIndex.get) + assert(origin.context.stopIndex == origin.stopIndex.get) + assert(origin.context.objectType == origin.objectType.get) + assert(origin.context.objectName == origin.objectName.get) + assert(origin.context.fragment == expectedFragment) } test("SPARK-39046: Return an empty context string if TreeNode.origin is wrongly set") { @@ -921,7 +930,7 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { stopIndex = Some(1), sqlText = text) Seq(origin1, origin2, origin3, origin4, origin5, origin6).foreach { origin => - assert(origin.context.isEmpty) + assert(origin.context.summary.isEmpty) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org