This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2acfc1dbca9 [SPARK-40369][CORE][SQL] Migrate the type check failures of calls via reflection onto error classes 2acfc1dbca9 is described below commit 2acfc1dbca975a2a4a38124fe8ebe464aa1663a9 Author: yangjie01 <yangji...@baidu.com> AuthorDate: Tue Oct 18 20:04:56 2022 +0500 [SPARK-40369][CORE][SQL] Migrate the type check failures of calls via reflection onto error classes ### What changes were proposed in this pull request? This pr replace `TypeCheckFailure` by `DataTypeMismatch` in `CallMethodViaReflection`. ### Why are the changes needed? Migration onto error classes unifies Spark SQL error messages. ### Does this PR introduce _any_ user-facing change? Yes. The PR changes user-facing error messages. ### How was this patch tested? - Pass GitHub Actions Closes #38294 from LuciferYang/SPARK-40369. Authored-by: yangjie01 <yangji...@baidu.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- core/src/main/resources/error/error-classes.json | 10 +++ .../expressions/CallMethodViaReflection.scala | 72 +++++++++++++++++----- .../expressions/CallMethodViaReflectionSuite.scala | 53 ++++++++++++---- 3 files changed, 105 insertions(+), 30 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 3e97029b154..7f42d8acc53 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -233,6 +233,11 @@ "The lower bound of a window frame must be <comparison> to the upper bound." ] }, + "UNEXPECTED_CLASS_TYPE" : { + "message" : [ + "class <className> not found" + ] + }, "UNEXPECTED_INPUT_TYPE" : { "message" : [ "parameter <paramIndex> requires <requiredType> type, however, <inputSql> is of <inputType> type." @@ -243,6 +248,11 @@ "The <exprName> must not be null" ] }, + "UNEXPECTED_STATIC_METHOD" : { + "message" : [ + "cannot find a static method <methodName> that matches the argument types in <className>" + ] + }, "UNSPECIFIED_FRAME" : { "message" : [ "Cannot use an UnspecifiedFrame. This should have been converted during analysis." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 7cb830d1156..db2053707b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -21,7 +21,8 @@ import java.lang.reflect.{Method, Modifier} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -61,20 +62,56 @@ case class CallMethodViaReflection(children: Seq[Expression]) override def checkInputDataTypes(): TypeCheckResult = { if (children.size < 2) { - TypeCheckFailure("requires at least two arguments") - } else if (!children.take(2).forall(e => e.dataType == StringType && e.foldable)) { - // The first two arguments must be string type. - TypeCheckFailure("first two arguments should be string literals") - } else if (!classExists) { - TypeCheckFailure(s"class $className not found") - } else if (children.slice(2, children.length) - .exists(e => !CallMethodViaReflection.typeMapping.contains(e.dataType))) { - TypeCheckFailure("arguments from the third require boolean, byte, short, " + - "integer, long, float, double or string expressions") - } else if (method == null) { - TypeCheckFailure(s"cannot find a static method that matches the argument types in $className") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_PARAMS", + messageParameters = Map("actualNum" -> children.length.toString)) } else { - TypeCheckSuccess + val unexpectedParameter = children.zipWithIndex.collectFirst { + case (e, 0) if !(e.dataType == StringType && e.foldable) => + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "class", + "inputType" -> toSQLType(StringType), + "inputExpr" -> toSQLExpr(children.head) + ) + ) + case (e, 1) if !(e.dataType == StringType && e.foldable) => + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "method", + "inputType" -> toSQLType(StringType), + "inputExpr" -> toSQLExpr(children(1)) + ) + ) + case (e, idx) if idx > 1 && !CallMethodViaReflection.typeMapping.contains(e.dataType) => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> (idx + 1).toString, + "requiredType" -> toSQLType( + TypeCollection(BooleanType, ByteType, ShortType, + IntegerType, LongType, FloatType, DoubleType, StringType)), + "inputSql" -> toSQLExpr(e), + "inputType" -> toSQLType(e.dataType)) + ) + } + + unexpectedParameter match { + case Some(mismatch) => mismatch + case _ if !classExists => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_CLASS_TYPE", + messageParameters = Map("className" -> className) + ) + case _ if method == null => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_STATIC_METHOD", + messageParameters = Map("methodName" -> methodName, "className" -> className) + ) + case _ => TypeCheckSuccess + } } } @@ -106,11 +143,12 @@ case class CallMethodViaReflection(children: Seq[Expression]) /** True if the class exists and can be loaded. */ @transient private lazy val classExists = CallMethodViaReflection.classExists(className) + /** Name of the method */ + @transient private lazy val methodName = children(1).eval(null).asInstanceOf[UTF8String].toString + /** The reflection method. */ - @transient lazy val method: Method = { - val methodName = children(1).eval(null).asInstanceOf[UTF8String].toString + @transient lazy val method: Method = CallMethodViaReflection.findMethod(className, methodName, argExprs.map(_.dataType)).orNull - } /** A temporary buffer used to hold intermediate results returned by children. */ @transient private lazy val buffer = new Array[Object](argExprs.length) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala index d8f3ad24246..c8b99f6f026 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.types._ /** A static class for testing purpose. */ object ReflectStaticClass { @@ -60,24 +61,39 @@ class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelp } test("class not found") { - val ret = createExpr("some-random-class", "method").checkInputDataTypes() + val wrongClassName = "some-random-class" + val ret = createExpr(wrongClassName, "method").checkInputDataTypes() assert(ret.isFailure) - val errorMsg = ret.asInstanceOf[TypeCheckFailure].message - assert(errorMsg.contains("not found") && errorMsg.contains("class")) + assert(ret == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_CLASS_TYPE", + messageParameters = Map("className" -> wrongClassName) + ) + ) } test("method not found because name does not match") { - val ret = createExpr(staticClassName, "notfoundmethod").checkInputDataTypes() + val wrongMethodName = "notfoundmethod" + val ret = createExpr(staticClassName, wrongMethodName).checkInputDataTypes() assert(ret.isFailure) - val errorMsg = ret.asInstanceOf[TypeCheckFailure].message - assert(errorMsg.contains("cannot find a static method")) + assert(ret == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_STATIC_METHOD", + messageParameters = Map("methodName" -> wrongMethodName, "className" -> staticClassName) + ) + ) } test("method not found because there is no static method") { - val ret = createExpr(dynamicClassName, "method1").checkInputDataTypes() + val wrongMethodName = "method1" + val ret = createExpr(dynamicClassName, wrongMethodName).checkInputDataTypes() assert(ret.isFailure) - val errorMsg = ret.asInstanceOf[TypeCheckFailure].message - assert(errorMsg.contains("cannot find a static method")) + assert(ret == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_STATIC_METHOD", + messageParameters = Map("methodName" -> wrongMethodName, "className" -> dynamicClassName) + ) + ) } test("input type checking") { @@ -91,8 +107,19 @@ class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelp test("unsupported type checking") { val ret = createExpr(staticClassName, "method1", new Timestamp(1)).checkInputDataTypes() assert(ret.isFailure) - val errorMsg = ret.asInstanceOf[TypeCheckFailure].message - assert(errorMsg.contains("arguments from the third require boolean, byte, short")) + assert(ret == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "3", + "requiredType" -> toSQLType( + TypeCollection(BooleanType, ByteType, ShortType, + IntegerType, LongType, FloatType, DoubleType, StringType)), + "inputSql" -> "\"TIMESTAMP '1969-12-31 16:00:00.001'\"", + "inputType" -> "\"TIMESTAMP\"" + ) + ) + ) } test("invoking methods using acceptable types") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org