This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1ca543b5595e [SPARK-45808][CONNECT][PYTHON] Better error handling for
SQL Exceptions
1ca543b5595e is described below
commit 1ca543b5595ebfff4c46500df0ef7715c440c050
Author: Martin Grund <[email protected]>
AuthorDate: Tue Nov 7 10:12:16 2023 -0800
[SPARK-45808][CONNECT][PYTHON] Better error handling for SQL Exceptions
### What changes were proposed in this pull request?
This patch optimizes the handling of errors reported back to Python. First,
it properly allows the extraction of the `ERROR_CLASS` and the `SQL_STATE` and
gives simpler accces to the stack trace.
It therefore makes sure that the display of the stack trace is no longer
only server-side decided but becomes a local usability property.
In addition the following methods on the `SparkConnectGrpcException` become
actually useful:
* `getSqlState()`
* `getErrorClass()`
* `getStackTrace()`
### Why are the changes needed?
Compatibility
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Updated the existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43667 from grundprinzip/SPARK-XXXX-ex.
Authored-by: Martin Grund <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../org/apache/spark/sql/ClientE2ETestSuite.scala | 3 +-
.../SparkConnectFetchErrorDetailsHandler.scala | 6 +-
.../spark/sql/connect/utils/ErrorUtils.scala | 14 ++
.../service/FetchErrorDetailsHandlerSuite.scala | 14 +-
.../service/SparkConnectSessionHolderSuite.scala | 102 ++++++------
python/pyspark/errors/exceptions/base.py | 2 +-
python/pyspark/errors/exceptions/captured.py | 2 +-
python/pyspark/errors/exceptions/connect.py | 178 ++++++++++++++++++---
python/pyspark/sql/connect/client/core.py | 13 +-
.../sql/tests/connect/test_connect_basic.py | 25 +--
10 files changed, 258 insertions(+), 101 deletions(-)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index b9fa415034c3..10c928f13041 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -136,8 +136,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with
SQLHelper with PrivateM
assert(
ex.getStackTrace
.find(_.getClassName.contains("org.apache.spark.sql.catalyst.analysis.CheckAnalysis"))
- .isDefined
- == isServerStackTraceEnabled)
+ .isDefined)
}
}
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
index 17a6e9e434f3..b5a3c986d169 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
@@ -20,9 +20,7 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.FetchErrorDetailsResponse
-import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.utils.ErrorUtils
-import org.apache.spark.sql.internal.SQLConf
/**
* Handles [[proto.FetchErrorDetailsRequest]]s for the
[[SparkConnectService]]. The handler
@@ -46,9 +44,7 @@ class SparkConnectFetchErrorDetailsHandler(
ErrorUtils.throwableToFetchErrorDetailsResponse(
st = error,
- serverStackTraceEnabled = sessionHolder.session.conf.get(
- Connect.CONNECT_SERVER_STACKTRACE_ENABLED) ||
sessionHolder.session.conf.get(
- SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED))
+ serverStackTraceEnabled = true)
}
.getOrElse(FetchErrorDetailsResponse.newBuilder().build())
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 744fa3c8aa1a..7cb555ca47ec 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -164,6 +164,20 @@ private[connect] object ErrorUtils extends Logging {
"classes",
JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName))))
+ // Add the SQL State and Error Class to the response metadata of the
ErrorInfoObject.
+ st match {
+ case e: SparkThrowable =>
+ val state = e.getSqlState
+ if (state != null && state.nonEmpty) {
+ errorInfo.putMetadata("sqlState", state)
+ }
+ val errorClass = e.getErrorClass
+ if (errorClass != null && errorClass.nonEmpty) {
+ errorInfo.putMetadata("errorClass", errorClass)
+ }
+ case _ =>
+ }
+
if
(sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED)))
{
// Generate a new unique key for this exception.
val errorId = UUID.randomUUID().toString
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
index 40439a217230..ebcd1de60057 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
@@ -103,15 +103,11 @@ class FetchErrorDetailsHandlerSuite extends
SharedSparkSession with ResourceHelp
assert(response.getErrors(1).getErrorTypeHierarchy(1) ==
classOf[Throwable].getName)
assert(response.getErrors(1).getErrorTypeHierarchy(2) ==
classOf[Object].getName)
assert(!response.getErrors(1).hasCauseIdx)
- if (serverStacktraceEnabled) {
- assert(response.getErrors(0).getStackTraceCount ==
testError.getStackTrace.length)
- assert(
- response.getErrors(1).getStackTraceCount ==
- testError.getCause.getStackTrace.length)
- } else {
- assert(response.getErrors(0).getStackTraceCount == 0)
- assert(response.getErrors(1).getStackTraceCount == 0)
- }
+ assert(response.getErrors(0).getStackTraceCount ==
testError.getStackTrace.length)
+ assert(
+ response.getErrors(1).getStackTraceCount ==
+ testError.getCause.getStackTrace.length)
+
} finally {
sessionHolder.session.conf.unset(Connect.CONNECT_SERVER_STACKTRACE_ENABLED.key)
sessionHolder.session.conf.unset(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED.key)
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
index 910c2a2650c6..9845cee31037 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
@@ -169,6 +169,56 @@ class SparkConnectSessionHolderSuite extends
SharedSparkSession {
accumulator = null)
}
+ test("python listener process: process terminates after listener is
removed") {
+ // scalastyle:off assume
+ assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
+ // scalastyle:on assume
+
+ val sessionHolder = SessionHolder.forTesting(spark)
+ try {
+ SparkConnectService.start(spark.sparkContext)
+
+ val pythonFn =
dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction)
+
+ val id1 = "listener_removeListener_test_1"
+ val id2 = "listener_removeListener_test_2"
+ val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
+ val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
+
+ sessionHolder.cacheListenerById(id1, listener1)
+ spark.streams.addListener(listener1)
+ sessionHolder.cacheListenerById(id2, listener2)
+ spark.streams.addListener(listener2)
+
+ val (runner1, runner2) = (listener1.runner, listener2.runner)
+
+ // assert both python processes are running
+ assert(!runner1.isWorkerStopped().get)
+ assert(!runner2.isWorkerStopped().get)
+
+ // remove listener1
+ spark.streams.removeListener(listener1)
+ sessionHolder.removeCachedListener(id1)
+ // assert listener1's python process is not running
+ eventually(timeout(30.seconds)) {
+ assert(runner1.isWorkerStopped().get)
+ assert(!runner2.isWorkerStopped().get)
+ }
+
+ // remove listener2
+ spark.streams.removeListener(listener2)
+ sessionHolder.removeCachedListener(id2)
+ eventually(timeout(30.seconds)) {
+ // assert listener2's python process is not running
+ assert(runner2.isWorkerStopped().get)
+ // all listeners are removed
+ assert(spark.streams.listListeners().isEmpty)
+ }
+ } finally {
+ SparkConnectService.stop()
+ }
+ }
+
test("python foreachBatch process: process terminates after query is
stopped") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
@@ -232,58 +282,10 @@ class SparkConnectSessionHolderSuite extends
SharedSparkSession {
assert(spark.streams.listListeners().length == 1) // only process
termination listener
} finally {
SparkConnectService.stop()
+ // Wait for things to calm down.
+ Thread.sleep(4.seconds.toMillis)
// remove process termination listener
spark.streams.listListeners().foreach(spark.streams.removeListener)
}
}
-
- test("python listener process: process terminates after listener is
removed") {
- // scalastyle:off assume
- assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
- // scalastyle:on assume
-
- val sessionHolder = SessionHolder.forTesting(spark)
- try {
- SparkConnectService.start(spark.sparkContext)
-
- val pythonFn =
dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction)
-
- val id1 = "listener_removeListener_test_1"
- val id2 = "listener_removeListener_test_2"
- val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
- val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
-
- sessionHolder.cacheListenerById(id1, listener1)
- spark.streams.addListener(listener1)
- sessionHolder.cacheListenerById(id2, listener2)
- spark.streams.addListener(listener2)
-
- val (runner1, runner2) = (listener1.runner, listener2.runner)
-
- // assert both python processes are running
- assert(!runner1.isWorkerStopped().get)
- assert(!runner2.isWorkerStopped().get)
-
- // remove listener1
- spark.streams.removeListener(listener1)
- sessionHolder.removeCachedListener(id1)
- // assert listener1's python process is not running
- eventually(timeout(30.seconds)) {
- assert(runner1.isWorkerStopped().get)
- assert(!runner2.isWorkerStopped().get)
- }
-
- // remove listener2
- spark.streams.removeListener(listener2)
- sessionHolder.removeCachedListener(id2)
- eventually(timeout(30.seconds)) {
- // assert listener2's python process is not running
- assert(runner2.isWorkerStopped().get)
- // all listeners are removed
- assert(spark.streams.listListeners().isEmpty)
- }
- } finally {
- SparkConnectService.stop()
- }
- }
}
diff --git a/python/pyspark/errors/exceptions/base.py
b/python/pyspark/errors/exceptions/base.py
index 1d09a68dffbf..518a2d99ce88 100644
--- a/python/pyspark/errors/exceptions/base.py
+++ b/python/pyspark/errors/exceptions/base.py
@@ -75,7 +75,7 @@ class PySparkException(Exception):
"""
return self.message_parameters
- def getSqlState(self) -> None:
+ def getSqlState(self) -> Optional[str]:
"""
Returns an SQLSTATE as a string.
diff --git a/python/pyspark/errors/exceptions/captured.py
b/python/pyspark/errors/exceptions/captured.py
index d62b7d24347e..55ed7ab3a6d5 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -107,7 +107,7 @@ class CapturedException(PySparkException):
else:
return None
- def getSqlState(self) -> Optional[str]: # type: ignore[override]
+ def getSqlState(self) -> Optional[str]:
assert SparkContext._gateway is not None
gw = SparkContext._gateway
if self._origin is not None and is_instance_of(
diff --git a/python/pyspark/errors/exceptions/connect.py
b/python/pyspark/errors/exceptions/connect.py
index 423fb2c6f0ac..2558c425469a 100644
--- a/python/pyspark/errors/exceptions/connect.py
+++ b/python/pyspark/errors/exceptions/connect.py
@@ -46,55 +46,155 @@ class SparkConnectException(PySparkException):
def convert_exception(
- info: "ErrorInfo", truncated_message: str, resp:
Optional[pb2.FetchErrorDetailsResponse]
+ info: "ErrorInfo",
+ truncated_message: str,
+ resp: Optional[pb2.FetchErrorDetailsResponse],
+ display_server_stacktrace: bool = False,
) -> SparkConnectException:
classes = []
+ sql_state = None
+ error_class = None
+
+ stacktrace: Optional[str] = None
+
if "classes" in info.metadata:
classes = json.loads(info.metadata["classes"])
+ if "sqlState" in info.metadata:
+ sql_state = info.metadata["sqlState"]
+
+ if "errorClass" in info.metadata:
+ error_class = info.metadata["errorClass"]
+
if resp is not None and resp.HasField("root_error_idx"):
message = resp.errors[resp.root_error_idx].message
stacktrace = _extract_jvm_stacktrace(resp)
else:
message = truncated_message
- stacktrace = info.metadata["stackTrace"] if "stackTrace" in
info.metadata else ""
-
- if len(stacktrace) > 0:
- message += f"\n\nJVM stacktrace:\n{stacktrace}"
+ stacktrace = info.metadata["stackTrace"] if "stackTrace" in
info.metadata else None
+ display_server_stacktrace = display_server_stacktrace if stacktrace is
not None else False
if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
- return ParseException(message)
+ return ParseException(
+ message,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
# Order matters. ParseException inherits AnalysisException.
elif "org.apache.spark.sql.AnalysisException" in classes:
- return AnalysisException(message)
+ return AnalysisException(
+ message,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes:
- return StreamingQueryException(message)
+ return StreamingQueryException(
+ message,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
elif "org.apache.spark.sql.execution.QueryExecutionException" in classes:
- return QueryExecutionException(message)
+ return QueryExecutionException(
+ message,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
# Order matters. NumberFormatException inherits IllegalArgumentException.
elif "java.lang.NumberFormatException" in classes:
- return NumberFormatException(message)
+ return NumberFormatException(
+ message,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
elif "java.lang.IllegalArgumentException" in classes:
- return IllegalArgumentException(message)
+ return IllegalArgumentException(
+ message,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
elif "java.lang.ArithmeticException" in classes:
- return ArithmeticException(message)
+ return ArithmeticException(
+ message,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
elif "java.lang.UnsupportedOperationException" in classes:
- return UnsupportedOperationException(message)
+ return UnsupportedOperationException(
+ message,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
elif "java.lang.ArrayIndexOutOfBoundsException" in classes:
- return ArrayIndexOutOfBoundsException(message)
+ return ArrayIndexOutOfBoundsException(
+ message,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
elif "java.time.DateTimeException" in classes:
- return DateTimeException(message)
+ return DateTimeException(
+ message,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
elif "org.apache.spark.SparkRuntimeException" in classes:
- return SparkRuntimeException(message)
+ return SparkRuntimeException(
+ message,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
elif "org.apache.spark.SparkUpgradeException" in classes:
- return SparkUpgradeException(message)
+ return SparkUpgradeException(
+ message,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
elif "org.apache.spark.api.python.PythonException" in classes:
return PythonException(
"\n An exception was thrown from the Python worker. "
"Please see the stack trace below.\n%s" % message
)
+ # Make sure that the generic SparkException is handled last.
+ elif "org.apache.spark.SparkException" in classes:
+ return SparkException(
+ message,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
else:
- return SparkConnectGrpcException(message, reason=info.reason)
+ return SparkConnectGrpcException(
+ message,
+ reason=info.reason,
+ error_class=error_class,
+ sql_state=sql_state,
+ server_stacktrace=stacktrace,
+ display_server_stacktrace=display_server_stacktrace,
+ )
def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str:
@@ -106,7 +206,7 @@ def _extract_jvm_stacktrace(resp:
pb2.FetchErrorDetailsResponse) -> str:
def format_stacktrace(error: pb2.FetchErrorDetailsResponse.Error) -> None:
message = f"{error.error_type_hierarchy[0]}: {error.message}"
if len(lines) == 0:
- lines.append(message)
+ lines.append(error.error_type_hierarchy[0])
else:
lines.append(f"Caused by: {message}")
for elem in error.stack_trace:
@@ -135,16 +235,48 @@ class SparkConnectGrpcException(SparkConnectException):
error_class: Optional[str] = None,
message_parameters: Optional[Dict[str, str]] = None,
reason: Optional[str] = None,
+ sql_state: Optional[str] = None,
+ server_stacktrace: Optional[str] = None,
+ display_server_stacktrace: bool = False,
) -> None:
self.message = message # type: ignore[assignment]
if reason is not None:
self.message = f"({reason}) {self.message}"
+ # PySparkException has the assumption that error_class and
message_parameters are
+ # only occurring together. If only one is set, we assume the message
to be fully
+ # parsed.
+ tmp_error_class = error_class
+ tmp_message_parameters = message_parameters
+ if error_class is not None and message_parameters is None:
+ tmp_error_class = None
+ elif error_class is None and message_parameters is not None:
+ tmp_message_parameters = None
+
super().__init__(
message=self.message,
- error_class=error_class,
- message_parameters=message_parameters,
+ error_class=tmp_error_class,
+ message_parameters=tmp_message_parameters,
)
+ self.error_class = error_class
+ self._sql_state: Optional[str] = sql_state
+ self._stacktrace: Optional[str] = server_stacktrace
+ self._display_stacktrace: bool = display_server_stacktrace
+
+ def getSqlState(self) -> Optional[str]:
+ if self._sql_state is not None:
+ return self._sql_state
+ else:
+ return super().getSqlState()
+
+ def getStackTrace(self) -> Optional[str]:
+ return self._stacktrace
+
+ def __str__(self) -> str:
+ desc = self.message
+ if self._display_stacktrace:
+ desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace
+ return desc
class AnalysisException(SparkConnectGrpcException, BaseAnalysisException):
@@ -223,3 +355,7 @@ class SparkUpgradeException(SparkConnectGrpcException,
BaseSparkUpgradeException
"""
Exception thrown because of Spark upgrade from Spark Connect.
"""
+
+
+class SparkException(SparkConnectGrpcException):
+ """ """
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 11a1112ad1fe..cef0ea4f305d 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1564,6 +1564,14 @@ class SparkConnectClient(object):
except grpc.RpcError:
return None
+ def _display_server_stack_trace(self) -> bool:
+ from pyspark.sql.connect.conf import RuntimeConf
+
+ conf = RuntimeConf(self)
+ if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true":
+ return True
+ return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true"
+
def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
"""
Error handling helper for dealing with GRPC Errors. On the server
side, certain
@@ -1594,7 +1602,10 @@ class SparkConnectClient(object):
d.Unpack(info)
raise convert_exception(
- info, status.message, self._fetch_enriched_error(info)
+ info,
+ status.message,
+ self._fetch_enriched_error(info),
+ self._display_server_stack_trace(),
) from None
raise SparkConnectGrpcException(status.message) from None
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index f024a03c2686..daf6772e52bf 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -3378,35 +3378,37 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
"""select from_json(
'{"d": "02-29"}', 'd date', map('dateFormat',
'MM-dd'))"""
).collect()
- self.assertTrue("JVM stacktrace" in e.exception.message)
- self.assertTrue("org.apache.spark.SparkUpgradeException:" in
e.exception.message)
+ self.assertTrue("JVM stacktrace" in str(e.exception))
+ self.assertTrue("org.apache.spark.SparkUpgradeException" in
str(e.exception))
self.assertTrue(
"at org.apache.spark.sql.errors.ExecutionErrors"
- ".failToParseDateTimeInNewParserError" in
e.exception.message
+ ".failToParseDateTimeInNewParserError" in str(e.exception)
)
- self.assertTrue("Caused by: java.time.DateTimeException:" in
e.exception.message)
+ self.assertTrue("Caused by: java.time.DateTimeException:" in
str(e.exception))
def test_not_hitting_netty_header_limit(self):
with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}):
with self.assertRaises(AnalysisException):
- self.spark.sql("select " + "test" * 10000).collect()
+ self.spark.sql("select " + "test" * 1).collect()
def test_error_stack_trace(self):
with self.sql_conf({"spark.sql.connect.enrichError.enabled": False}):
with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled":
True}):
with self.assertRaises(AnalysisException) as e:
self.spark.sql("select x").collect()
- self.assertTrue("JVM stacktrace" in e.exception.message)
+ self.assertTrue("JVM stacktrace" in str(e.exception))
+ self.assertIsNotNone(e.exception.getStackTrace())
self.assertTrue(
- "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis"
in e.exception.message
+ "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis"
in str(e.exception)
)
with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled":
False}):
with self.assertRaises(AnalysisException) as e:
self.spark.sql("select x").collect()
- self.assertFalse("JVM stacktrace" in e.exception.message)
+ self.assertFalse("JVM stacktrace" in str(e.exception))
+ self.assertIsNone(e.exception.getStackTrace())
self.assertFalse(
- "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis"
in e.exception.message
+ "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis"
in str(e.exception)
)
# Create a new session with a different stack trace size.
@@ -3421,9 +3423,10 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", True)
with self.assertRaises(AnalysisException) as e:
spark.sql("select x").collect()
- self.assertTrue("JVM stacktrace" in e.exception.message)
+ self.assertTrue("JVM stacktrace" in str(e.exception))
+ self.assertIsNotNone(e.exception.getStackTrace())
self.assertFalse(
- "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in
e.exception.message
+ "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in
str(e.exception)
)
spark.stop()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]