This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2bf26dfbd70 [SPARK-40530][SQL] Add error-related developer APIs
2bf26dfbd70 is described below
commit 2bf26dfbd708cbbbe8a51dda6972296055661e07
Author: Wenchen Fan <[email protected]>
AuthorDate: Mon Sep 26 15:35:38 2022 +0800
[SPARK-40530][SQL] Add error-related developer APIs
### What changes were proposed in this pull request?
Third-party Spark plugins may define their own errors using the same
framework as Spark: put error definition in json files. This PR moves some
error-related code to a new file and marks them as developer APIs, so that
others can reuse them instead of writing its own json reader.
### Why are the changes needed?
make it easier for Spark plugins to define errors.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
existing tests
Closes #37969 from cloud-fan/error.
Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../org/apache/spark/ErrorClassesJSONReader.scala | 118 +++++++++++++++++++++
.../org/apache/spark/SparkThrowableHelper.scala | 105 ++----------------
.../org/apache/spark/SparkThrowableSuite.scala | 54 +++++++---
3 files changed, 163 insertions(+), 114 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
b/core/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
new file mode 100644
index 00000000000..8d4ae3a877d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.net.URL
+
+import scala.collection.JavaConverters._
+import scala.collection.immutable.SortedMap
+
+import com.fasterxml.jackson.annotation.JsonIgnore
+import com.fasterxml.jackson.core.`type`.TypeReference
+import com.fasterxml.jackson.databind.json.JsonMapper
+import com.fasterxml.jackson.module.scala.DefaultScalaModule
+import org.apache.commons.text.StringSubstitutor
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * A reader to load error information from one or more JSON files. Note that,
if one error appears
+ * in more than one JSON files, the latter wins. Please read
core/src/main/resources/error/README.md
+ * for more details.
+ */
+@DeveloperApi
+class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) {
+ assert(jsonFileURLs.nonEmpty)
+
+ private def readAsMap(url: URL): SortedMap[String, ErrorInfo] = {
+ val mapper: JsonMapper = JsonMapper.builder()
+ .addModule(DefaultScalaModule)
+ .build()
+ mapper.readValue(url, new TypeReference[SortedMap[String, ErrorInfo]]() {})
+ }
+
+ // Exposed for testing
+ private[spark] val errorInfoMap = jsonFileURLs.map(readAsMap).reduce(_ ++ _)
+
+ def getErrorMessage(errorClass: String, messageParameters: Map[String,
String]): String = {
+ val messageTemplate = getMessageTemplate(errorClass)
+ val sub = new StringSubstitutor(messageParameters.asJava)
+ sub.setEnableUndefinedVariableException(true)
+ try {
+ sub.replace(messageTemplate.replaceAll("<([a-zA-Z0-9_-]+)>",
"\\$\\{$1\\}"))
+ } catch {
+ case _: IllegalArgumentException => throw SparkException.internalError(
+ s"Undefined error message parameter for error class: '$errorClass'. " +
+ s"Parameters: $messageParameters")
+ }
+ }
+
+ def getMessageTemplate(errorClass: String): String = {
+ val errorClasses = errorClass.split("\\.")
+ assert(errorClasses.length == 1 || errorClasses.length == 2)
+
+ val mainErrorClass = errorClasses.head
+ val subErrorClass = errorClasses.tail.headOption
+ val errorInfo = errorInfoMap.getOrElse(
+ mainErrorClass,
+ throw SparkException.internalError(s"Cannot find main error class
'$errorClass'"))
+ assert(errorInfo.subClass.isDefined == subErrorClass.isDefined)
+
+ if (subErrorClass.isEmpty) {
+ errorInfo.messageFormat
+ } else {
+ val errorSubInfo = errorInfo.subClass.get.getOrElse(
+ subErrorClass.get,
+ throw SparkException.internalError(s"Cannot find sub error class
'$errorClass'"))
+ errorInfo.messageFormat + " " + errorSubInfo.messageFormat
+ }
+ }
+
+ def getSqlState(errorClass: String): String = {
+ Option(errorClass).flatMap(errorInfoMap.get).flatMap(_.sqlState).orNull
+ }
+}
+
+/**
+ * Information associated with an error class.
+ *
+ * @param sqlState SQLSTATE associated with this class.
+ * @param subClass SubClass associated with this class.
+ * @param message C-style message format compatible with printf.
+ * The error message is constructed by concatenating the lines
with newlines.
+ */
+private case class ErrorInfo(
+ message: Seq[String],
+ subClass: Option[Map[String, ErrorSubInfo]],
+ sqlState: Option[String]) {
+ // For compatibility with multi-line error messages
+ @JsonIgnore
+ val messageFormat: String = message.mkString("\n")
+}
+
+/**
+ * Information associated with an error subclass.
+ *
+ * @param message C-style message format compatible with printf.
+ * The error message is constructed by concatenating the lines
with newlines.
+ */
+private case class ErrorSubInfo(message: Seq[String]) {
+ // For compatibility with multi-line error messages
+ @JsonIgnore
+ val messageFormat: String = message.mkString("\n")
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala
b/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala
index 1a50c1acd8d..d503f400d00 100644
--- a/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala
+++ b/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala
@@ -17,50 +17,12 @@
package org.apache.spark
-import java.net.URL
-
import scala.collection.JavaConverters._
-import scala.collection.immutable.SortedMap
-
-import com.fasterxml.jackson.annotation.JsonIgnore
-import com.fasterxml.jackson.core.`type`.TypeReference
-import com.fasterxml.jackson.databind.json.JsonMapper
-import com.fasterxml.jackson.module.scala.DefaultScalaModule
-import org.apache.commons.text.StringSubstitutor
import org.apache.spark.util.JsonProtocol.toJsonString
import org.apache.spark.util.Utils
-/**
- * Information associated with an error subclass.
- *
- * @param message C-style message format compatible with printf.
- * The error message is constructed by concatenating the lines
with newlines.
- */
-private[spark] case class ErrorSubInfo(message: Seq[String]) {
- // For compatibility with multi-line error messages
- @JsonIgnore
- val messageFormat: String = message.mkString("\n")
-}
-
-/**
- * Information associated with an error class.
- *
- * @param sqlState SQLSTATE associated with this class.
- * @param subClass SubClass associated with this class.
- * @param message C-style message format compatible with printf.
- * The error message is constructed by concatenating the lines
with newlines.
- */
-private[spark] case class ErrorInfo(
- message: Seq[String],
- subClass: Option[Map[String, ErrorSubInfo]],
- sqlState: Option[String]) {
- // For compatibility with multi-line error messages
- @JsonIgnore
- val messageFormat: String = message.mkString("\n")
-}
-
-object ErrorMessageFormat extends Enumeration {
+private[spark] object ErrorMessageFormat extends Enumeration {
val PRETTY, MINIMAL, STANDARD = Value
}
@@ -69,37 +31,8 @@ object ErrorMessageFormat extends Enumeration {
* construct error messages.
*/
private[spark] object SparkThrowableHelper {
- val errorClassesUrl: URL =
- Utils.getSparkClassLoader.getResource("error/error-classes.json")
- val errorClassToInfoMap: SortedMap[String, ErrorInfo] = {
- val mapper: JsonMapper = JsonMapper.builder()
- .addModule(DefaultScalaModule)
- .build()
- mapper.readValue(errorClassesUrl, new TypeReference[SortedMap[String,
ErrorInfo]]() {})
- }
-
- def getParameterNames(errorClass: String, errorSubCLass: String):
Array[String] = {
- val errorInfo = errorClassToInfoMap.getOrElse(errorClass,
- throw new IllegalArgumentException(s"Cannot find error class
'$errorClass'"))
- if (errorInfo.subClass.isEmpty && errorSubCLass != null) {
- throw new IllegalArgumentException(s"'$errorClass' has no subclass")
- }
- if (errorInfo.subClass.isDefined && errorSubCLass == null) {
- throw new IllegalArgumentException(s"'$errorClass' requires subclass")
- }
- var parameterizedMessage = errorInfo.messageFormat
- if (errorInfo.subClass.isDefined) {
- val givenSubClass = errorSubCLass
- val errorSubInfo = errorInfo.subClass.get.getOrElse(givenSubClass,
- throw new IllegalArgumentException(s"Cannot find sub error class
'$givenSubClass'"))
- parameterizedMessage = parameterizedMessage + errorSubInfo.messageFormat
- }
- val pattern = "<[a-zA-Z0-9_-]+>".r
- val matches = pattern.findAllIn(parameterizedMessage)
- val parameterSeq = matches.toArray
- val parameterNames = parameterSeq.map(p =>
p.stripPrefix("<").stripSuffix(">"))
- parameterNames
- }
+ val errorReader = new ErrorClassesJsonReader(
+ Seq(Utils.getSparkClassLoader.getResource("error/error-classes.json")))
def getMessage(
errorClass: String,
@@ -120,36 +53,15 @@ private[spark] object SparkThrowableHelper {
errorSubClass: String,
messageParameters: Map[String, String],
context: String): String = {
- val errorInfo = errorClassToInfoMap.getOrElse(errorClass,
- throw new IllegalArgumentException(s"Cannot find error class
'$errorClass'"))
- val (displayClass, displayFormat) = if (errorInfo.subClass.isEmpty) {
- (errorClass, errorInfo.messageFormat)
- } else {
- val subClasses = errorInfo.subClass.get
- if (errorSubClass == null) {
- throw new IllegalArgumentException(s"Subclass required for error class
'$errorClass'")
- }
- val errorSubInfo = subClasses.getOrElse(errorSubClass,
- throw new IllegalArgumentException(s"Cannot find sub error class
'$errorSubClass'"))
- (errorClass + "." + errorSubClass,
- errorInfo.messageFormat + " " + errorSubInfo.messageFormat)
- }
- val sub = new StringSubstitutor(messageParameters.asJava)
- sub.setEnableUndefinedVariableException(true)
- val displayMessage = try {
- sub.replace(displayFormat.replaceAll("<([a-zA-Z0-9_-]+)>",
"\\$\\{$1\\}"))
- } catch {
- case _: IllegalArgumentException => throw SparkException.internalError(
- s"Undefined an error message parameter: $messageParameters")
- }
+ val displayClass = errorClass + Option(errorSubClass).map("." +
_).getOrElse("")
+ val displayMessage = errorReader.getErrorMessage(displayClass,
messageParameters)
val displayQueryContext = (if (context.isEmpty) "" else "\n") + context
val prefix = if (displayClass.startsWith("_LEGACY_ERROR_TEMP_")) "" else
s"[$displayClass] "
-
s"$prefix$displayMessage$displayQueryContext"
}
def getSqlState(errorClass: String): String = {
-
Option(errorClass).flatMap(errorClassToInfoMap.get).flatMap(_.sqlState).orNull
+ errorReader.getSqlState(errorClass)
}
def isInternalError(errorClass: String): Boolean = {
@@ -179,9 +91,8 @@ private[spark] object SparkThrowableHelper {
val errorSubClass = e.getErrorSubClass
if (errorSubClass != null) g.writeStringField("errorSubClass",
errorSubClass)
if (format == STANDARD) {
- val errorInfo = errorClassToInfoMap.getOrElse(errorClass,
- throw SparkException.internalError(s"Cannot find the error class
'$errorClass'"))
- g.writeStringField("message", errorInfo.messageFormat)
+ val finalClass = errorClass + Option(errorSubClass).map("." +
_).getOrElse("")
+ g.writeStringField("message",
errorReader.getMessageTemplate(finalClass))
}
val sqlState = e.getSqlState
if (sqlState != null) g.writeStringField("sqlState", sqlState)
diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
index 2012133a74d..266683b1eca 100644
--- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
@@ -44,8 +44,10 @@ class SparkThrowableSuite extends SparkFunSuite {
"core/testOnly *SparkThrowableSuite -- -t \"Error classes are
correctly formatted\""
}}}
*/
- private val errorClassDir = getWorkspaceFilePath(
- "core", "src", "main", "resources", "error").toFile
+ private val errorJsonFilePath = getWorkspaceFilePath(
+ "core", "src", "main", "resources", "error", "error-classes.json")
+
+ private val errorReader = new
ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL))
override def beforeAll(): Unit = {
super.beforeAll()
@@ -68,11 +70,11 @@ class SparkThrowableSuite extends SparkFunSuite {
.addModule(DefaultScalaModule)
.enable(STRICT_DUPLICATE_DETECTION)
.build()
- mapper.readValue(errorClassesUrl, new TypeReference[Map[String,
ErrorInfo]]() {})
+ mapper.readValue(errorJsonFilePath.toUri.toURL, new
TypeReference[Map[String, ErrorInfo]]() {})
}
test("Error classes are correctly formatted") {
- val errorClassFileContents = IOUtils.toString(errorClassesUrl.openStream())
+ val errorClassFileContents =
IOUtils.toString(errorJsonFilePath.toUri.toURL.openStream())
val mapper = JsonMapper.builder()
.addModule(DefaultScalaModule)
.enable(SerializationFeature.INDENT_OUTPUT)
@@ -82,11 +84,11 @@ class SparkThrowableSuite extends SparkFunSuite {
val rewrittenString =
mapper.configure(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS, true)
.setSerializationInclusion(Include.NON_ABSENT)
.writer(prettyPrinter)
- .writeValueAsString(errorClassToInfoMap)
+ .writeValueAsString(errorReader.errorInfoMap)
if (regenerateGoldenFiles) {
if (rewrittenString.trim != errorClassFileContents.trim) {
- val errorClassesFile = new File(errorClassDir, new
File(errorClassesUrl.getPath).getName)
+ val errorClassesFile = errorJsonFilePath.toFile
logInfo(s"Regenerating error class file $errorClassesFile")
Files.delete(errorClassesFile.toPath)
FileUtils.writeStringToFile(errorClassesFile, rewrittenString,
StandardCharsets.UTF_8)
@@ -97,7 +99,7 @@ class SparkThrowableSuite extends SparkFunSuite {
}
test("SQLSTATE invariants") {
- val sqlStates = errorClassToInfoMap.values.toSeq.flatMap(_.sqlState)
+ val sqlStates = errorReader.errorInfoMap.values.toSeq.flatMap(_.sqlState)
val errorClassReadMe =
Utils.getSparkClassLoader.getResource("error/README.md")
val errorClassReadMeContents =
IOUtils.toString(errorClassReadMe.openStream())
val sqlStateTableRegex =
@@ -112,7 +114,7 @@ class SparkThrowableSuite extends SparkFunSuite {
}
test("Message invariants") {
- val messageSeq = errorClassToInfoMap.values.toSeq.flatMap { i =>
+ val messageSeq = errorReader.errorInfoMap.values.toSeq.flatMap { i =>
Seq(i.message) ++
i.subClass.getOrElse(Map.empty).values.toSeq.map(_.message)
}
messageSeq.foreach { message =>
@@ -124,7 +126,7 @@ class SparkThrowableSuite extends SparkFunSuite {
}
test("Message format invariants") {
- val messageFormats = errorClassToInfoMap
+ val messageFormats = errorReader.errorInfoMap
.filterKeys(!_.startsWith("_LEGACY_ERROR_TEMP_"))
.values.toSeq.flatMap { i => Seq(i.messageFormat) }
checkCondition(messageFormats, s => s != null)
@@ -137,22 +139,22 @@ class SparkThrowableSuite extends SparkFunSuite {
.addModule(DefaultScalaModule)
.enable(SerializationFeature.INDENT_OUTPUT)
.build()
- mapper.writeValue(tmpFile, errorClassToInfoMap)
+ mapper.writeValue(tmpFile, errorReader.errorInfoMap)
val rereadErrorClassToInfoMap = mapper.readValue(
tmpFile, new TypeReference[Map[String, ErrorInfo]]() {})
- assert(rereadErrorClassToInfoMap == errorClassToInfoMap)
+ assert(rereadErrorClassToInfoMap == errorReader.errorInfoMap)
}
test("Check if error class is missing") {
- val ex1 = intercept[IllegalArgumentException] {
+ val ex1 = intercept[SparkException] {
getMessage("", null, Map.empty[String, String])
}
- assert(ex1.getMessage == "Cannot find error class ''")
+ assert(ex1.getMessage.contains("Cannot find main error class"))
- val ex2 = intercept[IllegalArgumentException] {
+ val ex2 = intercept[SparkException] {
getMessage("LOREM_IPSUM", null, Map.empty[String, String])
}
- assert(ex2.getMessage == "Cannot find error class 'LOREM_IPSUM'")
+ assert(ex2.getMessage.contains("Cannot find main error class"))
}
test("Check if message parameters match message format") {
@@ -161,7 +163,7 @@ class SparkThrowableSuite extends SparkFunSuite {
getMessage("UNRESOLVED_COLUMN", "WITHOUT_SUGGESTION", Map.empty[String,
String])
}
assert(e.getErrorClass === "INTERNAL_ERROR")
- assert(e.getMessageParameters().get("message").contains("Undefined an
error message parameter"))
+ assert(e.getMessageParameters().get("message").contains("Undefined error
message parameter"))
// Does not fail with too many args (expects 0 args)
assert(getMessage("DIVIDE_BY_ZERO", null, Map("config" -> "foo", "a" ->
"bar")) ==
@@ -302,7 +304,7 @@ class SparkThrowableSuite extends SparkFunSuite {
"""{
| "errorClass" : "UNSUPPORTED_SAVE_MODE",
| "errorSubClass" : "EXISTENT_PATH",
- | "message" : "The save mode <saveMode> is not supported for:",
+ | "message" : "The save mode <saveMode> is not supported for: an
existent path.",
| "messageParameters" : {
| "saveMode" : "UNSUPPORTED_MODE"
| }
@@ -321,4 +323,22 @@ class SparkThrowableSuite extends SparkFunSuite {
| }
|}""".stripMargin)
}
+
+ test("overwrite error classes") {
+ withTempDir { dir =>
+ val json = new File(dir, "errors.json")
+ FileUtils.writeStringToFile(json,
+ """
+ |{
+ | "DIVIDE_BY_ZERO" : {
+ | "message" : [
+ | "abc"
+ | ]
+ | }
+ |}
+ |""".stripMargin)
+ val reader = new
ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL, json.toURL))
+ assert(reader.getErrorMessage("DIVIDE_BY_ZERO", Map.empty) == "abc")
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]