This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2bf26dfbd70 [SPARK-40530][SQL] Add error-related developer APIs
2bf26dfbd70 is described below

commit 2bf26dfbd708cbbbe8a51dda6972296055661e07
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Mon Sep 26 15:35:38 2022 +0800

    [SPARK-40530][SQL] Add error-related developer APIs
    
    ### What changes were proposed in this pull request?
    
    Third-party Spark plugins may define their own errors using the same 
framework as Spark: put error definition in json files. This PR moves some 
error-related code to a new file and marks them as developer APIs, so that 
others can reuse them instead of writing its own json reader.
    
    ### Why are the changes needed?
    
    make it easier for Spark plugins to define errors.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    existing tests
    
    Closes #37969 from cloud-fan/error.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../org/apache/spark/ErrorClassesJSONReader.scala  | 118 +++++++++++++++++++++
 .../org/apache/spark/SparkThrowableHelper.scala    | 105 ++----------------
 .../org/apache/spark/SparkThrowableSuite.scala     |  54 +++++++---
 3 files changed, 163 insertions(+), 114 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala 
b/core/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
new file mode 100644
index 00000000000..8d4ae3a877d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.net.URL
+
+import scala.collection.JavaConverters._
+import scala.collection.immutable.SortedMap
+
+import com.fasterxml.jackson.annotation.JsonIgnore
+import com.fasterxml.jackson.core.`type`.TypeReference
+import com.fasterxml.jackson.databind.json.JsonMapper
+import com.fasterxml.jackson.module.scala.DefaultScalaModule
+import org.apache.commons.text.StringSubstitutor
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * A reader to load error information from one or more JSON files. Note that, 
if one error appears
+ * in more than one JSON files, the latter wins. Please read 
core/src/main/resources/error/README.md
+ * for more details.
+ */
+@DeveloperApi
+class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) {
+  assert(jsonFileURLs.nonEmpty)
+
+  private def readAsMap(url: URL): SortedMap[String, ErrorInfo] = {
+    val mapper: JsonMapper = JsonMapper.builder()
+      .addModule(DefaultScalaModule)
+      .build()
+    mapper.readValue(url, new TypeReference[SortedMap[String, ErrorInfo]]() {})
+  }
+
+  // Exposed for testing
+  private[spark] val errorInfoMap = jsonFileURLs.map(readAsMap).reduce(_ ++ _)
+
+  def getErrorMessage(errorClass: String, messageParameters: Map[String, 
String]): String = {
+    val messageTemplate = getMessageTemplate(errorClass)
+    val sub = new StringSubstitutor(messageParameters.asJava)
+    sub.setEnableUndefinedVariableException(true)
+    try {
+      sub.replace(messageTemplate.replaceAll("<([a-zA-Z0-9_-]+)>", 
"\\$\\{$1\\}"))
+    } catch {
+      case _: IllegalArgumentException => throw SparkException.internalError(
+        s"Undefined error message parameter for error class: '$errorClass'. " +
+          s"Parameters: $messageParameters")
+    }
+  }
+
+  def getMessageTemplate(errorClass: String): String = {
+    val errorClasses = errorClass.split("\\.")
+    assert(errorClasses.length == 1 || errorClasses.length == 2)
+
+    val mainErrorClass = errorClasses.head
+    val subErrorClass = errorClasses.tail.headOption
+    val errorInfo = errorInfoMap.getOrElse(
+      mainErrorClass,
+      throw SparkException.internalError(s"Cannot find main error class 
'$errorClass'"))
+    assert(errorInfo.subClass.isDefined == subErrorClass.isDefined)
+
+    if (subErrorClass.isEmpty) {
+      errorInfo.messageFormat
+    } else {
+      val errorSubInfo = errorInfo.subClass.get.getOrElse(
+        subErrorClass.get,
+        throw SparkException.internalError(s"Cannot find sub error class 
'$errorClass'"))
+      errorInfo.messageFormat + " " + errorSubInfo.messageFormat
+    }
+  }
+
+  def getSqlState(errorClass: String): String = {
+    Option(errorClass).flatMap(errorInfoMap.get).flatMap(_.sqlState).orNull
+  }
+}
+
+/**
+ * Information associated with an error class.
+ *
+ * @param sqlState SQLSTATE associated with this class.
+ * @param subClass SubClass associated with this class.
+ * @param message C-style message format compatible with printf.
+ *                The error message is constructed by concatenating the lines 
with newlines.
+ */
+private case class ErrorInfo(
+    message: Seq[String],
+    subClass: Option[Map[String, ErrorSubInfo]],
+    sqlState: Option[String]) {
+  // For compatibility with multi-line error messages
+  @JsonIgnore
+  val messageFormat: String = message.mkString("\n")
+}
+
+/**
+ * Information associated with an error subclass.
+ *
+ * @param message C-style message format compatible with printf.
+ *                The error message is constructed by concatenating the lines 
with newlines.
+ */
+private case class ErrorSubInfo(message: Seq[String]) {
+  // For compatibility with multi-line error messages
+  @JsonIgnore
+  val messageFormat: String = message.mkString("\n")
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala 
b/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala
index 1a50c1acd8d..d503f400d00 100644
--- a/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala
+++ b/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala
@@ -17,50 +17,12 @@
 
 package org.apache.spark
 
-import java.net.URL
-
 import scala.collection.JavaConverters._
-import scala.collection.immutable.SortedMap
-
-import com.fasterxml.jackson.annotation.JsonIgnore
-import com.fasterxml.jackson.core.`type`.TypeReference
-import com.fasterxml.jackson.databind.json.JsonMapper
-import com.fasterxml.jackson.module.scala.DefaultScalaModule
-import org.apache.commons.text.StringSubstitutor
 
 import org.apache.spark.util.JsonProtocol.toJsonString
 import org.apache.spark.util.Utils
 
-/**
- * Information associated with an error subclass.
- *
- * @param message C-style message format compatible with printf.
- *                The error message is constructed by concatenating the lines 
with newlines.
- */
-private[spark] case class ErrorSubInfo(message: Seq[String]) {
-  // For compatibility with multi-line error messages
-  @JsonIgnore
-  val messageFormat: String = message.mkString("\n")
-}
-
-/**
- * Information associated with an error class.
- *
- * @param sqlState SQLSTATE associated with this class.
- * @param subClass SubClass associated with this class.
- * @param message C-style message format compatible with printf.
- *                The error message is constructed by concatenating the lines 
with newlines.
- */
-private[spark] case class ErrorInfo(
-    message: Seq[String],
-    subClass: Option[Map[String, ErrorSubInfo]],
-    sqlState: Option[String]) {
-  // For compatibility with multi-line error messages
-  @JsonIgnore
-  val messageFormat: String = message.mkString("\n")
-}
-
-object ErrorMessageFormat extends Enumeration {
+private[spark] object ErrorMessageFormat extends Enumeration {
   val PRETTY, MINIMAL, STANDARD = Value
 }
 
@@ -69,37 +31,8 @@ object ErrorMessageFormat extends Enumeration {
  * construct error messages.
  */
 private[spark] object SparkThrowableHelper {
-  val errorClassesUrl: URL =
-    Utils.getSparkClassLoader.getResource("error/error-classes.json")
-  val errorClassToInfoMap: SortedMap[String, ErrorInfo] = {
-    val mapper: JsonMapper = JsonMapper.builder()
-      .addModule(DefaultScalaModule)
-      .build()
-    mapper.readValue(errorClassesUrl, new TypeReference[SortedMap[String, 
ErrorInfo]]() {})
-  }
-
-  def getParameterNames(errorClass: String, errorSubCLass: String): 
Array[String] = {
-    val errorInfo = errorClassToInfoMap.getOrElse(errorClass,
-      throw new IllegalArgumentException(s"Cannot find error class 
'$errorClass'"))
-    if (errorInfo.subClass.isEmpty && errorSubCLass != null) {
-      throw new IllegalArgumentException(s"'$errorClass' has no subclass")
-    }
-    if (errorInfo.subClass.isDefined && errorSubCLass == null) {
-      throw new IllegalArgumentException(s"'$errorClass' requires subclass")
-    }
-    var parameterizedMessage = errorInfo.messageFormat
-    if (errorInfo.subClass.isDefined) {
-      val givenSubClass = errorSubCLass
-      val errorSubInfo = errorInfo.subClass.get.getOrElse(givenSubClass,
-        throw new IllegalArgumentException(s"Cannot find sub error class 
'$givenSubClass'"))
-      parameterizedMessage = parameterizedMessage + errorSubInfo.messageFormat
-    }
-    val pattern = "<[a-zA-Z0-9_-]+>".r
-    val matches = pattern.findAllIn(parameterizedMessage)
-    val parameterSeq = matches.toArray
-    val parameterNames = parameterSeq.map(p => 
p.stripPrefix("<").stripSuffix(">"))
-    parameterNames
-  }
+  val errorReader = new ErrorClassesJsonReader(
+    Seq(Utils.getSparkClassLoader.getResource("error/error-classes.json")))
 
   def getMessage(
       errorClass: String,
@@ -120,36 +53,15 @@ private[spark] object SparkThrowableHelper {
       errorSubClass: String,
       messageParameters: Map[String, String],
       context: String): String = {
-    val errorInfo = errorClassToInfoMap.getOrElse(errorClass,
-      throw new IllegalArgumentException(s"Cannot find error class 
'$errorClass'"))
-    val (displayClass, displayFormat) = if (errorInfo.subClass.isEmpty) {
-      (errorClass, errorInfo.messageFormat)
-    } else {
-      val subClasses = errorInfo.subClass.get
-      if (errorSubClass == null) {
-        throw new IllegalArgumentException(s"Subclass required for error class 
'$errorClass'")
-      }
-      val errorSubInfo = subClasses.getOrElse(errorSubClass,
-        throw new IllegalArgumentException(s"Cannot find sub error class 
'$errorSubClass'"))
-      (errorClass + "." + errorSubClass,
-        errorInfo.messageFormat + " " + errorSubInfo.messageFormat)
-    }
-    val sub = new StringSubstitutor(messageParameters.asJava)
-    sub.setEnableUndefinedVariableException(true)
-    val displayMessage = try {
-      sub.replace(displayFormat.replaceAll("<([a-zA-Z0-9_-]+)>", 
"\\$\\{$1\\}"))
-    } catch {
-      case _: IllegalArgumentException => throw SparkException.internalError(
-        s"Undefined an error message parameter: $messageParameters")
-    }
+    val displayClass = errorClass + Option(errorSubClass).map("." + 
_).getOrElse("")
+    val displayMessage = errorReader.getErrorMessage(displayClass, 
messageParameters)
     val displayQueryContext = (if (context.isEmpty) "" else "\n") + context
     val prefix = if (displayClass.startsWith("_LEGACY_ERROR_TEMP_")) "" else 
s"[$displayClass] "
-
     s"$prefix$displayMessage$displayQueryContext"
   }
 
   def getSqlState(errorClass: String): String = {
-    
Option(errorClass).flatMap(errorClassToInfoMap.get).flatMap(_.sqlState).orNull
+    errorReader.getSqlState(errorClass)
   }
 
   def isInternalError(errorClass: String): Boolean = {
@@ -179,9 +91,8 @@ private[spark] object SparkThrowableHelper {
           val errorSubClass = e.getErrorSubClass
           if (errorSubClass != null) g.writeStringField("errorSubClass", 
errorSubClass)
           if (format == STANDARD) {
-            val errorInfo = errorClassToInfoMap.getOrElse(errorClass,
-              throw SparkException.internalError(s"Cannot find the error class 
'$errorClass'"))
-            g.writeStringField("message", errorInfo.messageFormat)
+            val finalClass = errorClass + Option(errorSubClass).map("." + 
_).getOrElse("")
+            g.writeStringField("message", 
errorReader.getMessageTemplate(finalClass))
           }
           val sqlState = e.getSqlState
           if (sqlState != null) g.writeStringField("sqlState", sqlState)
diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala 
b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
index 2012133a74d..266683b1eca 100644
--- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
@@ -44,8 +44,10 @@ class SparkThrowableSuite extends SparkFunSuite {
         "core/testOnly *SparkThrowableSuite -- -t \"Error classes are 
correctly formatted\""
    }}}
    */
-  private val errorClassDir = getWorkspaceFilePath(
-    "core", "src", "main", "resources", "error").toFile
+  private val errorJsonFilePath = getWorkspaceFilePath(
+    "core", "src", "main", "resources", "error", "error-classes.json")
+
+  private val errorReader = new 
ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL))
 
   override def beforeAll(): Unit = {
     super.beforeAll()
@@ -68,11 +70,11 @@ class SparkThrowableSuite extends SparkFunSuite {
       .addModule(DefaultScalaModule)
       .enable(STRICT_DUPLICATE_DETECTION)
       .build()
-    mapper.readValue(errorClassesUrl, new TypeReference[Map[String, 
ErrorInfo]]() {})
+    mapper.readValue(errorJsonFilePath.toUri.toURL, new 
TypeReference[Map[String, ErrorInfo]]() {})
   }
 
   test("Error classes are correctly formatted") {
-    val errorClassFileContents = IOUtils.toString(errorClassesUrl.openStream())
+    val errorClassFileContents = 
IOUtils.toString(errorJsonFilePath.toUri.toURL.openStream())
     val mapper = JsonMapper.builder()
       .addModule(DefaultScalaModule)
       .enable(SerializationFeature.INDENT_OUTPUT)
@@ -82,11 +84,11 @@ class SparkThrowableSuite extends SparkFunSuite {
     val rewrittenString = 
mapper.configure(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS, true)
       .setSerializationInclusion(Include.NON_ABSENT)
       .writer(prettyPrinter)
-      .writeValueAsString(errorClassToInfoMap)
+      .writeValueAsString(errorReader.errorInfoMap)
 
     if (regenerateGoldenFiles) {
       if (rewrittenString.trim != errorClassFileContents.trim) {
-        val errorClassesFile = new File(errorClassDir, new 
File(errorClassesUrl.getPath).getName)
+        val errorClassesFile = errorJsonFilePath.toFile
         logInfo(s"Regenerating error class file $errorClassesFile")
         Files.delete(errorClassesFile.toPath)
         FileUtils.writeStringToFile(errorClassesFile, rewrittenString, 
StandardCharsets.UTF_8)
@@ -97,7 +99,7 @@ class SparkThrowableSuite extends SparkFunSuite {
   }
 
   test("SQLSTATE invariants") {
-    val sqlStates = errorClassToInfoMap.values.toSeq.flatMap(_.sqlState)
+    val sqlStates = errorReader.errorInfoMap.values.toSeq.flatMap(_.sqlState)
     val errorClassReadMe = 
Utils.getSparkClassLoader.getResource("error/README.md")
     val errorClassReadMeContents = 
IOUtils.toString(errorClassReadMe.openStream())
     val sqlStateTableRegex =
@@ -112,7 +114,7 @@ class SparkThrowableSuite extends SparkFunSuite {
   }
 
   test("Message invariants") {
-    val messageSeq = errorClassToInfoMap.values.toSeq.flatMap { i =>
+    val messageSeq = errorReader.errorInfoMap.values.toSeq.flatMap { i =>
       Seq(i.message) ++ 
i.subClass.getOrElse(Map.empty).values.toSeq.map(_.message)
     }
     messageSeq.foreach { message =>
@@ -124,7 +126,7 @@ class SparkThrowableSuite extends SparkFunSuite {
   }
 
   test("Message format invariants") {
-    val messageFormats = errorClassToInfoMap
+    val messageFormats = errorReader.errorInfoMap
       .filterKeys(!_.startsWith("_LEGACY_ERROR_TEMP_"))
       .values.toSeq.flatMap { i => Seq(i.messageFormat) }
     checkCondition(messageFormats, s => s != null)
@@ -137,22 +139,22 @@ class SparkThrowableSuite extends SparkFunSuite {
       .addModule(DefaultScalaModule)
       .enable(SerializationFeature.INDENT_OUTPUT)
       .build()
-    mapper.writeValue(tmpFile, errorClassToInfoMap)
+    mapper.writeValue(tmpFile, errorReader.errorInfoMap)
     val rereadErrorClassToInfoMap = mapper.readValue(
       tmpFile, new TypeReference[Map[String, ErrorInfo]]() {})
-    assert(rereadErrorClassToInfoMap == errorClassToInfoMap)
+    assert(rereadErrorClassToInfoMap == errorReader.errorInfoMap)
   }
 
   test("Check if error class is missing") {
-    val ex1 = intercept[IllegalArgumentException] {
+    val ex1 = intercept[SparkException] {
       getMessage("", null, Map.empty[String, String])
     }
-    assert(ex1.getMessage == "Cannot find error class ''")
+    assert(ex1.getMessage.contains("Cannot find main error class"))
 
-    val ex2 = intercept[IllegalArgumentException] {
+    val ex2 = intercept[SparkException] {
       getMessage("LOREM_IPSUM", null, Map.empty[String, String])
     }
-    assert(ex2.getMessage == "Cannot find error class 'LOREM_IPSUM'")
+    assert(ex2.getMessage.contains("Cannot find main error class"))
   }
 
   test("Check if message parameters match message format") {
@@ -161,7 +163,7 @@ class SparkThrowableSuite extends SparkFunSuite {
       getMessage("UNRESOLVED_COLUMN", "WITHOUT_SUGGESTION", Map.empty[String, 
String])
     }
     assert(e.getErrorClass === "INTERNAL_ERROR")
-    assert(e.getMessageParameters().get("message").contains("Undefined an 
error message parameter"))
+    assert(e.getMessageParameters().get("message").contains("Undefined error 
message parameter"))
 
     // Does not fail with too many args (expects 0 args)
     assert(getMessage("DIVIDE_BY_ZERO", null, Map("config" -> "foo", "a" -> 
"bar")) ==
@@ -302,7 +304,7 @@ class SparkThrowableSuite extends SparkFunSuite {
       """{
         |  "errorClass" : "UNSUPPORTED_SAVE_MODE",
         |  "errorSubClass" : "EXISTENT_PATH",
-        |  "message" : "The save mode <saveMode> is not supported for:",
+        |  "message" : "The save mode <saveMode> is not supported for: an 
existent path.",
         |  "messageParameters" : {
         |    "saveMode" : "UNSUPPORTED_MODE"
         |  }
@@ -321,4 +323,22 @@ class SparkThrowableSuite extends SparkFunSuite {
         |  }
         |}""".stripMargin)
   }
+
+  test("overwrite error classes") {
+    withTempDir { dir =>
+      val json = new File(dir, "errors.json")
+      FileUtils.writeStringToFile(json,
+        """
+          |{
+          |  "DIVIDE_BY_ZERO" : {
+          |    "message" : [
+          |      "abc"
+          |    ]
+          |  }
+          |}
+          |""".stripMargin)
+      val reader = new 
ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL, json.toURL))
+      assert(reader.getErrorMessage("DIVIDE_BY_ZERO", Map.empty) == "abc")
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to