This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 625f76dae0d [SPARK-40760][SQL] Migrate type check failures of interval
expressions onto error classes
625f76dae0d is described below
commit 625f76dae0d9581428d6c5c4b58bf2958957c8c8
Author: Max Gekk <[email protected]>
AuthorDate: Sun Oct 23 13:32:34 2022 +0500
[SPARK-40760][SQL] Migrate type check failures of interval expressions onto
error classes
### What changes were proposed in this pull request?
In the PR, I propose to add new error sub-classes of the error class
`DATATYPE_MISMATCH`, and use it in the case of type check failures of some
interval expressions.
### Why are the changes needed?
Migration onto error classes unifies Spark SQL error messages, and improves
search-ability of errors.
### Does this PR introduce _any_ user-facing change?
Yes. The PR changes user-facing error messages.
### How was this patch tested?
By running the affected test suites:
```
$ build/sbt "test:testOnly *AnalysisSuite"
$ build/sbt "test:testOnly *ExpressionTypeCheckingSuite"
$ build/sbt "test:testOnly *ApproxCountDistinctForIntervalsSuite"
```
Closes #38237 from MaxGekk/type-check-fails-interval-exprs.
Authored-by: Max Gekk <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
core/src/main/resources/error/error-classes.json | 5 +++
.../ApproxCountDistinctForIntervals.scala | 31 +++++++++++---
.../catalyst/expressions/aggregate/Average.scala | 2 +-
.../sql/catalyst/expressions/aggregate/Sum.scala | 2 +-
.../apache/spark/sql/catalyst/util/TypeUtils.scala | 20 +++++----
.../apache/spark/sql/types/AbstractDataType.scala | 9 ++++
.../sql/catalyst/analysis/AnalysisSuite.scala | 50 ++++++++++++++--------
.../analysis/ExpressionTypeCheckingSuite.scala | 26 +++++++++--
.../ApproxCountDistinctForIntervalsSuite.scala | 21 ++++++---
9 files changed, 123 insertions(+), 43 deletions(-)
diff --git a/core/src/main/resources/error/error-classes.json
b/core/src/main/resources/error/error-classes.json
index 5f4db145479..0f9b665718c 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -263,6 +263,11 @@
"The <exprName> must be between <valueRange> (current value =
<currentValue>)"
]
},
+ "WRONG_NUM_ENDPOINTS" : {
+ "message" : [
+ "The number of endpoints must be >= 2 to construct intervals but the
actual number is <actualNumber>."
+ ]
+ },
"WRONG_NUM_PARAMS" : {
"message" : [
"The <functionName> requires <expectedNum> parameters but the actual
number is <actualNum>."
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
index f3bf251ba0b..0be4e4aa465 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
@@ -21,10 +21,11 @@ import java.util
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure,
TypeCheckSuccess}
+import
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes,
Expression, GenericInternalRow}
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData,
HyperLogLogPlusPlusHelper}
+import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
@@ -49,7 +50,10 @@ case class ApproxCountDistinctForIntervals(
relativeSD: Double = 0.05,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
- extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes with
BinaryLike[Expression] {
+ extends TypedImperativeAggregate[Array[Long]]
+ with ExpectsInputTypes
+ with BinaryLike[Expression]
+ with QueryErrorsBase {
def this(child: Expression, endpointsExpression: Expression, relativeSD:
Expression) = {
this(
@@ -77,19 +81,32 @@ case class ApproxCountDistinctForIntervals(
if (defaultCheck.isFailure) {
defaultCheck
} else if (!endpointsExpression.foldable) {
- TypeCheckFailure("The endpoints provided must be constant literals")
+ DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> "endpointsExpression",
+ "inputType" -> toSQLType(endpointsExpression.dataType)))
} else {
endpointsExpression.dataType match {
case ArrayType(_: NumericType | DateType | TimestampType |
TimestampNTZType |
_: AnsiIntervalType, _) =>
if (endpoints.length < 2) {
- TypeCheckFailure("The number of endpoints must be >= 2 to
construct intervals")
+ DataTypeMismatch(
+ errorSubClass = "WRONG_NUM_ENDPOINTS",
+ messageParameters = Map("actualNumber" ->
endpoints.length.toString))
} else {
TypeCheckSuccess
}
- case _ =>
- TypeCheckFailure("Endpoints require (numeric or timestamp or date or
timestamp_ntz or " +
- "interval year to month or interval day to second) type")
+ case inputType =>
+ val requiredElemTypes = toSQLType(TypeCollection(
+ NumericType, DateType, TimestampType, TimestampNTZType,
AnsiIntervalType))
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> "2",
+ "requiredType" -> s"ARRAY OF $requiredElemTypes",
+ "inputSql" -> toSQLExpr(endpointsExpression),
+ "inputType" -> toSQLType(inputType)))
}
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index ae644e9d663..ce9fa0575f2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -54,7 +54,7 @@ case class Average(
Seq(TypeCollection(NumericType, YearMonthIntervalType,
DayTimeIntervalType))
override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "average")
+ TypeUtils.checkForAnsiIntervalOrNumericType(child)
override def nullable: Boolean = true
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 432d4b40b4a..2c892903437 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -67,7 +67,7 @@ case class Sum(
Seq(TypeCollection(NumericType, YearMonthIntervalType,
DayTimeIntervalType))
override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, prettyName)
+ TypeUtils.checkForAnsiIntervalOrNumericType(child)
final override val nodePatterns: Seq[TreePattern] = Seq(SUM)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 7cb471d14bd..0bb5d29c5c4 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -19,15 +19,14 @@ package org.apache.spark.sql.catalyst.util
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
-import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType
-import org.apache.spark.sql.catalyst.expressions.RowOrdering
-import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering}
+import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.types._
/**
* Functions to help with checking for valid data types and value comparison
of various types.
*/
-object TypeUtils {
+object TypeUtils extends QueryErrorsBase {
def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = {
if (RowOrdering.isOrderable(dt)) {
@@ -70,13 +69,18 @@ object TypeUtils {
}
}
- def checkForAnsiIntervalOrNumericType(
- dt: DataType, funcName: String): TypeCheckResult = dt match {
+ def checkForAnsiIntervalOrNumericType(input: Expression): TypeCheckResult =
input.dataType match {
case _: AnsiIntervalType | NullType =>
TypeCheckResult.TypeCheckSuccess
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
- case other => TypeCheckResult.TypeCheckFailure(
- s"function $funcName requires numeric or interval types, not
${other.catalogString}")
+ case other =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> "1",
+ "requiredType" -> Seq(NumericType,
AnsiIntervalType).map(toSQLType).mkString(" or "),
+ "inputSql" -> toSQLExpr(input),
+ "inputType" -> toSQLType(other)))
}
def getNumeric(t: DataType, exactNumericRequired: Boolean = false):
Numeric[Any] = {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index ebcf35a0674..294fb13e48c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -233,3 +233,12 @@ private[sql] abstract class DatetimeType extends AtomicType
* The interval type which conforms to the ANSI SQL standard.
*/
private[sql] abstract class AnsiIntervalType extends AtomicType
+
+private[spark] object AnsiIntervalType extends AbstractDataType {
+ override private[sql] def simpleString: String = "ANSI interval"
+
+ override private[sql] def acceptsType(other: DataType): Boolean =
+ other.isInstanceOf[AnsiIntervalType]
+
+ override private[sql] def defaultConcreteType: DataType =
DayTimeIntervalType()
+}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 3036742c83f..6f0e6ef0c11 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -1163,25 +1163,39 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}
test("SPARK-38118: Func(wrong_type) in the HAVING clause should throw data
mismatch error") {
- assertAnalysisError(parsePlan(
- s"""
- |WITH t as (SELECT true c)
- |SELECT t.c
- |FROM t
- |GROUP BY t.c
- |HAVING mean(t.c) > 0d""".stripMargin),
- Seq(s"cannot resolve 'mean(t.c)' due to data type mismatch"),
- false)
+ assertAnalysisErrorClass(
+ inputPlan = parsePlan(
+ s"""
+ |WITH t as (SELECT true c)
+ |SELECT t.c
+ |FROM t
+ |GROUP BY t.c
+ |HAVING mean(t.c) > 0d""".stripMargin),
+ expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ expectedMessageParameters = Map(
+ "sqlExpr" -> "\"mean(c)\"",
+ "paramIndex" -> "1",
+ "inputSql" -> "\"c\"",
+ "inputType" -> "\"BOOLEAN\"",
+ "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""),
+ caseSensitive = false)
- assertAnalysisError(parsePlan(
- s"""
- |WITH t as (SELECT true c, false d)
- |SELECT (t.c AND t.d) c
- |FROM t
- |GROUP BY t.c, t.d
- |HAVING mean(c) > 0d""".stripMargin),
- Seq(s"cannot resolve 'mean(t.c)' due to data type mismatch"),
- false)
+ assertAnalysisErrorClass(
+ inputPlan = parsePlan(
+ s"""
+ |WITH t as (SELECT true c, false d)
+ |SELECT (t.c AND t.d) c
+ |FROM t
+ |GROUP BY t.c, t.d
+ |HAVING mean(c) > 0d""".stripMargin),
+ expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ expectedMessageParameters = Map(
+ "sqlExpr" -> "\"mean(c)\"",
+ "paramIndex" -> "1",
+ "inputSql" -> "\"c\"",
+ "inputType" -> "\"BOOLEAN\"",
+ "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""),
+ caseSensitive = false)
assertAnalysisErrorClass(
inputPlan = parsePlan(
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 991721a55ca..b41f627bac9 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -396,9 +396,29 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite
with SQLHelper with Quer
"dataType" -> "\"MAP<STRING, BIGINT>\""
)
)
- assertError(Sum($"booleanField"), "function sum requires numeric or
interval types")
- assertError(Average($"booleanField"),
- "function average requires numeric or interval types")
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ assertSuccess(Sum($"booleanField"))
+ },
+ errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ parameters = Map(
+ "sqlExpr" -> "\"sum(booleanField)\"",
+ "paramIndex" -> "1",
+ "inputSql" -> "\"booleanField\"",
+ "inputType" -> "\"BOOLEAN\"",
+ "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""))
+ checkError(
+ exception = intercept[AnalysisException] {
+ assertSuccess(Average($"booleanField"))
+ },
+ errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ parameters = Map(
+ "sqlExpr" -> "\"avg(booleanField)\"",
+ "paramIndex" -> "1",
+ "inputSql" -> "\"booleanField\"",
+ "inputType" -> "\"BOOLEAN\"",
+ "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""))
}
test("check types for others") {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
index d00193c4f3b..bb99e1c1e8e 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
@@ -22,7 +22,7 @@ import java.time.LocalDateTime
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
BoundReference, CreateArray, Literal, SpecificInternalRow}
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils}
import org.apache.spark.sql.types._
@@ -48,20 +48,31 @@ class ApproxCountDistinctForIntervalsSuite extends
SparkFunSuite {
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Seq(AttributeReference("b",
DoubleType)())))
assert(wrongEndpoints.checkInputDataTypes() ==
- TypeCheckFailure("The endpoints provided must be constant literals"))
+ DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> "endpointsExpression",
+ "inputType" -> "\"ARRAY<DOUBLE>\"")))
wrongEndpoints = ApproxCountDistinctForIntervals(
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Array(10L).map(Literal(_))))
assert(wrongEndpoints.checkInputDataTypes() ==
- TypeCheckFailure("The number of endpoints must be >= 2 to construct
intervals"))
+ DataTypeMismatch("WRONG_NUM_ENDPOINTS", Map("actualNumber" -> "1")))
wrongEndpoints = ApproxCountDistinctForIntervals(
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Array("foobar").map(Literal(_))))
+ // scalastyle:off line.size.limit
assert(wrongEndpoints.checkInputDataTypes() ==
- TypeCheckFailure("Endpoints require (numeric or timestamp or date or
timestamp_ntz or " +
- "interval year to month or interval day to second) type"))
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> "2",
+ "requiredType" -> "ARRAY OF (\"NUMERIC\" or \"DATE\" or
\"TIMESTAMP\" or \"TIMESTAMP_NTZ\" or \"ANSI INTERVAL\")",
+ "inputSql" -> "\"array(foobar)\"",
+ "inputType" -> "\"ARRAY<STRING>\"")))
+ // scalastyle:on line.size.limit
}
/** Create an ApproxCountDistinctForIntervals instance and an input and
output buffer. */
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]