This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 48b1a283a2eb [SPARK-44622][SQL][CONNECT] Implement FetchErrorDetails
RPC
48b1a283a2eb is described below
commit 48b1a283a2eba9f70149d5980d074fad2743c4ff
Author: Yihong He <[email protected]>
AuthorDate: Wed Sep 20 00:14:44 2023 -0400
[SPARK-44622][SQL][CONNECT] Implement FetchErrorDetails RPC
### What changes were proposed in this pull request?
- Introduced the FetchErrorDetails RPC to retrieve comprehensive error
details. FetchErrorDetails is used for enriching the error by issuing a
separate RPC call based on the `errorId` field in the ErrorInfo.
- Introduced error enrichment that utilizes an additional RPC to fetch
untruncated exception messages and server-side stack traces. This enrichment
can be enabled or disabled using the flag
`spark.sql.connect.enrichError.enabled`, and it's true by default.
- Implemented setting server-side stack traces for exceptions on the client
side via FetchErrorDetails RPC for debugging. The feature is enabled or
disabled using the flag `spark.sql.connect.serverStacktrace.enabled` and it's
true by default
### Why are the changes needed?
- Attaching full exception messages to the error details protobuf can
quickly hit the 8K GRPC Netty header limit. Utilizing a separate RPC to fetch
comprehensive error information is more dependable.
- Providing server-side stack traces aids in effectively diagnosing
server-related issues.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- `build/sbt "connect/testOnly *FetchErrorDetailsHandlerSuite"`
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #42377 from heyihong/SPARK-44622.
Authored-by: Yihong He <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../src/main/protobuf/spark/connect/base.proto | 57 +++++++
.../apache/spark/sql/connect/config/Connect.scala | 18 +++
.../spark/sql/connect/service/SessionHolder.scala | 21 ++-
.../SparkConnectFetchErrorDetailsHandler.scala | 59 +++++++
.../sql/connect/service/SparkConnectService.scala | 14 ++
.../spark/sql/connect/utils/ErrorUtils.scala | 103 ++++++++++--
.../service/FetchErrorDetailsHandlerSuite.scala | 166 +++++++++++++++++++
python/pyspark/sql/connect/proto/base_pb2.py | 14 +-
python/pyspark/sql/connect/proto/base_pb2.pyi | 180 +++++++++++++++++++++
python/pyspark/sql/connect/proto/base_pb2_grpc.py | 45 ++++++
10 files changed, 659 insertions(+), 18 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 65e2493f8368..cf1355f7ebc1 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -778,6 +778,60 @@ message ReleaseExecuteResponse {
optional string operation_id = 2;
}
+message FetchErrorDetailsRequest {
+
+ // (Required)
+ // The session_id specifies a Spark session for a user identified by
user_context.user_id.
+ // The id should be a UUID string of the format
`00112233-4455-6677-8899-aabbccddeeff`.
+ string session_id = 1;
+
+ // User context
+ UserContext user_context = 2;
+
+ // (Required)
+ // The id of the error.
+ string error_id = 3;
+}
+
+message FetchErrorDetailsResponse {
+
+ message StackTraceElement {
+ // The fully qualified name of the class containing the execution point.
+ string declaring_class = 1;
+
+ // The name of the method containing the execution point.
+ string method_name = 2;
+
+ // The name of the file containing the execution point.
+ string file_name = 3;
+
+ // The line number of the source line containing the execution point.
+ int32 line_number = 4;
+ }
+
+ // Error defines the schema for the representing exception.
+ message Error {
+ // The fully qualified names of the exception class and its parent classes.
+ repeated string error_type_hierarchy = 1;
+
+ // The detailed message of the exception.
+ string message = 2;
+
+ // The stackTrace of the exception. It will be set
+ // if the SQLConf spark.sql.connect.serverStacktrace.enabled is true.
+ repeated StackTraceElement stack_trace = 3;
+
+ // The index of the cause error in errors.
+ optional int32 cause_idx = 4;
+ }
+
+ // The index of the root error in errors. The field will not be set if the
error is not found.
+ optional int32 root_error_idx = 1;
+
+ // A list of errors.
+ repeated Error errors = 2;
+}
+
// Main interface for the SparkConnect service.
service SparkConnectService {
@@ -813,5 +867,8 @@ service SparkConnectService {
// Non reattachable executions are released automatically and immediately
after the ExecutePlan
// RPC and ReleaseExecute may not be used.
rpc ReleaseExecute(ReleaseExecuteRequest) returns (ReleaseExecuteResponse) {}
+
+ // FetchErrorDetails retrieves the matched exception with details based on a
provided error id.
+ rpc FetchErrorDetails(FetchErrorDetailsRequest) returns
(FetchErrorDetailsResponse) {}
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index dfd6008ac09a..248444e710d2 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.connect.common.config.ConnectCommon
+import org.apache.spark.sql.internal.SQLConf.buildConf
object Connect {
import org.apache.spark.sql.internal.SQLConf.buildStaticConf
@@ -213,4 +214,21 @@ object Connect {
.version("3.5.0")
.intConf
.createWithDefault(200)
+
+ val CONNECT_ENRICH_ERROR_ENABLED =
+ buildConf("spark.sql.connect.enrichError.enabled")
+ .doc("""
+ |When true, it enriches errors with full exception messages and
optionally server-side
+ |stacktrace on the client side via an additional RPC.
+ |""".stripMargin)
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(true)
+
+ val CONNECT_SERVER_STACKTRACE_ENABLED =
+ buildConf("spark.sql.connect.serverStacktrace.enabled")
+ .doc("When true, it sets the server-side stacktrace in the user-facing
Spark exception.")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(true)
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index 1cef02d7e346..0748cd237bf0 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -19,11 +19,14 @@ package org.apache.spark.sql.connect.service
import java.nio.file.Path
import java.util.UUID
-import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit}
import scala.collection.JavaConverters._
import scala.collection.mutable
+import com.google.common.base.Ticker
+import com.google.common.cache.CacheBuilder
+
import org.apache.spark.{JobArtifactSet, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
@@ -32,6 +35,7 @@ import
org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper
+import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE,
ERROR_CACHE_TIMEOUT_SEC}
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.SystemClock
import org.apache.spark.util.Utils
@@ -45,6 +49,15 @@ case class SessionHolder(userId: String, sessionId: String,
session: SparkSessio
private val executions: ConcurrentMap[String, ExecuteHolder] =
new ConcurrentHashMap[String, ExecuteHolder]()
+ // The cache that maps an error id to a throwable. The throwable in cache is
independent to
+ // each other.
+ private[connect] val errorIdToError = CacheBuilder
+ .newBuilder()
+ .ticker(Ticker.systemTicker())
+ .maximumSize(ERROR_CACHE_SIZE)
+ .expireAfterAccess(ERROR_CACHE_TIMEOUT_SEC, TimeUnit.SECONDS)
+ .build[String, Throwable]()
+
val eventManager: SessionEventsManager = SessionEventsManager(this, new
SystemClock())
// Mapping from relation ID (passed to client) to runtime dataframe. Used
for callbacks like
@@ -265,6 +278,12 @@ case class SessionHolder(userId: String, sessionId:
String, session: SparkSessio
object SessionHolder {
+ // The maximum number of distinct errors in the cache.
+ private val ERROR_CACHE_SIZE = 20
+
+ // The maximum time for an error to stay in the cache.
+ private val ERROR_CACHE_TIMEOUT_SEC = 60
+
/** Creates a dummy session holder for use in tests. */
def forTesting(session: SparkSession): SessionHolder = {
val ret =
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
new file mode 100644
index 000000000000..17a6e9e434f3
--- /dev/null
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.service
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.FetchErrorDetailsResponse
+import org.apache.spark.sql.connect.config.Connect
+import org.apache.spark.sql.connect.utils.ErrorUtils
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Handles [[proto.FetchErrorDetailsRequest]]s for the
[[SparkConnectService]]. The handler
+ * retrieves the matched error with details from the cache based on a provided
error id.
+ *
+ * @param responseObserver
+ */
+class SparkConnectFetchErrorDetailsHandler(
+ responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]) {
+
+ def handle(v: proto.FetchErrorDetailsRequest): Unit = {
+ val sessionHolder =
+ SparkConnectService
+ .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
+
+ val response =
Option(sessionHolder.errorIdToError.getIfPresent(v.getErrorId))
+ .map { error =>
+ // This error can only be fetched once,
+ // if a connection dies in the middle you cannot repeat.
+ sessionHolder.errorIdToError.invalidate(v.getErrorId)
+
+ ErrorUtils.throwableToFetchErrorDetailsResponse(
+ st = error,
+ serverStackTraceEnabled = sessionHolder.session.conf.get(
+ Connect.CONNECT_SERVER_STACKTRACE_ENABLED) ||
sessionHolder.session.conf.get(
+ SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED))
+ }
+ .getOrElse(FetchErrorDetailsResponse.newBuilder().build())
+
+ responseObserver.onNext(response)
+
+ responseObserver.onCompleted()
+ }
+}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 269e47609dbf..e82c9cba5626 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -201,6 +201,20 @@ class SparkConnectService(debug: Boolean) extends
AsyncService with BindableServ
sessionId = request.getSessionId)
}
+ override def fetchErrorDetails(
+ request: proto.FetchErrorDetailsRequest,
+ responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]): Unit
= {
+ try {
+ new
SparkConnectFetchErrorDetailsHandler(responseObserver).handle(request)
+ } catch {
+ ErrorUtils.handleError(
+ "getErrorInfo",
+ observer = responseObserver,
+ userId = request.getUserContext.getUserId,
+ sessionId = request.getSessionId)
+ }
+ }
+
private def methodWithCustomMarshallers(methodDesc:
MethodDescriptor[MessageLite, MessageLite])
: MethodDescriptor[MessageLite, MessageLite] = {
val recursionLimit =
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 2050ebc01aa0..1abd44608cd0 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -17,8 +17,12 @@
package org.apache.spark.sql.connect.utils
+import java.util.UUID
+
import scala.annotation.tailrec
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
import com.google.protobuf.{Any => ProtoAny}
@@ -33,13 +37,14 @@ import org.json4s.jackson.JsonMethods
import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
import org.apache.spark.api.python.PythonException
+import org.apache.spark.connect.proto.FetchErrorDetailsResponse
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.config.Connect
-import org.apache.spark.sql.connect.service.ExecuteEventsManager
-import org.apache.spark.sql.connect.service.SparkConnectService
+import org.apache.spark.sql.connect.service.{ExecuteEventsManager,
SessionHolder, SparkConnectService}
import org.apache.spark.sql.internal.SQLConf
private[connect] object ErrorUtils extends Logging {
+
private def allClasses(cl: Class[_]): Seq[Class[_]] = {
val classes = ArrayBuffer.empty[Class[_]]
if (cl != null && !cl.equals(classOf[java.lang.Object])) {
@@ -57,7 +62,67 @@ private[connect] object ErrorUtils extends Logging {
classes.toSeq
}
- private def buildStatusFromThrowable(st: Throwable, stackTraceEnabled:
Boolean): RPCStatus = {
+ // The maximum length of the error chain.
+ private[connect] val MAX_ERROR_CHAIN_LENGTH = 5
+
+ /**
+ * Convert Throwable to a protobuf message FetchErrorDetailsResponse.
+ * @param st
+ * the Throwable to be converted
+ * @param serverStackTraceEnabled
+ * whether to return the server stack trace.
+ * @return
+ * FetchErrorDetailsResponse
+ */
+ private[connect] def throwableToFetchErrorDetailsResponse(
+ st: Throwable,
+ serverStackTraceEnabled: Boolean = false): FetchErrorDetailsResponse = {
+
+ var currentError = st
+ val buffer = mutable.Buffer.empty[FetchErrorDetailsResponse.Error]
+
+ while (buffer.size < MAX_ERROR_CHAIN_LENGTH && currentError != null) {
+ val builder = FetchErrorDetailsResponse.Error
+ .newBuilder()
+ .setMessage(currentError.getMessage)
+ .addAllErrorTypeHierarchy(
+ ErrorUtils.allClasses(currentError.getClass).map(_.getName).asJava)
+
+ if (serverStackTraceEnabled) {
+ builder.addAllStackTrace(
+ currentError.getStackTrace
+ .map { stackTraceElement =>
+ FetchErrorDetailsResponse.StackTraceElement
+ .newBuilder()
+ .setDeclaringClass(stackTraceElement.getClassName)
+ .setMethodName(stackTraceElement.getMethodName)
+ .setFileName(stackTraceElement.getFileName)
+ .setLineNumber(stackTraceElement.getLineNumber)
+ .build()
+ }
+ .toIterable
+ .asJava)
+ }
+
+ val causeIdx = buffer.size + 1
+
+ if (causeIdx < MAX_ERROR_CHAIN_LENGTH && currentError.getCause != null) {
+ builder.setCauseIdx(causeIdx)
+ }
+
+ buffer.append(builder.build())
+
+ currentError = currentError.getCause
+ }
+
+ FetchErrorDetailsResponse
+ .newBuilder()
+ .setRootErrorIdx(0)
+ .addAllErrors(buffer.asJava)
+ .build()
+ }
+
+ private def buildStatusFromThrowable(st: Throwable, sessionHolder:
SessionHolder): RPCStatus = {
val errorInfo = ErrorInfo
.newBuilder()
.setReason(st.getClass.getName)
@@ -66,14 +131,26 @@ private[connect] object ErrorUtils extends Logging {
"classes",
JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName))))
- lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st))
- val withStackTrace = if (stackTraceEnabled && stackTrace.nonEmpty) {
- val maxSize =
SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE)
- errorInfo.putMetadata("stackTrace",
StringUtils.abbreviate(stackTrace.get, maxSize))
- } else {
- errorInfo
+ if (sessionHolder.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED)) {
+ // Generate a new unique key for this exception.
+ val errorId = UUID.randomUUID().toString
+
+ errorInfo.putMetadata("errorId", errorId)
+
+ sessionHolder.errorIdToError
+ .put(errorId, st)
}
+ lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st))
+ val withStackTrace =
+ if (sessionHolder.session.conf.get(
+ SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty) {
+ val maxSize =
SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE)
+ errorInfo.putMetadata("stackTrace",
StringUtils.abbreviate(stackTrace.get, maxSize))
+ } else {
+ errorInfo
+ }
+
RPCStatus
.newBuilder()
.setCode(RPCCode.INTERNAL_VALUE)
@@ -107,21 +184,19 @@ private[connect] object ErrorUtils extends Logging {
sessionId: String,
events: Option[ExecuteEventsManager] = None,
isInterrupted: Boolean = false): PartialFunction[Throwable, Unit] = {
- val session =
+ val sessionHolder =
SparkConnectService
.getOrCreateIsolatedSession(userId, sessionId)
- .session
- val stackTraceEnabled =
session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED)
val partial: PartialFunction[Throwable, (Throwable, Throwable)] = {
case se: SparkException if isPythonExecutionException(se) =>
(
se,
StatusProto.toStatusRuntimeException(
- buildStatusFromThrowable(se.getCause, stackTraceEnabled)))
+ buildStatusFromThrowable(se.getCause, sessionHolder)))
case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e)
=>
- (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e,
stackTraceEnabled)))
+ (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e,
sessionHolder)))
case e: Throwable =>
(
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
new file mode 100644
index 000000000000..c0591dcc9c7b
--- /dev/null
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+
+import scala.concurrent.Promise
+import scala.concurrent.duration._
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.FetchErrorDetailsResponse
+import org.apache.spark.sql.connect.ResourceHelper
+import org.apache.spark.sql.connect.config.Connect
+import org.apache.spark.sql.connect.utils.ErrorUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.util.ThreadUtils
+
+private class FetchErrorDetailsResponseObserver(p:
Promise[FetchErrorDetailsResponse])
+ extends StreamObserver[FetchErrorDetailsResponse] {
+ override def onNext(v: FetchErrorDetailsResponse): Unit = p.success(v)
+ override def onError(throwable: Throwable): Unit = throw throwable
+ override def onCompleted(): Unit = {}
+}
+
+class FetchErrorDetailsHandlerSuite extends SharedSparkSession with
ResourceHelper {
+
+ private val userId = "user1"
+
+ private val sessionId = UUID.randomUUID().toString
+
+ private def fetchErrorDetails(
+ userId: String,
+ sessionId: String,
+ errorId: String): FetchErrorDetailsResponse = {
+ val promise = Promise[FetchErrorDetailsResponse]
+ val handler =
+ new SparkConnectFetchErrorDetailsHandler(new
FetchErrorDetailsResponseObserver(promise))
+ val context = proto.UserContext
+ .newBuilder()
+ .setUserId(userId)
+ .build()
+ val request = proto.FetchErrorDetailsRequest
+ .newBuilder()
+ .setUserContext(context)
+ .setSessionId(sessionId)
+ .setErrorId(errorId)
+ .build()
+ handler.handle(request)
+ ThreadUtils.awaitResult(promise.future, 5.seconds)
+ }
+
+ for (serverStacktraceEnabled <- Seq(false, true)) {
+ test(s"error chain is properly constructed - $serverStacktraceEnabled") {
+ val testError =
+ new Exception("test1", new Exception("test2"))
+ val errorId = UUID.randomUUID().toString()
+
+ val sessionHolder = SparkConnectService
+ .getOrCreateIsolatedSession(userId, sessionId)
+
+ sessionHolder.errorIdToError.put(errorId, testError)
+
+ sessionHolder.session.conf
+ .set(Connect.CONNECT_SERVER_STACKTRACE_ENABLED.key,
serverStacktraceEnabled)
+ sessionHolder.session.conf
+ .set(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED.key, false)
+ try {
+ val response = fetchErrorDetails(userId, sessionId, errorId)
+ assert(response.hasRootErrorIdx)
+ assert(response.getRootErrorIdx == 0)
+
+ assert(response.getErrorsCount == 2)
+ assert(response.getErrors(0).getMessage == "test1")
+ assert(response.getErrors(0).getErrorTypeHierarchyCount == 3)
+ assert(response.getErrors(0).getErrorTypeHierarchy(0) ==
classOf[Exception].getName)
+ assert(response.getErrors(0).getErrorTypeHierarchy(1) ==
classOf[Throwable].getName)
+ assert(response.getErrors(0).getErrorTypeHierarchy(2) ==
classOf[Object].getName)
+ assert(response.getErrors(0).hasCauseIdx)
+ assert(response.getErrors(0).getCauseIdx == 1)
+
+ assert(response.getErrors(1).getMessage == "test2")
+ assert(response.getErrors(1).getErrorTypeHierarchyCount == 3)
+ assert(response.getErrors(1).getErrorTypeHierarchy(0) ==
classOf[Exception].getName)
+ assert(response.getErrors(1).getErrorTypeHierarchy(1) ==
classOf[Throwable].getName)
+ assert(response.getErrors(1).getErrorTypeHierarchy(2) ==
classOf[Object].getName)
+ assert(!response.getErrors(1).hasCauseIdx)
+ if (serverStacktraceEnabled) {
+ assert(response.getErrors(0).getStackTraceCount ==
testError.getStackTrace.length)
+ assert(
+ response.getErrors(1).getStackTraceCount ==
+ testError.getCause.getStackTrace.length)
+ } else {
+ assert(response.getErrors(0).getStackTraceCount == 0)
+ assert(response.getErrors(1).getStackTraceCount == 0)
+ }
+ } finally {
+
sessionHolder.session.conf.unset(Connect.CONNECT_SERVER_STACKTRACE_ENABLED.key)
+
sessionHolder.session.conf.unset(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED.key)
+ }
+ }
+ }
+
+ test("error not found") {
+ val response = fetchErrorDetails(userId, sessionId,
UUID.randomUUID().toString())
+ assert(!response.hasRootErrorIdx)
+ }
+
+ test("invalidate cached exceptions after first request") {
+ val testError = new Exception("test1")
+ val errorId = UUID.randomUUID().toString()
+
+ SparkConnectService
+ .getOrCreateIsolatedSession(userId, sessionId)
+ .errorIdToError
+ .put(errorId, testError)
+
+ val response = fetchErrorDetails(userId, sessionId, errorId)
+ assert(response.hasRootErrorIdx)
+ assert(response.getRootErrorIdx == 0)
+
+ assert(response.getErrorsCount == 1)
+ assert(response.getErrors(0).getMessage == "test1")
+
+ assert(
+ SparkConnectService
+ .getOrCreateIsolatedSession(userId, sessionId)
+ .errorIdToError
+ .size() == 0)
+ }
+
+ test("error chain is truncated after reaching max depth") {
+ var testError = new Exception("test")
+ for (i <- 0 until 2 * ErrorUtils.MAX_ERROR_CHAIN_LENGTH) {
+ val errorId = UUID.randomUUID().toString()
+
+ SparkConnectService
+ .getOrCreateIsolatedSession(userId, sessionId)
+ .errorIdToError
+ .put(errorId, testError)
+
+ val response = fetchErrorDetails(userId, sessionId, errorId)
+ val expectedErrorCount = Math.min(i + 1,
ErrorUtils.MAX_ERROR_CHAIN_LENGTH)
+ assert(response.getErrorsCount == expectedErrorCount)
+ assert(response.getErrors(expectedErrorCount - 1).hasCauseIdx == false)
+
+ testError = new Exception(s"test$i", testError)
+ }
+ }
+}
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py
b/python/pyspark/sql/connect/proto/base_pb2.py
index 731f4445e150..2bde0677e4b7 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
[...]
+
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
[...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -197,6 +197,14 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11157
_RELEASEEXECUTERESPONSE._serialized_start = 11186
_RELEASEEXECUTERESPONSE._serialized_end = 11298
- _SPARKCONNECTSERVICE._serialized_start = 11301
- _SPARKCONNECTSERVICE._serialized_end = 12044
+ _FETCHERRORDETAILSREQUEST._serialized_start = 11301
+ _FETCHERRORDETAILSREQUEST._serialized_end = 11448
+ _FETCHERRORDETAILSRESPONSE._serialized_start = 11451
+ _FETCHERRORDETAILSRESPONSE._serialized_end = 11997
+ _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11596
+ _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11751
+ _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 11754
+ _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 11978
+ _SPARKCONNECTSERVICE._serialized_start = 12000
+ _SPARKCONNECTSERVICE._serialized_end = 12849
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 3dca29230ef2..43254ceb2560 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -2728,3 +2728,183 @@ class
ReleaseExecuteResponse(google.protobuf.message.Message):
) -> typing_extensions.Literal["operation_id"] | None: ...
global___ReleaseExecuteResponse = ReleaseExecuteResponse
+
+class FetchErrorDetailsRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ SESSION_ID_FIELD_NUMBER: builtins.int
+ USER_CONTEXT_FIELD_NUMBER: builtins.int
+ ERROR_ID_FIELD_NUMBER: builtins.int
+ session_id: builtins.str
+ """(Required)
+ The session_id specifies a Spark session for a user identified by
user_context.user_id.
+ The id should be a UUID string of the format
`00112233-4455-6677-8899-aabbccddeeff`.
+ """
+ @property
+ def user_context(self) -> global___UserContext:
+ """User context"""
+ error_id: builtins.str
+ """(Required)
+ The id of the error.
+ """
+ def __init__(
+ self,
+ *,
+ session_id: builtins.str = ...,
+ user_context: global___UserContext | None = ...,
+ error_id: builtins.str = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["user_context",
b"user_context"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "error_id", b"error_id", "session_id", b"session_id",
"user_context", b"user_context"
+ ],
+ ) -> None: ...
+
+global___FetchErrorDetailsRequest = FetchErrorDetailsRequest
+
+class FetchErrorDetailsResponse(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ class StackTraceElement(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ DECLARING_CLASS_FIELD_NUMBER: builtins.int
+ METHOD_NAME_FIELD_NUMBER: builtins.int
+ FILE_NAME_FIELD_NUMBER: builtins.int
+ LINE_NUMBER_FIELD_NUMBER: builtins.int
+ declaring_class: builtins.str
+ """The fully qualified name of the class containing the execution
point."""
+ method_name: builtins.str
+ """The name of the method containing the execution point."""
+ file_name: builtins.str
+ """The name of the file containing the execution point."""
+ line_number: builtins.int
+ """The line number of the source line containing the execution
point."""
+ def __init__(
+ self,
+ *,
+ declaring_class: builtins.str = ...,
+ method_name: builtins.str = ...,
+ file_name: builtins.str = ...,
+ line_number: builtins.int = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "declaring_class",
+ b"declaring_class",
+ "file_name",
+ b"file_name",
+ "line_number",
+ b"line_number",
+ "method_name",
+ b"method_name",
+ ],
+ ) -> None: ...
+
+ class Error(google.protobuf.message.Message):
+ """Error defines the schema for the representing exception."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ ERROR_TYPE_HIERARCHY_FIELD_NUMBER: builtins.int
+ MESSAGE_FIELD_NUMBER: builtins.int
+ STACK_TRACE_FIELD_NUMBER: builtins.int
+ CAUSE_IDX_FIELD_NUMBER: builtins.int
+ @property
+ def error_type_hierarchy(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """The fully qualified names of the exception class and its parent
classes."""
+ message: builtins.str
+ """The detailed message of the exception."""
+ @property
+ def stack_trace(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ global___FetchErrorDetailsResponse.StackTraceElement
+ ]:
+ """The stackTrace of the exception. It will be set
+ if the SQLConf spark.sql.connect.serverStacktrace.enabled is true.
+ """
+ cause_idx: builtins.int
+ """The index of the cause error in errors."""
+ def __init__(
+ self,
+ *,
+ error_type_hierarchy: collections.abc.Iterable[builtins.str] |
None = ...,
+ message: builtins.str = ...,
+ stack_trace: collections.abc.Iterable[
+ global___FetchErrorDetailsResponse.StackTraceElement
+ ]
+ | None = ...,
+ cause_idx: builtins.int | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_cause_idx", b"_cause_idx", "cause_idx", b"cause_idx"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_cause_idx",
+ b"_cause_idx",
+ "cause_idx",
+ b"cause_idx",
+ "error_type_hierarchy",
+ b"error_type_hierarchy",
+ "message",
+ b"message",
+ "stack_trace",
+ b"stack_trace",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_cause_idx",
b"_cause_idx"]
+ ) -> typing_extensions.Literal["cause_idx"] | None: ...
+
+ ROOT_ERROR_IDX_FIELD_NUMBER: builtins.int
+ ERRORS_FIELD_NUMBER: builtins.int
+ root_error_idx: builtins.int
+ """The index of the root error in errors. The field will not be set if the
error is not found."""
+ @property
+ def errors(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ global___FetchErrorDetailsResponse.Error
+ ]:
+ """A list of errors."""
+ def __init__(
+ self,
+ *,
+ root_error_idx: builtins.int | None = ...,
+ errors:
collections.abc.Iterable[global___FetchErrorDetailsResponse.Error] | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_root_error_idx", b"_root_error_idx", "root_error_idx",
b"root_error_idx"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_root_error_idx",
+ b"_root_error_idx",
+ "errors",
+ b"errors",
+ "root_error_idx",
+ b"root_error_idx",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_root_error_idx",
b"_root_error_idx"]
+ ) -> typing_extensions.Literal["root_error_idx"] | None: ...
+
+global___FetchErrorDetailsResponse = FetchErrorDetailsResponse
diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
index e6bfda8a40a8..f6c5573ded6b 100644
--- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
+++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
@@ -70,6 +70,11 @@ class SparkConnectServiceStub(object):
request_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.SerializeToString,
response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.FromString,
)
+ self.FetchErrorDetails = channel.unary_unary(
+ "/spark.connect.SparkConnectService/FetchErrorDetails",
+
request_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.SerializeToString,
+
response_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.FromString,
+ )
class SparkConnectServiceServicer(object):
@@ -136,6 +141,12 @@ class SparkConnectServiceServicer(object):
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
+ def FetchErrorDetails(self, request, context):
+ """FetchErrorDetails retrieves the matched exception with details
based on a provided error id."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
def add_SparkConnectServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
@@ -179,6 +190,11 @@ def add_SparkConnectServiceServicer_to_server(servicer,
server):
request_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.FromString,
response_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.SerializeToString,
),
+ "FetchErrorDetails": grpc.unary_unary_rpc_method_handler(
+ servicer.FetchErrorDetails,
+
request_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.FromString,
+
response_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
"spark.connect.SparkConnectService", rpc_method_handlers
@@ -421,3 +437,32 @@ class SparkConnectService(object):
timeout,
metadata,
)
+
+ @staticmethod
+ def FetchErrorDetails(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ "/spark.connect.SparkConnectService/FetchErrorDetails",
+
spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.SerializeToString,
+
spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]