This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f82a5e75cfb [SPARK-40749][SQL] Migrate type check failures of 
generators onto error classes
f82a5e75cfb is described below

commit f82a5e75cfbc0d5dea249029354737e811765e6a
Author: panbingkun <pbk1...@gmail.com>
AuthorDate: Fri Nov 4 12:21:35 2022 +0300

    [SPARK-40749][SQL] Migrate type check failures of generators onto error 
classes
    
    ### What changes were proposed in this pull request?
    This pr aims to
    A.check error classes in GeneratorFunctionSuite by using checkError()
    B.replaces TypeCheckFailure by DataTypeMismatch in type checks in the 
generator expressions, includes:
    
    1. Stack (3): 
https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala#L163-L170
    2. ExplodeBase (1): 
https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala#L299
    3. Inline (1):
    
https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala#L441
    
    ### Why are the changes needed?
    Migration onto error classes unifies Spark SQL error messages.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. The PR changes user-facing error messages.
    
    ### How was this patch tested?
    1. Add new UT
    2. Update existed UT
    3. Pass GA
    
    Closes #38482 from panbingkun/SPARK-40749.
    
    Authored-by: panbingkun <pbk1...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 core/src/main/resources/error/error-classes.json   |   5 +
 .../sql/catalyst/expressions/generators.scala      |  74 ++++++--
 .../analysis/ExpressionTypeCheckingSuite.scala     |  28 ++-
 .../apache/spark/sql/GeneratorFunctionSuite.scala  | 198 +++++++++++++++------
 4 files changed, 236 insertions(+), 69 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json 
b/core/src/main/resources/error/error-classes.json
index 7fc806752be..f4b7874217a 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -335,6 +335,11 @@
           "The lower bound of a window frame must be <comparison> to the upper 
bound."
         ]
       },
+      "STACK_COLUMN_DIFF_TYPES" : {
+        "message" : [
+          "The data type of the column (<columnIndex>) do not have the same 
type: <leftType> (<leftParamIndex>) <> <rightType> (<rightParamIndex>)."
+        ]
+      },
       "UNEXPECTED_CLASS_TYPE" : {
         "message" : [
           "class <className> not found"
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index d305b4d3700..1d60dd3795e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -22,6 +22,8 @@ import scala.collection.mutable
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
+import org.apache.spark.sql.catalyst.expressions.Cast._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.trees.TreePattern.{GENERATOR, TreePattern}
@@ -160,16 +162,54 @@ case class Stack(children: Seq[Expression]) extends 
Generator {
 
   override def checkInputDataTypes(): TypeCheckResult = {
     if (children.length <= 1) {
-      TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 
arguments.")
-    } else if (children.head.dataType != IntegerType || 
!children.head.foldable || numRows < 1) {
-      TypeCheckResult.TypeCheckFailure("The number of rows must be a positive 
constant integer.")
+      DataTypeMismatch(
+        errorSubClass = "WRONG_NUM_ARGS",
+        messageParameters = Map(
+          "functionName" -> toSQLId(prettyName),
+          "expectedNum" -> "> 1",
+          "actualNum" -> children.length.toString)
+      )
+    } else if (children.head.dataType != IntegerType) {
+      DataTypeMismatch(
+        errorSubClass = "UNEXPECTED_INPUT_TYPE",
+        messageParameters = Map(
+          "paramIndex" -> "1",
+          "requiredType" -> toSQLType(IntegerType),
+          "inputSql" -> toSQLExpr(children.head),
+          "inputType" -> toSQLType(children.head.dataType))
+      )
+    } else if (!children.head.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> "n",
+          "inputType" -> toSQLType(IntegerType),
+          "inputExpr" -> toSQLExpr(children.head)
+        )
+      )
+    } else if (numRows < 1) {
+      DataTypeMismatch(
+        errorSubClass = "VALUE_OUT_OF_RANGE",
+        messageParameters = Map(
+          "exprName" -> toSQLId("n"),
+          "valueRange" -> s"(0, ${Int.MaxValue}]",
+          "currentValue" -> toSQLValue(numRows, children.head.dataType)
+        )
+      )
     } else {
       for (i <- 1 until children.length) {
         val j = (i - 1) % numFields
         if (children(i).dataType != elementSchema.fields(j).dataType) {
-          return TypeCheckResult.TypeCheckFailure(
-            s"Argument ${j + 1} 
(${elementSchema.fields(j).dataType.catalogString}) != " +
-              s"Argument $i (${children(i).dataType.catalogString})")
+          return DataTypeMismatch(
+            errorSubClass = "STACK_COLUMN_DIFF_TYPES",
+            messageParameters = Map(
+              "columnIndex" -> j.toString,
+              "leftParamIndex" -> (j + 1).toString,
+              "leftType" -> toSQLType(elementSchema.fields(j).dataType),
+              "rightParamIndex" -> i.toString,
+              "rightType" -> toSQLType(children(i).dataType)
+            )
+          )
         }
       }
       TypeCheckResult.TypeCheckSuccess
@@ -296,9 +336,14 @@ abstract class ExplodeBase extends UnaryExpression with 
CollectionGenerator with
     case _: ArrayType | _: MapType =>
       TypeCheckResult.TypeCheckSuccess
     case _ =>
-      TypeCheckResult.TypeCheckFailure(
-        "input to function explode should be array or map type, " +
-          s"not ${child.dataType.catalogString}")
+      DataTypeMismatch(
+        errorSubClass = "UNEXPECTED_INPUT_TYPE",
+        messageParameters = Map(
+          "paramIndex" -> "1",
+          "requiredType" -> toSQLType(TypeCollection(ArrayType, MapType)),
+          "inputSql" -> toSQLExpr(child),
+          "inputType" -> toSQLType(child.dataType))
+      )
   }
 
   // hive-compatible default alias for explode function ("col" for array, 
"key", "value" for map)
@@ -438,9 +483,14 @@ case class Inline(child: Expression) extends 
UnaryExpression with CollectionGene
     case ArrayType(st: StructType, _) =>
       TypeCheckResult.TypeCheckSuccess
     case _ =>
-      TypeCheckResult.TypeCheckFailure(
-        s"input to function $prettyName should be array of struct type, " +
-          s"not ${child.dataType.catalogString}")
+      DataTypeMismatch(
+        errorSubClass = "UNEXPECTED_INPUT_TYPE",
+        messageParameters = Map(
+          "paramIndex" -> "1",
+          "requiredType" -> toSQLType("ARRAY<STRUCT>"),
+          "inputSql" -> toSQLExpr(child),
+          "inputType" -> toSQLType(child.dataType))
+      )
   }
 
   override def elementSchema: StructType = child.dataType match {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index eb2ebce3a5f..f656131c8e7 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -519,10 +519,30 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite 
with SQLHelper with Quer
         "expectedNum" -> "> 0",
         "actualNum" -> "0"))
 
-    assertError(Explode($"intField"),
-      "input to function explode should be array or map type")
-    assertError(PosExplode($"intField"),
-      "input to function explode should be array or map type")
+    checkError(
+      exception = intercept[AnalysisException] {
+        assertSuccess(Explode($"intField"))
+      },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> "\"explode(intField)\"",
+        "paramIndex" -> "1",
+        "inputSql" -> "\"intField\"",
+        "inputType" -> "\"INT\"",
+        "requiredType" -> "(\"ARRAY\" or \"MAP\")"))
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        assertSuccess(PosExplode($"intField"))
+      },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> "\"posexplode(intField)\"",
+        "paramIndex" -> "1",
+        "inputSql" -> "\"intField\"",
+        "inputType" -> "\"INT\"",
+        "requiredType" -> "(\"ARRAY\" or \"MAP\")")
+    )
   }
 
   test("check types for CreateNamedStruct") {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
index 3fb66f08cea..abec582d43a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -55,36 +55,101 @@ class GeneratorFunctionSuite extends QueryTest with 
SharedSparkSession {
       Row(1, 2) :: Row(3, null) :: Row(1, 2) :: Row(3, null) :: Nil)
 
     // The first argument must be a positive constant integer.
-    val m = intercept[AnalysisException] {
-      df.selectExpr("stack(1.1, 1, 2, 3)")
-    }.getMessage
-    assert(m.contains("The number of rows must be a positive constant 
integer."))
-    val m2 = intercept[AnalysisException] {
-      df.selectExpr("stack(-1, 1, 2, 3)")
-    }.getMessage
-    assert(m2.contains("The number of rows must be a positive constant 
integer."))
+    checkError(
+      exception = intercept[AnalysisException] {
+        df.selectExpr("stack(1.1, 1, 2, 3)")
+      },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> "\"stack(1.1, 1, 2, 3)\"",
+        "paramIndex" -> "1",
+        "inputSql" -> "\"1.1\"",
+        "inputType" -> "\"DECIMAL(2,1)\"",
+        "requiredType" -> "\"INT\""),
+      context = ExpectedContext(
+        fragment = "stack(1.1, 1, 2, 3)",
+        start = 0,
+        stop = 18
+      )
+    )
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        df.selectExpr("stack(-1, 1, 2, 3)")
+      },
+      errorClass = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE",
+      parameters = Map(
+        "sqlExpr" -> "\"stack(-1, 1, 2, 3)\"",
+        "exprName" -> "`n`",
+        "valueRange" -> "(0, 2147483647]",
+        "currentValue" -> "-1"),
+      context = ExpectedContext(
+        fragment = "stack(-1, 1, 2, 3)",
+        start = 0,
+        stop = 17
+      )
+    )
 
     // The data for the same column should have the same type.
-    val m3 = intercept[AnalysisException] {
-      df.selectExpr("stack(2, 1, '2.2')")
-    }.getMessage
-    assert(m3.contains("data type mismatch: Argument 1 (int) != Argument 2 
(string)"))
+    checkError(
+      exception = intercept[AnalysisException] {
+        df.selectExpr("stack(2, 1, '2.2')")
+      },
+      errorClass = "DATATYPE_MISMATCH.STACK_COLUMN_DIFF_TYPES",
+      parameters = Map(
+        "sqlExpr" -> "\"stack(2, 1, 2.2)\"",
+        "columnIndex" -> "0",
+        "leftParamIndex" -> "1",
+        "leftType" -> "\"INT\"",
+        "rightParamIndex" -> "2",
+        "rightType" -> "\"STRING\""),
+      context = ExpectedContext(
+        fragment = "stack(2, 1, '2.2')",
+        start = 0,
+        stop = 17
+      )
+    )
 
     // stack on column data
     val df2 = Seq((2, 1, 2, 3)).toDF("n", "a", "b", "c")
     checkAnswer(df2.selectExpr("stack(2, a, b, c)"), Row(1, 2) :: Row(3, null) 
:: Nil)
 
-    val m4 = intercept[AnalysisException] {
-      df2.selectExpr("stack(n, a, b, c)")
-    }.getMessage
-    assert(m4.contains("The number of rows must be a positive constant 
integer."))
+    checkError(
+      exception = intercept[AnalysisException] {
+        df2.selectExpr("stack(n, a, b, c)")
+      },
+      errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+      parameters = Map(
+        "sqlExpr" -> "\"stack(n, a, b, c)\"",
+        "inputName" -> "n",
+        "inputType" -> "\"INT\"",
+        "inputExpr" -> "\"n\""),
+      context = ExpectedContext(
+        fragment = "stack(n, a, b, c)",
+        start = 0,
+        stop = 16
+      )
+    )
 
     val df3 = Seq((2, 1, 2.0)).toDF("n", "a", "b")
-    val m5 = intercept[AnalysisException] {
-      df3.selectExpr("stack(2, a, b)")
-    }.getMessage
-    assert(m5.contains("data type mismatch: Argument 1 (int) != Argument 2 
(double)"))
-
+    checkError(
+      exception = intercept[AnalysisException] {
+        df3.selectExpr("stack(2, a, b)")
+      },
+      errorClass = "DATATYPE_MISMATCH.STACK_COLUMN_DIFF_TYPES",
+      parameters = Map(
+        "sqlExpr" -> "\"stack(2, a, b)\"",
+        "columnIndex" -> "0",
+        "leftParamIndex" -> "1",
+        "leftType" -> "\"INT\"",
+        "rightParamIndex" -> "2",
+        "rightType" -> "\"DOUBLE\""),
+      context = ExpectedContext(
+        fragment = "stack(2, a, b)",
+        start = 0,
+        stop = 13
+      )
+    )
   }
 
   test("single explode") {
@@ -218,10 +283,18 @@ class GeneratorFunctionSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("inline raises exception on array of null type") {
-    val m = intercept[AnalysisException] {
-      spark.range(2).select(inline(array()))
-    }.getMessage
-    assert(m.contains("data type mismatch"))
+    checkError(
+      exception = intercept[AnalysisException] {
+        spark.range(2).select(inline(array()))
+      },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> "\"inline(array())\"",
+        "paramIndex" -> "1",
+        "inputSql" -> "\"array()\"",
+        "inputType" -> "\"ARRAY<VOID>\"",
+        "requiredType" -> "\"ARRAY<STRUCT>\"")
+    )
   }
 
   test("inline with empty table") {
@@ -250,20 +323,30 @@ class GeneratorFunctionSuite extends QueryTest with 
SharedSparkSession {
       Row(1, 2) :: Row(1, 2) :: Nil)
 
     // Spark think [struct<a:int>, struct<b:int>] is heterogeneous due to name 
difference.
-    val m = intercept[AnalysisException] {
-      df.select(inline(array(struct('a), struct('b))))
-    }.getMessage
-    assert(m.contains("data type mismatch"))
+    checkError(
+      exception = intercept[AnalysisException] {
+        df.select(inline(array(struct('a), struct('b))))
+      },
+      errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
+      parameters = Map(
+        "sqlExpr" -> "\"array(struct(a), struct(b))\"",
+        "functionName" -> "`array`",
+        "dataType" -> "(\"STRUCT<a: INT>\" or \"STRUCT<b: INT>\")"))
 
     checkAnswer(
       df.select(inline(array(struct('a), struct('b.alias("a"))))),
       Row(1) :: Row(2) :: Nil)
 
     // Spark think [struct<a:int>, struct<col1:int>] is heterogeneous due to 
name difference.
-    val m2 = intercept[AnalysisException] {
-      df.select(inline(array(struct('a), struct(lit(2)))))
-    }.getMessage
-    assert(m2.contains("data type mismatch"))
+    checkError(
+      exception = intercept[AnalysisException] {
+        df.select(inline(array(struct('a), struct(lit(2)))))
+      },
+      errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
+      parameters = Map(
+        "sqlExpr" -> "\"array(struct(a), struct(2))\"",
+        "functionName" -> "`array`",
+        "dataType" -> "(\"STRUCT<a: INT>\" or \"STRUCT<col1: INT>\")"))
 
     checkAnswer(
       df.select(inline(array(struct('a), struct(lit(2).alias("a"))))),
@@ -330,30 +413,39 @@ class GeneratorFunctionSuite extends QueryTest with 
SharedSparkSession {
         Row(1, 2) :: Row(1, 3) :: Nil
       )
 
-      val msg1 = intercept[AnalysisException] {
-        sql("select 1 + explode(array(min(c2), max(c2))) from t1 group by c1")
-      }.getMessage
-      assert(msg1.contains("The generator is not supported: nested in 
expressions"))
-
-      val msg2 = intercept[AnalysisException] {
-        sql(
-          """select
-            |  explode(array(min(c2), max(c2))),
-            |  posexplode(array(min(c2), max(c2)))
-            |from t1 group by c1
-          """.stripMargin)
-      }.getMessage
-      assert(msg2.contains("The generator is not supported: " +
-        "only one generator allowed per aggregate clause"))
+      checkError(
+        exception = intercept[AnalysisException] {
+          sql("select 1 + explode(array(min(c2), max(c2))) from t1 group by 
c1")
+        },
+        errorClass = "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS",
+        parameters = Map(
+          "expression" -> "\"(1 + explode(array(min(c2), max(c2))))\""))
+
+
+      checkError(
+        exception = intercept[AnalysisException] {
+          sql(
+            """select
+              |  explode(array(min(c2), max(c2))),
+              |  posexplode(array(min(c2), max(c2)))
+              |from t1 group by c1""".stripMargin)
+        },
+        errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
+        parameters = Map(
+          "clause" -> "aggregate",
+          "num" -> "2",
+          "generators" -> ("\"explode(array(min(c2), max(c2)))\", " +
+            "\"posexplode(array(min(c2), max(c2)))\"")))
     }
   }
 
   test("SPARK-30998: Unsupported nested inner generators") {
-    val errMsg = intercept[AnalysisException] {
-      sql("SELECT array(array(1, 2), array(3)) 
v").select(explode(explode($"v"))).collect
-    }.getMessage
-    assert(errMsg.contains("The generator is not supported: " +
-      """nested in expressions "explode(explode(v))""""))
+    checkError(
+      exception = intercept[AnalysisException] {
+        sql("SELECT array(array(1, 2), array(3)) 
v").select(explode(explode($"v"))).collect
+      },
+      errorClass = "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS",
+      parameters = Map("expression" -> "\"explode(explode(v))\""))
   }
 
   test("SPARK-30997: generators in aggregate expressions for dataframe") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to