This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d62a8e63e86d [SPARK-44463][SS][CONNECT] Improve error handling for
Connect steaming Python worker
d62a8e63e86d is described below
commit d62a8e63e86d5d50974bd699cfa49f102a7acf28
Author: bogao007 <[email protected]>
AuthorDate: Wed Sep 20 12:44:30 2023 +0900
[SPARK-44463][SS][CONNECT] Improve error handling for Connect steaming
Python worker
### What changes were proposed in this pull request?
Handle errors inside streaming Python workers (foreach_batch_worker and
listener_worker) and propagate to server side.
- Write 0 to Python worker's outfile if no error occurs.
- Write -2 and traceback to outfile if there's an error which can be read
from the server side.
I was referring to the code
[here](https://github.com/apache/spark/blob/981312284f0776ca847c8d21411f74a72c639b22/python/pyspark/sql/worker/analyze_udtf.py#L157-L160)
from another existing Python worker.
### Why are the changes needed?
Without this change, there's no error handling in streaming python workers.
The server side is
[expecting](https://github.com/apache/spark/blob/981312284f0776ca847c8d21411f74a72c639b22/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala#L128-L129)
0 being written in [python
worker's](https://github.com/apache/spark/blob/981312284f0776ca847c8d21411f74a72c639b22/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
[...]
If we remove the
[lines](https://github.com/apache/spark/blob/981312284f0776ca847c8d21411f74a72c639b22/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala#L128-L129)
reading python worker's output. The streaming query would succeed even if
there's an error in foreachBatch function which is not the desired behavior we
want.
With this PR, we are propagating the errors from Python worker to the
server so it would fail the streaming query.
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
Enabled `test_streaming_foreach_batch_propagates_python_errors` test.
Did manual testing
ForeachBatch:
```
>>> def collectBatch(df, id):
... raise RuntimeError("this should fail the query")
>>> df =
spark.readStream.format("text").load("python/test_support/sql/streaming")
>>> q = df.writeStream.foreachBatch(collectBatch).start()
```
```
23/09/18 14:21:12 ERROR MicroBatchExecution: Query [id =
8168dc4d-02cc-4ddd-996c-96667d928b88, runId =
04829434-767e-4d13-b4c2-e45ce8932223] terminated with error
java.lang.IllegalStateException: Found error inside foreachBatch Python
process: Traceback (most recent call last):
File
"/Users/bo.gao/workplace/spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py",
line 76, in main
File
"/Users/bo.gao/workplace/spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py",
line 69, in process
File "<stdin>", line 2, in collectBatch
RuntimeError: this should fail the query
at
org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper$.$anonfun$pythonForeachBatchWrapper$1(StreamingForeachBatchHelper.scala:137)
at
org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper$.$anonfun$pythonForeachBatchWrapper$1$adapted(StreamingForeachBatchHelper.scala:115)
at
org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper$.$anonfun$dataFrameCachingWrapper$1(StreamingForeachBatchHelper.scala:70)
at
org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper$.$anonfun$dataFrameCachingWrapper$1$adapted(StreamingForeachBatchHelper.scala:60)
at
org.apache.spark.sql.execution.streaming.sources.ForeachBatchSink.addBatch(ForeachBatchSink.scala:34)
at
org.apache.spark.sql.execution.streaming.MicroBatchExecution.$anonfun$runBatch$17(MicroBatchExecution.scala:732)
at
org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$6(SQLExecution.scala:150)
at
org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:241)
at
org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$1(SQLExecution.scala:116)
at
org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
at
org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId0(SQLExecution.scala:72)
at
org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:196)
at
org.apache.spark.sql.execution.streaming.MicroBatchExecution.$anonfun$runBatch$16(MicroBatchExecution.scala:729)
at
org.apache.spark.sql.execution.streaming.ProgressReporter.reportTimeTaken(ProgressReporter.scala:427)
at
org.apache.spark.sql.execution.streaming.ProgressReporter.reportTimeTaken$(ProgressReporter.scala:425)
at
org.apache.spark.sql.execution.streaming.StreamExecution.reportTimeTaken(StreamExecution.scala:67)
at
org.apache.spark.sql.execution.streaming.MicroBatchExecution.runBatch(MicroBatchExecution.scala:729)
at
org.apache.spark.sql.execution.streaming.MicroBatchExecution.$anonfun$runActivatedStream$2(MicroBatchExecution.scala:286)
at
scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at
org.apache.spark.sql.execution.streaming.ProgressReporter.reportTimeTaken(ProgressReporter.scala:427)
at
org.apache.spark.sql.execution.streaming.ProgressReporter.reportTimeTaken$(ProgressReporter.scala:425)
at
org.apache.spark.sql.execution.streaming.StreamExecution.reportTimeTaken(StreamExecution.scala:67)
at
org.apache.spark.sql.execution.streaming.MicroBatchExecution.$anonfun$runActivatedStream$1(MicroBatchExecution.scala:249)
at
org.apache.spark.sql.execution.streaming.ProcessingTimeExecutor.execute(TriggerExecutor.scala:67)
at
org.apache.spark.sql.execution.streaming.MicroBatchExecution.runActivatedStream(MicroBatchExecution.scala:239)
at
org.apache.spark.sql.execution.streaming.StreamExecution.$anonfun$runStream$1(StreamExecution.scala:311)
at
scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at
org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
at
org.apache.spark.sql.execution.streaming.StreamExecution.org$apache$spark$sql$execution$streaming$StreamExecution$$runStream(StreamExecution.scala:289)
at
org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.$anonfun$run$1(StreamExecution.scala:211)
at
scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at
org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
at
org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.run(StreamExecution.scala:211)
```
StreamingQueryListener:
```
>>> class TestListener(StreamingQueryListener):
... def onQueryStarted(self, event):
... raise RuntimeError("this should fail the listener")
... def onQueryProgress(self, event):
... pass
... def onQueryIdle(self, event):
... pass
... def onQueryTerminated(self, event):
... pass
...
>>> test_listener = TestListener()
>>> spark.streams.addListener(test_listener)
>>> df = spark.readStream.format("rate").option("rowsPerSecond", 10).load()
>>> query = df.writeStream.format("noop").queryName("test").start()
>>> query.stop()
```
```
23/09/18 14:18:56 ERROR StreamingQueryListenerBus: Listener
PythonStreamingQueryListener threw an exception
java.lang.IllegalStateException: Found error inside Streaming query
listener Python process for function onQueryStarted: Traceback (most recent
call last):
File
"/Users/bo.gao/workplace/spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/listener_worker.py",
line 90, in main
File
"/Users/bo.gao/workplace/spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/listener_worker.py",
line 78, in process
File "<stdin>", line 3, in onQueryStarted
RuntimeError: this should fail the listener
at
org.apache.spark.sql.connect.planner.PythonStreamingQueryListener.handlePythonWorkerError(StreamingQueryListenerHelper.scala:88)
at
org.apache.spark.sql.connect.planner.PythonStreamingQueryListener.onQueryStarted(StreamingQueryListenerHelper.scala:50)
at
org.apache.spark.sql.execution.streaming.StreamingQueryListenerBus.doPostEvent(StreamingQueryListenerBus.scala:131)
at
org.apache.spark.sql.execution.streaming.StreamingQueryListenerBus.doPostEvent(StreamingQueryListenerBus.scala:43)
at
org.apache.spark.util.ListenerBus.postToAll(ListenerBus.scala:117)
at
org.apache.spark.util.ListenerBus.postToAll$(ListenerBus.scala:101)
at
org.apache.spark.sql.execution.streaming.StreamingQueryListenerBus.postToAll(StreamingQueryListenerBus.scala:88)
at
org.apache.spark.sql.execution.streaming.StreamingQueryListenerBus.post(StreamingQueryListenerBus.scala:77)
at
org.apache.spark.sql.streaming.StreamingQueryManager.postListenerEvent(StreamingQueryManager.scala:231)
at
org.apache.spark.sql.execution.streaming.StreamExecution.postEvent(StreamExecution.scala:408)
at
org.apache.spark.sql.execution.streaming.StreamExecution.org$apache$spark$sql$execution$streaming$StreamExecution$$runStream(StreamExecution.scala:283)
at
org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.$anonfun$run$1(StreamExecution.scala:211)
at
scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at
org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
at
org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.run(StreamExecution.scala:211)
```
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #42986 from bogao007/error-handling.
Authored-by: bogao007 <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../planner/StreamingForeachBatchHelper.scala | 30 ++++++++++++++--
.../planner/StreamingQueryListenerHelper.scala | 42 ++++++++++++++++++++--
.../streaming/worker/foreach_batch_worker.py | 10 ++++--
.../connect/streaming/worker/listener_worker.py | 9 ++++-
.../connect/streaming/test_parity_foreach_batch.py | 1 -
python/pyspark/sql/worker/analyze_udtf.py | 23 ++----------
python/pyspark/util.py | 31 +++++++++++++++-
python/pyspark/worker.py | 24 ++-----------
8 files changed, 116 insertions(+), 54 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index c30e08bc39dd..5ef0aea6b61c 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.connect.planner
+import java.io.EOFException
+import java.nio.charset.StandardCharsets
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentMap
@@ -23,7 +25,8 @@ import java.util.concurrent.ConcurrentMap
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
-import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction,
StreamingPythonRunner}
+import org.apache.spark.SparkException
+import org.apache.spark.api.python.{PythonException, PythonRDD,
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.connect.service.SessionHolder
@@ -125,8 +128,29 @@ object StreamingForeachBatchHelper extends Logging {
dataOut.writeLong(args.batchId)
dataOut.flush()
- val ret = dataIn.readInt()
- logInfo(s"Python foreach batch for dfId ${args.dfId} completed (ret:
$ret)")
+ try {
+ dataIn.readInt() match {
+ case 0 =>
+ logInfo(s"Python foreach batch for dfId ${args.dfId} completed
(ret: 0)")
+ case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+ val exLength = dataIn.readInt()
+ val obj = new Array[Byte](exLength)
+ dataIn.readFully(obj)
+ val msg = new String(obj, StandardCharsets.UTF_8)
+ throw new PythonException(
+ s"Found error inside foreachBatch Python process: $msg",
+ null)
+ case otherValue =>
+ throw new IllegalStateException(
+ s"Unexpected return value $otherValue from the " +
+ s"Python worker.")
+ }
+ } catch {
+ // TODO: Better handling (e.g. retries) on exceptions like
EOFException to avoid
+ // transient errors, same for StreamingQueryListenerHelper.
+ case eof: EOFException =>
+ throw new SparkException("Python worker exited unexpectedly
(crashed)", eof)
+ }
}
(dataFrameCachingWrapper(foreachBatchRunnerFn, sessionHolder),
RunnerCleaner(runner))
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
index 01339a8a1b47..886aeab3befd 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
@@ -17,7 +17,12 @@
package org.apache.spark.sql.connect.planner
-import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction,
StreamingPythonRunner}
+import java.io.EOFException
+import java.nio.charset.StandardCharsets
+
+import org.apache.spark.SparkException
+import org.apache.spark.api.python.{PythonException, PythonRDD,
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.service.{SessionHolder,
SparkConnectService}
import org.apache.spark.sql.streaming.StreamingQueryListener
@@ -27,7 +32,8 @@ import org.apache.spark.sql.streaming.StreamingQueryListener
* When a new event is received, it is serialized to json, and passed to the
python process.
*/
class PythonStreamingQueryListener(listener: SimplePythonFunction,
sessionHolder: SessionHolder)
- extends StreamingQueryListener {
+ extends StreamingQueryListener
+ with Logging {
private val port = SparkConnectService.localPort
private val connectUrl =
s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
@@ -38,33 +44,63 @@ class PythonStreamingQueryListener(listener:
SimplePythonFunction, sessionHolder
sessionHolder.sessionId,
"pyspark.sql.connect.streaming.worker.listener_worker")
- val (dataOut, _) = runner.init()
+ val (dataOut, dataIn) = runner.init()
override def onQueryStarted(event:
StreamingQueryListener.QueryStartedEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(0)
dataOut.flush()
+ handlePythonWorkerError("onQueryStarted")
}
override def onQueryProgress(event:
StreamingQueryListener.QueryProgressEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(1)
dataOut.flush()
+ handlePythonWorkerError("onQueryProgress")
}
override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit
= {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(2)
dataOut.flush()
+ handlePythonWorkerError("onQueryIdle")
}
override def onQueryTerminated(event:
StreamingQueryListener.QueryTerminatedEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(3)
dataOut.flush()
+ handlePythonWorkerError("onQueryTerminated")
}
private[spark] def stopListenerProcess(): Unit = {
runner.stop()
}
+
+ // TODO: Reuse the same method in StreamingForeachBatchHelper to avoid
duplication.
+ private def handlePythonWorkerError(functionName: String): Unit = {
+ try {
+ dataIn.readInt() match {
+ case 0 =>
+ logInfo(s"Streaming query listener function $functionName completed
(ret: 0)")
+ case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+ val exLength = dataIn.readInt()
+ val obj = new Array[Byte](exLength)
+ dataIn.readFully(obj)
+ val msg = new String(obj, StandardCharsets.UTF_8)
+ throw new PythonException(
+ s"Found error inside Streaming query listener Python " +
+ s"process for function $functionName: $msg",
+ null)
+ case otherValue =>
+ throw new IllegalStateException(
+ s"Unexpected return value $otherValue from the " +
+ s"Python worker.")
+ }
+ } catch {
+ case eof: EOFException =>
+ throw new SparkException("Python worker exited unexpectedly
(crashed)", eof)
+ }
+ }
}
diff --git
a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
index 72037f1263db..06534e355de7 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
@@ -30,6 +30,7 @@ from pyspark.serializers import (
)
from pyspark import worker
from pyspark.sql import SparkSession
+from pyspark.util import handle_worker_exception
from typing import IO
from pyspark.worker_util import check_python_version
@@ -69,8 +70,13 @@ def main(infile: IO, outfile: IO) -> None:
while True:
df_ref_id = utf8_deserializer.loads(infile)
batch_id = read_long(infile)
- process(df_ref_id, int(batch_id)) # TODO(SPARK-44463): Propagate
error to the user.
- write_int(0, outfile)
+ # Handle errors inside Python worker. Write 0 to outfile if no errors
and write -2 with
+ # traceback string if error occurs.
+ try:
+ process(df_ref_id, int(batch_id))
+ write_int(0, outfile)
+ except BaseException as e:
+ handle_worker_exception(e, outfile)
outfile.flush()
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index c026945767d9..ed38a7884359 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -31,6 +31,7 @@ from pyspark.serializers import (
)
from pyspark import worker
from pyspark.sql import SparkSession
+from pyspark.util import handle_worker_exception
from typing import IO
from pyspark.sql.streaming.listener import (
@@ -83,7 +84,13 @@ def main(infile: IO, outfile: IO) -> None:
while True:
event = utf8_deserializer.loads(infile)
event_type = read_int(infile)
- process(event, int(event_type)) # TODO(SPARK-44463): Propagate error
to the user.
+ # Handle errors inside Python worker. Write 0 to outfile if no errors
and write -2 with
+ # traceback string if error occurs.
+ try:
+ process(event, int(event_type))
+ write_int(0, outfile)
+ except BaseException as e:
+ handle_worker_exception(e, outfile)
outfile.flush()
diff --git
a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
index c174bd53f8e9..30f7bb8c2df9 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
@@ -23,7 +23,6 @@ from pyspark.errors import PySparkPicklingError
class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin,
ReusedConnectTestCase):
- @unittest.skip("SPARK-44463: Error handling needs improvement in connect
foreachBatch")
def test_streaming_foreach_batch_propagates_python_errors(self):
super().test_streaming_foreach_batch_propagates_python_errors()
diff --git a/python/pyspark/sql/worker/analyze_udtf.py
b/python/pyspark/sql/worker/analyze_udtf.py
index 29665b586a36..194cd3db7655 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -18,7 +18,6 @@
import inspect
import os
import sys
-import traceback
from typing import Dict, List, IO, Tuple
from pyspark.accumulators import _accumulatorRegistry
@@ -33,7 +32,7 @@ from pyspark.serializers import (
)
from pyspark.sql.types import _parse_datatype_json_string
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
-from pyspark.util import try_simplify_traceback
+from pyspark.util import handle_worker_exception
from pyspark.worker_util import (
check_python_version,
read_command,
@@ -146,25 +145,7 @@ def main(infile: IO, outfile: IO) -> None:
write_int(2, outfile)
except BaseException as e:
- try:
- exc_info = None
- if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
- tb = try_simplify_traceback(sys.exc_info()[-1]) # type:
ignore[arg-type]
- if tb is not None:
- e.__cause__ = None
- exc_info = "".join(traceback.format_exception(type(e), e,
tb))
- if exc_info is None:
- exc_info = traceback.format_exc()
-
- write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
- write_with_length(exc_info.encode("utf-8"), outfile)
- except IOError:
- # JVM close the socket
- pass
- except BaseException:
- # Write the error to stderr if it happened while serializing
- print("PySpark worker failed with exception:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ handle_worker_exception(e, outfile)
sys.exit(-1)
send_accumulator_updates(outfile)
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 87f808549d1a..47f5933079e2 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -26,7 +26,7 @@ import threading
import traceback
import typing
from types import TracebackType
-from typing import Any, Callable, Iterator, List, Optional, TextIO, Tuple,
Union
+from typing import Any, Callable, IO, Iterator, List, Optional, TextIO, Tuple,
Union
from pyspark.errors import PySparkRuntimeError
@@ -382,6 +382,35 @@ def inheritable_thread_target(f: Optional[Union[Callable,
"SparkSession"]] = Non
return f # type: ignore[return-value]
+def handle_worker_exception(e: BaseException, outfile: IO) -> None:
+ """
+ Handles exception for Python worker which writes
SpecialLengths.PYTHON_EXCEPTION_THROWN (-2)
+ and exception traceback info to outfile. JVM could then read from the
outfile and perform
+ exception handling there.
+ """
+ from pyspark.serializers import write_int, write_with_length,
SpecialLengths
+
+ try:
+ exc_info = None
+ if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
+ tb = try_simplify_traceback(sys.exc_info()[-1]) # type:
ignore[arg-type]
+ if tb is not None:
+ e.__cause__ = None
+ exc_info = "".join(traceback.format_exception(type(e), e, tb))
+ if exc_info is None:
+ exc_info = traceback.format_exc()
+
+ write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
+ write_with_length(exc_info.encode("utf-8"), outfile)
+ except IOError:
+ # JVM close the socket
+ pass
+ except BaseException:
+ # Write the error to stderr if it happened while serializing
+ print("PySpark worker failed with exception:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+
class InheritableThread(threading.Thread):
"""
Thread that is recommended to be used in PySpark when the pinned thread
mode is
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 026c07fb9988..a3c7bbb59ddf 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -25,7 +25,6 @@ from inspect import getfullargspec
import json
from typing import Any, Callable, Iterable, Iterator
-import traceback
import faulthandler
from pyspark.accumulators import _accumulatorRegistry
@@ -34,7 +33,6 @@ from pyspark.taskcontext import BarrierTaskContext,
TaskContext
from pyspark.resource import ResourceInformation
from pyspark.rdd import PythonEvalType
from pyspark.serializers import (
- write_with_length,
write_int,
read_long,
read_bool,
@@ -54,7 +52,7 @@ from pyspark.sql.pandas.serializers import (
)
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.types import BinaryType, Row, StringType, StructType,
_parse_datatype_json_string
-from pyspark.util import fail_on_stopiteration, try_simplify_traceback
+from pyspark.util import fail_on_stopiteration, handle_worker_exception
from pyspark import shuffle
from pyspark.errors import PySparkRuntimeError, PySparkTypeError
from pyspark.worker_util import (
@@ -1324,25 +1322,7 @@ def main(infile, outfile):
TaskContext._setTaskContext(None)
BarrierTaskContext._setTaskContext(None)
except BaseException as e:
- try:
- exc_info = None
- if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
- tb = try_simplify_traceback(sys.exc_info()[-1])
- if tb is not None:
- e.__cause__ = None
- exc_info = "".join(traceback.format_exception(type(e), e,
tb))
- if exc_info is None:
- exc_info = traceback.format_exc()
-
- write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
- write_with_length(exc_info.encode("utf-8"), outfile)
- except IOError:
- # JVM close the socket
- pass
- except BaseException:
- # Write the error to stderr if it happened while serializing
- print("PySpark worker failed with exception:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ handle_worker_exception(e, outfile)
sys.exit(-1)
finally:
if faulthandler_log_path:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]