This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4863be5632f [SPARK-45207][SQL][CONNECT] Implement Error Enrichment for 
Scala Client
4863be5632f is described below

commit 4863be5632f3165a5699a525235ea118c1e1f7eb
Author: Yihong He <yihong...@databricks.com>
AuthorDate: Mon Sep 25 09:35:33 2023 +0900

    [SPARK-45207][SQL][CONNECT] Implement Error Enrichment for Scala Client
    
    ### What changes were proposed in this pull request?
    
    -  Implemented the reconstruction of the complete exception (un-truncated 
error messages, cause exceptions, server-side stacktrace) based on the 
responses of FetchErrorDetails RPC.
    
    ### Why are the changes needed?
    
    - Cause exceptions play an important role in the current control flow, such 
as in StreamingQueryException. They are also valuable for debugging.
    - Un-truncated error message is useful for debugging
    - Providing server-side stack traces aids in effectively diagnosing 
server-related issues.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    - `build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite"`
    - `build/sbt "connect-client-jvm/testOnly *ClientStreamingQuerySuite"`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #42987 from heyihong/SPARK-45207.
    
    Authored-by: Yihong He <yihong...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |  59 ++++++-
 .../sql/streaming/ClientStreamingQuerySuite.scala  |  41 ++++-
 .../client/CustomSparkConnectBlockingStub.scala    |  44 ++++-
 .../connect/client/GrpcExceptionConverter.scala    | 192 +++++++++++++++++----
 4 files changed, 292 insertions(+), 44 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 21892542eab..ec9b1698a4e 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql
 
 import java.io.{ByteArrayOutputStream, PrintStream}
 import java.nio.file.Files
+import java.time.DateTimeException
 import java.util.Properties
 
 import scala.collection.JavaConverters._
@@ -29,7 +30,7 @@ import org.apache.commons.lang3.{JavaVersion, SystemUtils}
 import org.scalactic.TolerantNumerics
 import org.scalatest.PrivateMethodTester
 
-import org.apache.spark.{SparkArithmeticException, SparkException}
+import org.apache.spark.{SparkArithmeticException, SparkException, 
SparkUpgradeException}
 import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
 import 
org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, 
NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException, 
TempTableAlreadyExistsException}
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
@@ -44,6 +45,62 @@ import org.apache.spark.sql.types._
 
 class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with 
PrivateMethodTester {
 
+  for (enrichErrorEnabled <- Seq(false, true)) {
+    test(s"cause exception - ${enrichErrorEnabled}") {
+      withSQLConf("spark.sql.connect.enrichError.enabled" -> 
enrichErrorEnabled.toString) {
+        val ex = intercept[SparkUpgradeException] {
+          spark
+            .sql("""
+                |select from_json(
+                |  '{"d": "02-29"}',
+                |  'd date',
+                |  map('dateFormat', 'MM-dd'))
+                |""".stripMargin)
+            .collect()
+        }
+        if (enrichErrorEnabled) {
+          assert(ex.getCause.isInstanceOf[DateTimeException])
+        } else {
+          assert(ex.getCause == null)
+        }
+      }
+    }
+  }
+
+  test(s"throw SparkException with large cause exception") {
+    withSQLConf("spark.sql.connect.enrichError.enabled" -> "true") {
+      val session = spark
+      import session.implicits._
+
+      val throwException =
+        udf((_: String) => throw new SparkException("test" * 10000))
+
+      val ex = intercept[SparkException] {
+        Seq("1").toDS.withColumn("udf_val", throwException($"value")).collect()
+      }
+
+      assert(ex.getCause.isInstanceOf[SparkException])
+      assert(ex.getCause.getMessage.contains("test" * 10000))
+    }
+  }
+
+  for (isServerStackTraceEnabled <- Seq(false, true)) {
+    test(s"server-side stack trace is set in exceptions - 
${isServerStackTraceEnabled}") {
+      withSQLConf(
+        "spark.sql.connect.serverStacktrace.enabled" -> 
isServerStackTraceEnabled.toString,
+        "spark.sql.pyspark.jvmStacktrace.enabled" -> "false") {
+        val ex = intercept[AnalysisException] {
+          spark.sql("select x").collect()
+        }
+        assert(
+          ex.getStackTrace
+            
.find(_.getClassName.contains("org.apache.spark.sql.catalyst.analysis.CheckAnalysis"))
+            .isDefined
+            == isServerStackTraceEnabled)
+      }
+    }
+  }
+
   test("throw SparkArithmeticException") {
     withSQLConf("spark.sql.ansi.enabled" -> "true") {
       intercept[SparkArithmeticException] {
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
index dc4d441ec30..5d281cfbfeb 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
@@ -27,11 +27,11 @@ import org.scalatest.concurrent.Eventually.eventually
 import org.scalatest.concurrent.Futures.timeout
 import org.scalatest.time.SpanSugar._
 
+import org.apache.spark.SparkException
 import org.apache.spark.api.java.function.VoidFunction2
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession}
-import org.apache.spark.sql.functions.col
-import org.apache.spark.sql.functions.window
+import org.apache.spark.sql.functions.{col, udf, window}
 import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, 
QueryStartedEvent, QueryTerminatedEvent}
 import org.apache.spark.sql.test.{QueryTest, SQLHelper}
 import org.apache.spark.util.SparkFileUtils
@@ -175,6 +175,43 @@ class ClientStreamingQuerySuite extends QueryTest with 
SQLHelper with Logging {
     }
   }
 
+  test("throw exception in streaming") {
+    // Disable spark.sql.pyspark.jvmStacktrace.enabled to avoid hitting the
+    // netty header limit.
+    withSQLConf("spark.sql.pyspark.jvmStacktrace.enabled" -> "false") {
+      val session = spark
+      import session.implicits._
+
+      val checkForTwo = udf((value: Int) => {
+        if (value == 2) {
+          throw new RuntimeException("Number 2 encountered!")
+        }
+        value
+      })
+
+      val query = spark.readStream
+        .format("rate")
+        .option("rowsPerSecond", "1")
+        .load()
+        .select(checkForTwo($"value").as("checkedValue"))
+        .writeStream
+        .outputMode("append")
+        .format("console")
+        .start()
+
+      val exception = intercept[SparkException] {
+        query.awaitTermination()
+      }
+
+      assert(exception.getCause.isInstanceOf[SparkException])
+      assert(exception.getCause.getCause.isInstanceOf[SparkException])
+      assert(exception.getCause.getCause.getCause.isInstanceOf[SparkException])
+      assert(
+        exception.getCause.getCause.getCause.getMessage
+          .contains("java.lang.RuntimeException: Number 2 encountered!"))
+    }
+  }
+
   test("foreach Row") {
     val writer = new TestForeachWriter[Row]
 
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
index 80edcfa8be1..f02704b2a02 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
@@ -27,11 +27,21 @@ private[connect] class CustomSparkConnectBlockingStub(
     retryPolicy: GrpcRetryHandler.RetryPolicy) {
 
   private val stub = SparkConnectServiceGrpc.newBlockingStub(channel)
+
   private val retryHandler = new GrpcRetryHandler(retryPolicy)
 
+  // GrpcExceptionConverter with a GRPC stub for fetching error details from 
server.
+  private val grpcExceptionConverter = new GrpcExceptionConverter(stub)
+
   def executePlan(request: ExecutePlanRequest): 
CloseableIterator[ExecutePlanResponse] = {
-    GrpcExceptionConverter.convert {
-      GrpcExceptionConverter.convertIterator[ExecutePlanResponse](
+    grpcExceptionConverter.convert(
+      request.getSessionId,
+      request.getUserContext,
+      request.getClientType) {
+      grpcExceptionConverter.convertIterator[ExecutePlanResponse](
+        request.getSessionId,
+        request.getUserContext,
+        request.getClientType,
         retryHandler.RetryIterator[ExecutePlanRequest, ExecutePlanResponse](
           request,
           r => CloseableIterator(stub.executePlan(r).asScala)))
@@ -40,15 +50,24 @@ private[connect] class CustomSparkConnectBlockingStub(
 
   def executePlanReattachable(
       request: ExecutePlanRequest): CloseableIterator[ExecutePlanResponse] = {
-    GrpcExceptionConverter.convert {
-      GrpcExceptionConverter.convertIterator[ExecutePlanResponse](
+    grpcExceptionConverter.convert(
+      request.getSessionId,
+      request.getUserContext,
+      request.getClientType) {
+      grpcExceptionConverter.convertIterator[ExecutePlanResponse](
+        request.getSessionId,
+        request.getUserContext,
+        request.getClientType,
         // Don't use retryHandler - own retry handling is inside.
         new ExecutePlanResponseReattachableIterator(request, channel, 
retryPolicy))
     }
   }
 
   def analyzePlan(request: AnalyzePlanRequest): AnalyzePlanResponse = {
-    GrpcExceptionConverter.convert {
+    grpcExceptionConverter.convert(
+      request.getSessionId,
+      request.getUserContext,
+      request.getClientType) {
       retryHandler.retry {
         stub.analyzePlan(request)
       }
@@ -56,7 +75,10 @@ private[connect] class CustomSparkConnectBlockingStub(
   }
 
   def config(request: ConfigRequest): ConfigResponse = {
-    GrpcExceptionConverter.convert {
+    grpcExceptionConverter.convert(
+      request.getSessionId,
+      request.getUserContext,
+      request.getClientType) {
       retryHandler.retry {
         stub.config(request)
       }
@@ -64,7 +86,10 @@ private[connect] class CustomSparkConnectBlockingStub(
   }
 
   def interrupt(request: InterruptRequest): InterruptResponse = {
-    GrpcExceptionConverter.convert {
+    grpcExceptionConverter.convert(
+      request.getSessionId,
+      request.getUserContext,
+      request.getClientType) {
       retryHandler.retry {
         stub.interrupt(request)
       }
@@ -72,7 +97,10 @@ private[connect] class CustomSparkConnectBlockingStub(
   }
 
   def artifactStatus(request: ArtifactStatusesRequest): 
ArtifactStatusesResponse = {
-    GrpcExceptionConverter.convert {
+    grpcExceptionConverter.convert(
+      request.getSessionId,
+      request.getUserContext,
+      request.getClientType) {
       retryHandler.retry {
         stub.artifactStatus(request)
       }
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
index fe9f6dc2b4a..edbc434ef96 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
@@ -24,49 +24,145 @@ import scala.reflect.ClassTag
 import com.google.rpc.ErrorInfo
 import io.grpc.StatusRuntimeException
 import io.grpc.protobuf.StatusProto
+import org.json4s.DefaultFormats
+import org.json4s.jackson.JsonMethods
 
 import org.apache.spark.{SparkArithmeticException, 
SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, 
SparkIllegalArgumentException, SparkNumberFormatException, 
SparkRuntimeException, SparkUnsupportedOperationException, 
SparkUpgradeException}
+import org.apache.spark.connect.proto.{FetchErrorDetailsRequest, 
FetchErrorDetailsResponse, UserContext}
+import 
org.apache.spark.connect.proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.AnalysisException
 import 
org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, 
NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException, 
TempTableAlreadyExistsException}
 import org.apache.spark.sql.catalyst.parser.ParseException
 import org.apache.spark.sql.catalyst.trees.Origin
-import org.apache.spark.util.JsonUtils
 
-private[client] object GrpcExceptionConverter extends JsonUtils {
-  def convert[T](f: => T): T = {
+/**
+ * GrpcExceptionConverter handles the conversion of StatusRuntimeExceptions 
into Spark exceptions.
+ * It does so by utilizing the ErrorInfo defined in error_details.proto and 
making an additional
+ * FetchErrorDetails RPC call to retrieve the full error message and 
optionally the server-side
+ * stacktrace.
+ *
+ * If the FetchErrorDetails RPC call succeeds, the exceptions will be 
constructed based on the
+ * response. If the RPC call fails, the exception will be constructed based on 
the ErrorInfo. If
+ * the ErrorInfo is missing, the exception will be constructed based on the 
StatusRuntimeException
+ * itself.
+ */
+private[client] class GrpcExceptionConverter(grpcStub: 
SparkConnectServiceBlockingStub)
+    extends Logging {
+  import GrpcExceptionConverter._
+
+  def convert[T](sessionId: String, userContext: UserContext, clientType: 
String)(f: => T): T = {
     try {
       f
     } catch {
       case e: StatusRuntimeException =>
-        throw toThrowable(e)
+        throw toThrowable(e, sessionId, userContext, clientType)
     }
   }
 
-  def convertIterator[T](iter: CloseableIterator[T]): CloseableIterator[T] = {
+  def convertIterator[T](
+      sessionId: String,
+      userContext: UserContext,
+      clientType: String,
+      iter: CloseableIterator[T]): CloseableIterator[T] = {
     new WrappedCloseableIterator[T] {
 
       override def innerIterator: Iterator[T] = iter
 
       override def hasNext: Boolean = {
-        convert {
+        convert(sessionId, userContext, clientType) {
           iter.hasNext
         }
       }
 
       override def next(): T = {
-        convert {
+        convert(sessionId, userContext, clientType) {
           iter.next()
         }
       }
 
       override def close(): Unit = {
-        convert {
+        convert(sessionId, userContext, clientType) {
           iter.close()
         }
       }
     }
   }
 
+  /**
+   * Fetches enriched errors with full exception message and optionally 
stacktrace by issuing an
+   * additional RPC call to fetch error details. The RPC call is best-effort 
at-most-once.
+   */
+  private def fetchEnrichedError(
+      info: ErrorInfo,
+      sessionId: String,
+      userContext: UserContext,
+      clientType: String): Option[Throwable] = {
+    val errorId = info.getMetadataOrDefault("errorId", null)
+    if (errorId == null) {
+      logWarning("Unable to fetch enriched error since errorId is missing")
+      return None
+    }
+
+    try {
+      val errorDetailsResponse = grpcStub.fetchErrorDetails(
+        FetchErrorDetailsRequest
+          .newBuilder()
+          .setSessionId(sessionId)
+          .setErrorId(errorId)
+          .setUserContext(userContext)
+          .setClientType(clientType)
+          .build())
+
+      if (!errorDetailsResponse.hasRootErrorIdx) {
+        logWarning("Unable to fetch enriched error since error is not found")
+        return None
+      }
+
+      Some(
+        errorsToThrowable(
+          errorDetailsResponse.getRootErrorIdx,
+          errorDetailsResponse.getErrorsList.asScala.toSeq))
+    } catch {
+      case e: StatusRuntimeException =>
+        logWarning("Unable to fetch enriched error", e)
+        None
+    }
+  }
+
+  private def toThrowable(
+      ex: StatusRuntimeException,
+      sessionId: String,
+      userContext: UserContext,
+      clientType: String): Throwable = {
+    val status = StatusProto.fromThrowable(ex)
+
+    // Extract the ErrorInfo from the StatusProto, if present.
+    val errorInfoOpt = status.getDetailsList.asScala
+      .find(_.is(classOf[ErrorInfo]))
+      .map(_.unpack(classOf[ErrorInfo]))
+
+    if (errorInfoOpt.isDefined) {
+      // If ErrorInfo is found, try to fetch enriched error details by an 
additional RPC.
+      val enrichedErrorOpt =
+        fetchEnrichedError(errorInfoOpt.get, sessionId, userContext, 
clientType)
+      if (enrichedErrorOpt.isDefined) {
+        return enrichedErrorOpt.get
+      }
+
+      // If fetching enriched error details fails, convert ErrorInfo to a 
Throwable.
+      // Unlike enriched errors above, the message from status may be 
truncated,
+      // and no cause exceptions or server-side stack traces will be 
reconstructed.
+      return errorInfoToThrowable(errorInfoOpt.get, status.getMessage)
+    }
+
+    // If no ErrorInfo is found, create a SparkException based on the 
StatusRuntimeException.
+    new SparkException(ex.toString, ex.getCause)
+  }
+}
+
+private object GrpcExceptionConverter {
+
   private def errorConstructor[T <: Throwable: ClassTag](
       throwableCtr: (String, Option[Throwable]) => T)
       : (String, (String, Option[Throwable]) => Throwable) = {
@@ -93,33 +189,63 @@ private[client] object GrpcExceptionConverter extends 
JsonUtils {
       new SparkArrayIndexOutOfBoundsException(message)),
     errorConstructor[DateTimeException]((message, _) => new 
SparkDateTimeException(message)),
     errorConstructor((message, cause) => new SparkRuntimeException(message, 
cause)),
-    errorConstructor((message, cause) => new SparkUpgradeException(message, 
cause)))
-
-  private def errorInfoToThrowable(info: ErrorInfo, message: String): 
Option[Throwable] = {
-    val classes =
-      mapper.readValue(info.getMetadataOrDefault("classes", "[]"), 
classOf[Array[String]])
+    errorConstructor((message, cause) => new SparkUpgradeException(message, 
cause)),
+    errorConstructor((message, cause) => new SparkException(message, 
cause.orNull)))
+
+  /**
+   * errorsToThrowable reconstructs the exception based on a list of protobuf 
messages
+   * FetchErrorDetailsResponse.Error with un-truncated error messages and 
server-side stacktrace
+   * (if set).
+   */
+  private def errorsToThrowable(
+      errorIdx: Int,
+      errors: Seq[FetchErrorDetailsResponse.Error]): Throwable = {
+
+    val error = errors(errorIdx)
+
+    val classHierarchy = error.getErrorTypeHierarchyList.asScala
+
+    val constructor =
+      classHierarchy
+        .flatMap(errorFactory.get)
+        .headOption
+        .getOrElse((message: String, cause: Option[Throwable]) =>
+          new SparkException(s"${classHierarchy.head}: ${message}", 
cause.orNull))
+
+    val causeOpt =
+      if (error.hasCauseIdx) Some(errorsToThrowable(error.getCauseIdx, 
errors)) else None
+
+    val exception = constructor(error.getMessage, causeOpt)
+
+    if (!error.getStackTraceList.isEmpty) {
+      exception.setStackTrace(error.getStackTraceList.asScala.toArray.map { 
stackTraceElement =>
+        new StackTraceElement(
+          stackTraceElement.getDeclaringClass,
+          stackTraceElement.getMethodName,
+          stackTraceElement.getFileName,
+          stackTraceElement.getLineNumber)
+      })
+    }
 
-    classes
-      .find(errorFactory.contains)
-      .map { cls =>
-        val constructor = errorFactory.get(cls).get
-        constructor(message, None)
-      }
+    exception
   }
 
-  private def toThrowable(ex: StatusRuntimeException): Throwable = {
-    val status = StatusProto.fromThrowable(ex)
-
-    val fallbackEx = new SparkException(ex.toString, ex.getCause)
-
-    val errorInfoOpt = status.getDetailsList.asScala
-      .find(_.is(classOf[ErrorInfo]))
-
-    if (errorInfoOpt.isEmpty) {
-      return fallbackEx
-    }
-
-    errorInfoToThrowable(errorInfoOpt.get.unpack(classOf[ErrorInfo]), 
status.getMessage)
-      .getOrElse(fallbackEx)
+  /**
+   * errorInfoToThrowable reconstructs the exception based on the error 
classes hierarchy and the
+   * truncated error message.
+   */
+  private def errorInfoToThrowable(info: ErrorInfo, message: String): 
Throwable = {
+    implicit val formats = DefaultFormats
+    val classes =
+      JsonMethods.parse(info.getMetadataOrDefault("classes", 
"[]")).extract[Array[String]]
+
+    errorsToThrowable(
+      0,
+      Seq(
+        FetchErrorDetailsResponse.Error
+          .newBuilder()
+          .setMessage(message)
+          .addAllErrorTypeHierarchy(classes.toIterable.asJava)
+          .build()))
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to