This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 74668e2bf14 [SPARK-40751][SQL] Migrate type check failures of high 
order functions onto error classes
74668e2bf14 is described below

commit 74668e2bf14760dbc60509f7736f410c09084697
Author: panbingkun <pbk1...@gmail.com>
AuthorDate: Thu Oct 27 13:47:54 2022 +0300

    [SPARK-40751][SQL] Migrate type check failures of high order functions onto 
error classes
    
    ### What changes were proposed in this pull request?
    This pr aims to replace TypeCheckFailure by DataTypeMismatch in type checks 
in the high-order functions expressions, includes:
    - 1. ArraySort (2):
    
https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L403-L407
    - 2. ArrayAggregate (1):
    
https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L807
    - 3. MapZipWith (1):
    
https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L1028
    
    ### Why are the changes needed?
    Migration onto error classes unifies Spark SQL error messages.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. The PR changes user-facing error messages.
    
    ### How was this patch tested?
    - Update existed UT
    - Pass GA.
    
    Closes #38359 from panbingkun/SPARK-40751.
    
    Authored-by: panbingkun <pbk1...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 core/src/main/resources/error/error-classes.json   | 10 +++
 .../scala/org/apache/spark/SparkFunSuite.scala     |  6 ++
 .../expressions/higherOrderFunctions.scala         | 43 ++++++++---
 .../expressions/HigherOrderFunctionsSuite.scala    | 18 +++++
 .../results/typeCoercion/native/mapZipWith.sql.out | 35 ++++++++-
 .../apache/spark/sql/DataFrameFunctionsSuite.scala | 89 +++++++++++++++++-----
 6 files changed, 171 insertions(+), 30 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json 
b/core/src/main/resources/error/error-classes.json
index d72eeece82e..015d86171d7 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -200,6 +200,11 @@
           "The <functionName> accepts only arrays of pair structs, but 
<childExpr> is of <childType>."
         ]
       },
+      "MAP_ZIP_WITH_DIFF_TYPES" : {
+        "message" : [
+          "Input to the <functionName> should have been two maps with 
compatible key types, but it's [<leftType>, <rightType>]."
+        ]
+      },
       "NON_FOLDABLE_INPUT" : {
         "message" : [
           "the input <inputName> should be a foldable <inputType> expression; 
however, got <inputExpr>."
@@ -275,6 +280,11 @@
           "The <exprName> must not be null"
         ]
       },
+      "UNEXPECTED_RETURN_TYPE" : {
+        "message" : [
+          "The <functionName> requires return <expectedType> type, but the 
actual is <actualType> type."
+        ]
+      },
       "UNEXPECTED_STATIC_METHOD" : {
         "message" : [
           "cannot find a static method <methodName> that matches the argument 
types in <className>"
diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala 
b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
index 46b62d879cf..7a08de9c181 100644
--- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
@@ -370,6 +370,12 @@ abstract class SparkFunSuite
     checkError(exception, errorClass, sqlState, parameters,
       false, Array(context))
 
+  protected def checkErrorMatchPVals(
+      exception: SparkThrowable,
+      errorClass: String,
+      parameters: Map[String, String]): Unit =
+    checkError(exception, errorClass, None, parameters, matchPVals = true)
+
   protected def checkErrorMatchPVals(
       exception: SparkThrowable,
       errorClass: String,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index 98513fb5ddd..b59860ff181 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -24,6 +24,8 @@ import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, 
UnresolvedException}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
+import org.apache.spark.sql.catalyst.expressions.Cast._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, 
TernaryLike}
 import org.apache.spark.sql.catalyst.trees.TreePattern._
@@ -400,11 +402,25 @@ case class ArraySort(
             if (function.dataType == IntegerType) {
               TypeCheckResult.TypeCheckSuccess
             } else {
-              TypeCheckResult.TypeCheckFailure("Return type of the given 
function has to be " +
-                "IntegerType")
+              DataTypeMismatch(
+                errorSubClass = "UNEXPECTED_RETURN_TYPE",
+                messageParameters = Map(
+                  "functionName" -> toSQLId(function.prettyName),
+                  "expectedType" -> toSQLType(IntegerType),
+                  "actualType" -> toSQLType(function.dataType)
+                )
+              )
             }
           case _ =>
-            TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array 
input.")
+            DataTypeMismatch(
+              errorSubClass = "UNEXPECTED_INPUT_TYPE",
+              messageParameters = Map(
+                "paramIndex" -> "1",
+                "requiredType" -> toSQLType(ArrayType),
+                "inputSql" -> toSQLExpr(argument),
+                "inputType" -> toSQLType(argument.dataType)
+              )
+            )
         }
       case failure => failure
     }
@@ -804,9 +820,13 @@ case class ArrayAggregate(
       case TypeCheckResult.TypeCheckSuccess =>
         if (!DataType.equalsStructurally(
             zero.dataType, merge.dataType, ignoreNullability = true)) {
-          TypeCheckResult.TypeCheckFailure(
-            s"argument 3 requires ${zero.dataType.simpleString} type, " +
-              s"however, '${merge.sql}' is of ${merge.dataType.catalogString} 
type.")
+          DataTypeMismatch(
+            errorSubClass = "UNEXPECTED_INPUT_TYPE",
+            messageParameters = Map(
+              "paramIndex" -> "3",
+              "requiredType" -> toSQLType(zero.dataType),
+              "inputSql" -> toSQLExpr(merge),
+              "inputType" -> toSQLType(merge.dataType)))
         } else {
           TypeCheckResult.TypeCheckSuccess
         }
@@ -1025,9 +1045,14 @@ case class MapZipWith(left: Expression, right: 
Expression, function: Expression)
         if (leftKeyType.sameType(rightKeyType)) {
           TypeUtils.checkForOrderingExpr(leftKeyType, prettyName)
         } else {
-          TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName 
should have " +
-            s"been two ${MapType.simpleString}s with compatible key types, but 
the key types are " +
-            s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].")
+          DataTypeMismatch(
+            errorSubClass = "MAP_ZIP_WITH_DIFF_TYPES",
+            messageParameters = Map(
+              "functionName" -> toSQLId(prettyName),
+              "leftType" -> toSQLType(leftKeyType),
+              "rightType" -> toSQLType(rightKeyType)
+            )
+          )
         }
       case failure => failure
     }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
index a6546d8a5db..5f62dc97086 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
+import org.apache.spark.sql.catalyst.expressions.Cast._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -859,4 +861,20 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
         Seq(1, 1, 2, 3))
     }
   }
+
+  test("Return type of the given function has to be IntegerType") {
+    val comparator = {
+      val comp = ArraySort.comparator _
+      (left: Expression, right: Expression) => Literal.create("hello", 
StringType)
+    }
+
+    val result = arraySort(Literal.create(Seq(3, 1, 1, 2)), 
comparator).checkInputDataTypes()
+    assert(result == DataTypeMismatch(
+      errorSubClass = "UNEXPECTED_RETURN_TYPE",
+      messageParameters = Map(
+        "functionName" -> toSQLId("lambdafunction"),
+        "expectedType" -> toSQLType(IntegerType),
+        "actualType" -> toSQLType(StringType)
+      )))
+  }
 }
diff --git 
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
index 2f176951df8..09c6e10f762 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
@@ -82,8 +82,22 @@ FROM various_maps
 struct<>
 -- !query output
 org.apache.spark.sql.AnalysisException
-cannot resolve 'map_zip_with(various_maps.decimal_map1, 
various_maps.decimal_map2, lambdafunction(struct(k, v1, v2), k, v1, v2))' due 
to argument data type mismatch: The input to function map_zip_with should have 
been two maps with compatible key types, but the key types are [decimal(36,0), 
decimal(36,35)].; line 1 pos 7
-
+{
+  "errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
+  "messageParameters" : {
+    "functionName" : "`map_zip_with`",
+    "leftType" : "\"DECIMAL(36,0)\"",
+    "rightType" : "\"DECIMAL(36,35)\"",
+    "sqlExpr" : "\"map_zip_with(decimal_map1, decimal_map2, 
lambdafunction(struct(k, v1, v2), k, v1, v2))\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 81,
+    "fragment" : "map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> 
struct(k, v1, v2))"
+  } ]
+}
 
 -- !query
 SELECT map_zip_with(decimal_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m
@@ -110,7 +124,22 @@ FROM various_maps
 struct<>
 -- !query output
 org.apache.spark.sql.AnalysisException
-cannot resolve 'map_zip_with(various_maps.decimal_map2, various_maps.int_map, 
lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type 
mismatch: The input to function map_zip_with should have been two maps with 
compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7
+{
+  "errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
+  "messageParameters" : {
+    "functionName" : "`map_zip_with`",
+    "leftType" : "\"DECIMAL(36,35)\"",
+    "rightType" : "\"INT\"",
+    "sqlExpr" : "\"map_zip_with(decimal_map2, int_map, 
lambdafunction(struct(k, v1, v2), k, v1, v2))\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 76,
+    "fragment" : "map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, 
v1, v2))"
+  } ]
+}
 
 
 -- !query
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 85877c97ed5..3f02429fe62 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -533,6 +533,22 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
     )
   }
 
+  test("The given function only supports array input") {
+    val df = Seq(1, 2, 3).toDF("a")
+    checkErrorMatchPVals(
+      exception = intercept[AnalysisException] {
+        df.select(array_sort(col("a"), (x, y) => x - y))
+      },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> """"array_sort\(a, lambdafunction\(\(x_\d+ - y_\d+\), 
x_\d+, y_\d+\)\)"""",
+        "paramIndex" -> "1",
+        "requiredType" -> "\"ARRAY\"",
+        "inputSql" -> "\"a\"",
+        "inputType" -> "\"INT\""
+      ))
+  }
+
   test("sort_array/array_sort functions") {
     val df = Seq(
       (Array[Int](2, 1, 3), Array("b", "c", "a")),
@@ -3492,15 +3508,35 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
         "requiredType" -> "\"ARRAY\""))
     // scalastyle:on line.size.limit
 
-    val ex4 = intercept[AnalysisException] {
-      df.selectExpr("aggregate(s, 0, (acc, x) -> x)")
-    }
-    assert(ex4.getMessage.contains("data type mismatch: argument 3 requires 
int type"))
+    // scalastyle:off line.size.limit
+    checkError(
+      exception = intercept[AnalysisException] {
+        df.selectExpr("aggregate(s, 0, (acc, x) -> x)")
+      },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> """"aggregate(s, 0, lambdafunction(namedlambdavariable(), 
namedlambdavariable(), namedlambdavariable()), 
lambdafunction(namedlambdavariable(), namedlambdavariable()))"""",
+        "paramIndex" -> "3",
+        "inputSql" -> "\"lambdafunction(namedlambdavariable(), 
namedlambdavariable(), namedlambdavariable())\"",
+        "inputType" -> "\"STRING\"",
+        "requiredType" -> "\"INT\""
+      ))
+    // scalastyle:on line.size.limit
 
-    val ex4a = intercept[AnalysisException] {
-      df.select(aggregate(col("s"), lit(0), (acc, x) => x))
-    }
-    assert(ex4a.getMessage.contains("data type mismatch: argument 3 requires 
int type"))
+    // scalastyle:off line.size.limit
+    checkError(
+      exception = intercept[AnalysisException] {
+        df.select(aggregate(col("s"), lit(0), (acc, x) => x))
+      },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> """"aggregate(s, 0, lambdafunction(namedlambdavariable(), 
namedlambdavariable(), namedlambdavariable()), 
lambdafunction(namedlambdavariable(), namedlambdavariable()))"""",
+        "paramIndex" -> "3",
+        "inputSql" -> "\"lambdafunction(namedlambdavariable(), 
namedlambdavariable(), namedlambdavariable())\"",
+        "inputType" -> "\"STRING\"",
+        "requiredType" -> "\"INT\""
+      ))
+    // scalastyle:on line.size.limit
 
     checkError(
       exception =
@@ -3570,17 +3606,34 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
     }
     assert(ex1.getMessage.contains("The number of lambda function arguments 
'2' does not match"))
 
-    val ex2 = intercept[AnalysisException] {
-      df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))")
-    }
-    assert(ex2.getMessage.contains("The input to function map_zip_with should 
have " +
-      "been two maps with compatible key types"))
+    checkError(
+      exception = intercept[AnalysisException] {
+        df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))")
+      },
+      errorClass = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
+      parameters = Map(
+        "sqlExpr" -> "\"map_zip_with(mis, mmi, lambdafunction(concat(x, y, z), 
x, y, z))\"",
+        "functionName" -> "`map_zip_with`",
+        "leftType" -> "\"INT\"",
+        "rightType" -> "\"MAP<INT, INT>\""),
+      context = ExpectedContext(
+        fragment = "map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))",
+        start = 0,
+        stop = 51))
 
-    val ex2a = intercept[AnalysisException] {
-      df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, y, 
z)))
-    }
-    assert(ex2a.getMessage.contains("The input to function map_zip_with should 
have " +
-      "been two maps with compatible key types"))
+    // scalastyle:off line.size.limit
+    checkError(
+      exception = intercept[AnalysisException] {
+        df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, 
y, z)))
+      },
+      errorClass = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
+      matchPVals = true,
+      parameters = Map(
+        "sqlExpr" -> """"map_zip_with\(mis, mmi, 
lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""",
+        "functionName" -> "`map_zip_with`",
+        "leftType" -> "\"INT\"",
+        "rightType" -> "\"MAP<INT, INT>\""))
+    // scalastyle:on line.size.limit
 
     checkError(
       exception = intercept[AnalysisException] {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to