This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2bf26dfbd70 [SPARK-40530][SQL] Add error-related developer APIs 2bf26dfbd70 is described below commit 2bf26dfbd708cbbbe8a51dda6972296055661e07 Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Mon Sep 26 15:35:38 2022 +0800 [SPARK-40530][SQL] Add error-related developer APIs ### What changes were proposed in this pull request? Third-party Spark plugins may define their own errors using the same framework as Spark: put error definition in json files. This PR moves some error-related code to a new file and marks them as developer APIs, so that others can reuse them instead of writing its own json reader. ### Why are the changes needed? make it easier for Spark plugins to define errors. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes #37969 from cloud-fan/error. Authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../org/apache/spark/ErrorClassesJSONReader.scala | 118 +++++++++++++++++++++ .../org/apache/spark/SparkThrowableHelper.scala | 105 ++---------------- .../org/apache/spark/SparkThrowableSuite.scala | 54 +++++++--- 3 files changed, 163 insertions(+), 114 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala b/core/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala new file mode 100644 index 00000000000..8d4ae3a877d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.net.URL + +import scala.collection.JavaConverters._ +import scala.collection.immutable.SortedMap + +import com.fasterxml.jackson.annotation.JsonIgnore +import com.fasterxml.jackson.core.`type`.TypeReference +import com.fasterxml.jackson.databind.json.JsonMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule +import org.apache.commons.text.StringSubstitutor + +import org.apache.spark.annotation.DeveloperApi + +/** + * A reader to load error information from one or more JSON files. Note that, if one error appears + * in more than one JSON files, the latter wins. Please read core/src/main/resources/error/README.md + * for more details. + */ +@DeveloperApi +class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) { + assert(jsonFileURLs.nonEmpty) + + private def readAsMap(url: URL): SortedMap[String, ErrorInfo] = { + val mapper: JsonMapper = JsonMapper.builder() + .addModule(DefaultScalaModule) + .build() + mapper.readValue(url, new TypeReference[SortedMap[String, ErrorInfo]]() {}) + } + + // Exposed for testing + private[spark] val errorInfoMap = jsonFileURLs.map(readAsMap).reduce(_ ++ _) + + def getErrorMessage(errorClass: String, messageParameters: Map[String, String]): String = { + val messageTemplate = getMessageTemplate(errorClass) + val sub = new StringSubstitutor(messageParameters.asJava) + sub.setEnableUndefinedVariableException(true) + try { + sub.replace(messageTemplate.replaceAll("<([a-zA-Z0-9_-]+)>", "\\$\\{$1\\}")) + } catch { + case _: IllegalArgumentException => throw SparkException.internalError( + s"Undefined error message parameter for error class: '$errorClass'. " + + s"Parameters: $messageParameters") + } + } + + def getMessageTemplate(errorClass: String): String = { + val errorClasses = errorClass.split("\\.") + assert(errorClasses.length == 1 || errorClasses.length == 2) + + val mainErrorClass = errorClasses.head + val subErrorClass = errorClasses.tail.headOption + val errorInfo = errorInfoMap.getOrElse( + mainErrorClass, + throw SparkException.internalError(s"Cannot find main error class '$errorClass'")) + assert(errorInfo.subClass.isDefined == subErrorClass.isDefined) + + if (subErrorClass.isEmpty) { + errorInfo.messageFormat + } else { + val errorSubInfo = errorInfo.subClass.get.getOrElse( + subErrorClass.get, + throw SparkException.internalError(s"Cannot find sub error class '$errorClass'")) + errorInfo.messageFormat + " " + errorSubInfo.messageFormat + } + } + + def getSqlState(errorClass: String): String = { + Option(errorClass).flatMap(errorInfoMap.get).flatMap(_.sqlState).orNull + } +} + +/** + * Information associated with an error class. + * + * @param sqlState SQLSTATE associated with this class. + * @param subClass SubClass associated with this class. + * @param message C-style message format compatible with printf. + * The error message is constructed by concatenating the lines with newlines. + */ +private case class ErrorInfo( + message: Seq[String], + subClass: Option[Map[String, ErrorSubInfo]], + sqlState: Option[String]) { + // For compatibility with multi-line error messages + @JsonIgnore + val messageFormat: String = message.mkString("\n") +} + +/** + * Information associated with an error subclass. + * + * @param message C-style message format compatible with printf. + * The error message is constructed by concatenating the lines with newlines. + */ +private case class ErrorSubInfo(message: Seq[String]) { + // For compatibility with multi-line error messages + @JsonIgnore + val messageFormat: String = message.mkString("\n") +} diff --git a/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index 1a50c1acd8d..d503f400d00 100644 --- a/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -17,50 +17,12 @@ package org.apache.spark -import java.net.URL - import scala.collection.JavaConverters._ -import scala.collection.immutable.SortedMap - -import com.fasterxml.jackson.annotation.JsonIgnore -import com.fasterxml.jackson.core.`type`.TypeReference -import com.fasterxml.jackson.databind.json.JsonMapper -import com.fasterxml.jackson.module.scala.DefaultScalaModule -import org.apache.commons.text.StringSubstitutor import org.apache.spark.util.JsonProtocol.toJsonString import org.apache.spark.util.Utils -/** - * Information associated with an error subclass. - * - * @param message C-style message format compatible with printf. - * The error message is constructed by concatenating the lines with newlines. - */ -private[spark] case class ErrorSubInfo(message: Seq[String]) { - // For compatibility with multi-line error messages - @JsonIgnore - val messageFormat: String = message.mkString("\n") -} - -/** - * Information associated with an error class. - * - * @param sqlState SQLSTATE associated with this class. - * @param subClass SubClass associated with this class. - * @param message C-style message format compatible with printf. - * The error message is constructed by concatenating the lines with newlines. - */ -private[spark] case class ErrorInfo( - message: Seq[String], - subClass: Option[Map[String, ErrorSubInfo]], - sqlState: Option[String]) { - // For compatibility with multi-line error messages - @JsonIgnore - val messageFormat: String = message.mkString("\n") -} - -object ErrorMessageFormat extends Enumeration { +private[spark] object ErrorMessageFormat extends Enumeration { val PRETTY, MINIMAL, STANDARD = Value } @@ -69,37 +31,8 @@ object ErrorMessageFormat extends Enumeration { * construct error messages. */ private[spark] object SparkThrowableHelper { - val errorClassesUrl: URL = - Utils.getSparkClassLoader.getResource("error/error-classes.json") - val errorClassToInfoMap: SortedMap[String, ErrorInfo] = { - val mapper: JsonMapper = JsonMapper.builder() - .addModule(DefaultScalaModule) - .build() - mapper.readValue(errorClassesUrl, new TypeReference[SortedMap[String, ErrorInfo]]() {}) - } - - def getParameterNames(errorClass: String, errorSubCLass: String): Array[String] = { - val errorInfo = errorClassToInfoMap.getOrElse(errorClass, - throw new IllegalArgumentException(s"Cannot find error class '$errorClass'")) - if (errorInfo.subClass.isEmpty && errorSubCLass != null) { - throw new IllegalArgumentException(s"'$errorClass' has no subclass") - } - if (errorInfo.subClass.isDefined && errorSubCLass == null) { - throw new IllegalArgumentException(s"'$errorClass' requires subclass") - } - var parameterizedMessage = errorInfo.messageFormat - if (errorInfo.subClass.isDefined) { - val givenSubClass = errorSubCLass - val errorSubInfo = errorInfo.subClass.get.getOrElse(givenSubClass, - throw new IllegalArgumentException(s"Cannot find sub error class '$givenSubClass'")) - parameterizedMessage = parameterizedMessage + errorSubInfo.messageFormat - } - val pattern = "<[a-zA-Z0-9_-]+>".r - val matches = pattern.findAllIn(parameterizedMessage) - val parameterSeq = matches.toArray - val parameterNames = parameterSeq.map(p => p.stripPrefix("<").stripSuffix(">")) - parameterNames - } + val errorReader = new ErrorClassesJsonReader( + Seq(Utils.getSparkClassLoader.getResource("error/error-classes.json"))) def getMessage( errorClass: String, @@ -120,36 +53,15 @@ private[spark] object SparkThrowableHelper { errorSubClass: String, messageParameters: Map[String, String], context: String): String = { - val errorInfo = errorClassToInfoMap.getOrElse(errorClass, - throw new IllegalArgumentException(s"Cannot find error class '$errorClass'")) - val (displayClass, displayFormat) = if (errorInfo.subClass.isEmpty) { - (errorClass, errorInfo.messageFormat) - } else { - val subClasses = errorInfo.subClass.get - if (errorSubClass == null) { - throw new IllegalArgumentException(s"Subclass required for error class '$errorClass'") - } - val errorSubInfo = subClasses.getOrElse(errorSubClass, - throw new IllegalArgumentException(s"Cannot find sub error class '$errorSubClass'")) - (errorClass + "." + errorSubClass, - errorInfo.messageFormat + " " + errorSubInfo.messageFormat) - } - val sub = new StringSubstitutor(messageParameters.asJava) - sub.setEnableUndefinedVariableException(true) - val displayMessage = try { - sub.replace(displayFormat.replaceAll("<([a-zA-Z0-9_-]+)>", "\\$\\{$1\\}")) - } catch { - case _: IllegalArgumentException => throw SparkException.internalError( - s"Undefined an error message parameter: $messageParameters") - } + val displayClass = errorClass + Option(errorSubClass).map("." + _).getOrElse("") + val displayMessage = errorReader.getErrorMessage(displayClass, messageParameters) val displayQueryContext = (if (context.isEmpty) "" else "\n") + context val prefix = if (displayClass.startsWith("_LEGACY_ERROR_TEMP_")) "" else s"[$displayClass] " - s"$prefix$displayMessage$displayQueryContext" } def getSqlState(errorClass: String): String = { - Option(errorClass).flatMap(errorClassToInfoMap.get).flatMap(_.sqlState).orNull + errorReader.getSqlState(errorClass) } def isInternalError(errorClass: String): Boolean = { @@ -179,9 +91,8 @@ private[spark] object SparkThrowableHelper { val errorSubClass = e.getErrorSubClass if (errorSubClass != null) g.writeStringField("errorSubClass", errorSubClass) if (format == STANDARD) { - val errorInfo = errorClassToInfoMap.getOrElse(errorClass, - throw SparkException.internalError(s"Cannot find the error class '$errorClass'")) - g.writeStringField("message", errorInfo.messageFormat) + val finalClass = errorClass + Option(errorSubClass).map("." + _).getOrElse("") + g.writeStringField("message", errorReader.getMessageTemplate(finalClass)) } val sqlState = e.getSqlState if (sqlState != null) g.writeStringField("sqlState", sqlState) diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 2012133a74d..266683b1eca 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -44,8 +44,10 @@ class SparkThrowableSuite extends SparkFunSuite { "core/testOnly *SparkThrowableSuite -- -t \"Error classes are correctly formatted\"" }}} */ - private val errorClassDir = getWorkspaceFilePath( - "core", "src", "main", "resources", "error").toFile + private val errorJsonFilePath = getWorkspaceFilePath( + "core", "src", "main", "resources", "error", "error-classes.json") + + private val errorReader = new ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL)) override def beforeAll(): Unit = { super.beforeAll() @@ -68,11 +70,11 @@ class SparkThrowableSuite extends SparkFunSuite { .addModule(DefaultScalaModule) .enable(STRICT_DUPLICATE_DETECTION) .build() - mapper.readValue(errorClassesUrl, new TypeReference[Map[String, ErrorInfo]]() {}) + mapper.readValue(errorJsonFilePath.toUri.toURL, new TypeReference[Map[String, ErrorInfo]]() {}) } test("Error classes are correctly formatted") { - val errorClassFileContents = IOUtils.toString(errorClassesUrl.openStream()) + val errorClassFileContents = IOUtils.toString(errorJsonFilePath.toUri.toURL.openStream()) val mapper = JsonMapper.builder() .addModule(DefaultScalaModule) .enable(SerializationFeature.INDENT_OUTPUT) @@ -82,11 +84,11 @@ class SparkThrowableSuite extends SparkFunSuite { val rewrittenString = mapper.configure(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS, true) .setSerializationInclusion(Include.NON_ABSENT) .writer(prettyPrinter) - .writeValueAsString(errorClassToInfoMap) + .writeValueAsString(errorReader.errorInfoMap) if (regenerateGoldenFiles) { if (rewrittenString.trim != errorClassFileContents.trim) { - val errorClassesFile = new File(errorClassDir, new File(errorClassesUrl.getPath).getName) + val errorClassesFile = errorJsonFilePath.toFile logInfo(s"Regenerating error class file $errorClassesFile") Files.delete(errorClassesFile.toPath) FileUtils.writeStringToFile(errorClassesFile, rewrittenString, StandardCharsets.UTF_8) @@ -97,7 +99,7 @@ class SparkThrowableSuite extends SparkFunSuite { } test("SQLSTATE invariants") { - val sqlStates = errorClassToInfoMap.values.toSeq.flatMap(_.sqlState) + val sqlStates = errorReader.errorInfoMap.values.toSeq.flatMap(_.sqlState) val errorClassReadMe = Utils.getSparkClassLoader.getResource("error/README.md") val errorClassReadMeContents = IOUtils.toString(errorClassReadMe.openStream()) val sqlStateTableRegex = @@ -112,7 +114,7 @@ class SparkThrowableSuite extends SparkFunSuite { } test("Message invariants") { - val messageSeq = errorClassToInfoMap.values.toSeq.flatMap { i => + val messageSeq = errorReader.errorInfoMap.values.toSeq.flatMap { i => Seq(i.message) ++ i.subClass.getOrElse(Map.empty).values.toSeq.map(_.message) } messageSeq.foreach { message => @@ -124,7 +126,7 @@ class SparkThrowableSuite extends SparkFunSuite { } test("Message format invariants") { - val messageFormats = errorClassToInfoMap + val messageFormats = errorReader.errorInfoMap .filterKeys(!_.startsWith("_LEGACY_ERROR_TEMP_")) .values.toSeq.flatMap { i => Seq(i.messageFormat) } checkCondition(messageFormats, s => s != null) @@ -137,22 +139,22 @@ class SparkThrowableSuite extends SparkFunSuite { .addModule(DefaultScalaModule) .enable(SerializationFeature.INDENT_OUTPUT) .build() - mapper.writeValue(tmpFile, errorClassToInfoMap) + mapper.writeValue(tmpFile, errorReader.errorInfoMap) val rereadErrorClassToInfoMap = mapper.readValue( tmpFile, new TypeReference[Map[String, ErrorInfo]]() {}) - assert(rereadErrorClassToInfoMap == errorClassToInfoMap) + assert(rereadErrorClassToInfoMap == errorReader.errorInfoMap) } test("Check if error class is missing") { - val ex1 = intercept[IllegalArgumentException] { + val ex1 = intercept[SparkException] { getMessage("", null, Map.empty[String, String]) } - assert(ex1.getMessage == "Cannot find error class ''") + assert(ex1.getMessage.contains("Cannot find main error class")) - val ex2 = intercept[IllegalArgumentException] { + val ex2 = intercept[SparkException] { getMessage("LOREM_IPSUM", null, Map.empty[String, String]) } - assert(ex2.getMessage == "Cannot find error class 'LOREM_IPSUM'") + assert(ex2.getMessage.contains("Cannot find main error class")) } test("Check if message parameters match message format") { @@ -161,7 +163,7 @@ class SparkThrowableSuite extends SparkFunSuite { getMessage("UNRESOLVED_COLUMN", "WITHOUT_SUGGESTION", Map.empty[String, String]) } assert(e.getErrorClass === "INTERNAL_ERROR") - assert(e.getMessageParameters().get("message").contains("Undefined an error message parameter")) + assert(e.getMessageParameters().get("message").contains("Undefined error message parameter")) // Does not fail with too many args (expects 0 args) assert(getMessage("DIVIDE_BY_ZERO", null, Map("config" -> "foo", "a" -> "bar")) == @@ -302,7 +304,7 @@ class SparkThrowableSuite extends SparkFunSuite { """{ | "errorClass" : "UNSUPPORTED_SAVE_MODE", | "errorSubClass" : "EXISTENT_PATH", - | "message" : "The save mode <saveMode> is not supported for:", + | "message" : "The save mode <saveMode> is not supported for: an existent path.", | "messageParameters" : { | "saveMode" : "UNSUPPORTED_MODE" | } @@ -321,4 +323,22 @@ class SparkThrowableSuite extends SparkFunSuite { | } |}""".stripMargin) } + + test("overwrite error classes") { + withTempDir { dir => + val json = new File(dir, "errors.json") + FileUtils.writeStringToFile(json, + """ + |{ + | "DIVIDE_BY_ZERO" : { + | "message" : [ + | "abc" + | ] + | } + |} + |""".stripMargin) + val reader = new ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL, json.toURL)) + assert(reader.getErrorMessage("DIVIDE_BY_ZERO", Map.empty) == "abc") + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org