This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d62a8e63e86d [SPARK-44463][SS][CONNECT] Improve error handling for 
Connect steaming Python worker
d62a8e63e86d is described below

commit d62a8e63e86d5d50974bd699cfa49f102a7acf28
Author: bogao007 <[email protected]>
AuthorDate: Wed Sep 20 12:44:30 2023 +0900

    [SPARK-44463][SS][CONNECT] Improve error handling for Connect steaming 
Python worker
    
    ### What changes were proposed in this pull request?
    
    Handle errors inside streaming Python workers (foreach_batch_worker and 
listener_worker) and propagate to server side.
    - Write 0 to Python worker's outfile if no error occurs.
    - Write -2 and traceback to outfile if there's an error which can be read 
from the server side.
    
    I was referring to the code 
[here](https://github.com/apache/spark/blob/981312284f0776ca847c8d21411f74a72c639b22/python/pyspark/sql/worker/analyze_udtf.py#L157-L160)
 from another existing Python worker.
    
    ### Why are the changes needed?
    
    Without this change, there's no error handling in streaming python workers. 
The server side is 
[expecting](https://github.com/apache/spark/blob/981312284f0776ca847c8d21411f74a72c639b22/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala#L128-L129)
 0 being written in [python 
worker's](https://github.com/apache/spark/blob/981312284f0776ca847c8d21411f74a72c639b22/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
 [...]
    
    If we remove the 
[lines](https://github.com/apache/spark/blob/981312284f0776ca847c8d21411f74a72c639b22/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala#L128-L129)
 reading python worker's output. The streaming query would succeed even if 
there's an error in foreachBatch function which is not the desired behavior we 
want.
    
    With this PR, we are propagating the errors from Python worker to the 
server so it would fail the streaming query.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes
    
    ### How was this patch tested?
    
    Enabled `test_streaming_foreach_batch_propagates_python_errors` test.
    
    Did manual testing
    ForeachBatch:
    ```
    >>> def collectBatch(df, id):
    ...             raise RuntimeError("this should fail the query")
    >>> df = 
spark.readStream.format("text").load("python/test_support/sql/streaming")
    >>> q = df.writeStream.foreachBatch(collectBatch).start()
    ```
    
    ```
    23/09/18 14:21:12 ERROR MicroBatchExecution: Query [id = 
8168dc4d-02cc-4ddd-996c-96667d928b88, runId = 
04829434-767e-4d13-b4c2-e45ce8932223] terminated with error
    java.lang.IllegalStateException: Found error inside foreachBatch Python 
process: Traceback (most recent call last):
      File 
"/Users/bo.gao/workplace/spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py",
 line 76, in main
      File 
"/Users/bo.gao/workplace/spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py",
 line 69, in process
      File "<stdin>", line 2, in collectBatch
    RuntimeError: this should fail the query
    
            at 
org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper$.$anonfun$pythonForeachBatchWrapper$1(StreamingForeachBatchHelper.scala:137)
            at 
org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper$.$anonfun$pythonForeachBatchWrapper$1$adapted(StreamingForeachBatchHelper.scala:115)
            at 
org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper$.$anonfun$dataFrameCachingWrapper$1(StreamingForeachBatchHelper.scala:70)
            at 
org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper$.$anonfun$dataFrameCachingWrapper$1$adapted(StreamingForeachBatchHelper.scala:60)
            at 
org.apache.spark.sql.execution.streaming.sources.ForeachBatchSink.addBatch(ForeachBatchSink.scala:34)
            at 
org.apache.spark.sql.execution.streaming.MicroBatchExecution.$anonfun$runBatch$17(MicroBatchExecution.scala:732)
            at 
org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$6(SQLExecution.scala:150)
            at 
org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:241)
            at 
org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId0$1(SQLExecution.scala:116)
            at 
org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
            at 
org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId0(SQLExecution.scala:72)
            at 
org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:196)
            at 
org.apache.spark.sql.execution.streaming.MicroBatchExecution.$anonfun$runBatch$16(MicroBatchExecution.scala:729)
            at 
org.apache.spark.sql.execution.streaming.ProgressReporter.reportTimeTaken(ProgressReporter.scala:427)
            at 
org.apache.spark.sql.execution.streaming.ProgressReporter.reportTimeTaken$(ProgressReporter.scala:425)
            at 
org.apache.spark.sql.execution.streaming.StreamExecution.reportTimeTaken(StreamExecution.scala:67)
            at 
org.apache.spark.sql.execution.streaming.MicroBatchExecution.runBatch(MicroBatchExecution.scala:729)
            at 
org.apache.spark.sql.execution.streaming.MicroBatchExecution.$anonfun$runActivatedStream$2(MicroBatchExecution.scala:286)
            at 
scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
            at 
org.apache.spark.sql.execution.streaming.ProgressReporter.reportTimeTaken(ProgressReporter.scala:427)
            at 
org.apache.spark.sql.execution.streaming.ProgressReporter.reportTimeTaken$(ProgressReporter.scala:425)
            at 
org.apache.spark.sql.execution.streaming.StreamExecution.reportTimeTaken(StreamExecution.scala:67)
            at 
org.apache.spark.sql.execution.streaming.MicroBatchExecution.$anonfun$runActivatedStream$1(MicroBatchExecution.scala:249)
            at 
org.apache.spark.sql.execution.streaming.ProcessingTimeExecutor.execute(TriggerExecutor.scala:67)
            at 
org.apache.spark.sql.execution.streaming.MicroBatchExecution.runActivatedStream(MicroBatchExecution.scala:239)
            at 
org.apache.spark.sql.execution.streaming.StreamExecution.$anonfun$runStream$1(StreamExecution.scala:311)
            at 
scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
            at 
org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
            at 
org.apache.spark.sql.execution.streaming.StreamExecution.org$apache$spark$sql$execution$streaming$StreamExecution$$runStream(StreamExecution.scala:289)
            at 
org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.$anonfun$run$1(StreamExecution.scala:211)
            at 
scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
            at 
org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
            at 
org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.run(StreamExecution.scala:211)
    ```
    StreamingQueryListener:
    ```
    >>> class TestListener(StreamingQueryListener):
    ...     def onQueryStarted(self, event):
    ...         raise RuntimeError("this should fail the listener")
    ...     def onQueryProgress(self, event):
    ...         pass
    ...     def onQueryIdle(self, event):
    ...         pass
    ...     def onQueryTerminated(self, event):
    ...         pass
    ...
    >>> test_listener = TestListener()
    >>> spark.streams.addListener(test_listener)
    >>> df = spark.readStream.format("rate").option("rowsPerSecond", 10).load()
    >>> query = df.writeStream.format("noop").queryName("test").start()
    >>> query.stop()
    ```
    
    ```
    23/09/18 14:18:56 ERROR StreamingQueryListenerBus: Listener 
PythonStreamingQueryListener threw an exception
    java.lang.IllegalStateException: Found error inside Streaming query 
listener Python process for function onQueryStarted: Traceback (most recent 
call last):
      File 
"/Users/bo.gao/workplace/spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/listener_worker.py",
 line 90, in main
      File 
"/Users/bo.gao/workplace/spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/listener_worker.py",
 line 78, in process
      File "<stdin>", line 3, in onQueryStarted
    RuntimeError: this should fail the listener
    
            at 
org.apache.spark.sql.connect.planner.PythonStreamingQueryListener.handlePythonWorkerError(StreamingQueryListenerHelper.scala:88)
            at 
org.apache.spark.sql.connect.planner.PythonStreamingQueryListener.onQueryStarted(StreamingQueryListenerHelper.scala:50)
            at 
org.apache.spark.sql.execution.streaming.StreamingQueryListenerBus.doPostEvent(StreamingQueryListenerBus.scala:131)
            at 
org.apache.spark.sql.execution.streaming.StreamingQueryListenerBus.doPostEvent(StreamingQueryListenerBus.scala:43)
            at 
org.apache.spark.util.ListenerBus.postToAll(ListenerBus.scala:117)
            at 
org.apache.spark.util.ListenerBus.postToAll$(ListenerBus.scala:101)
            at 
org.apache.spark.sql.execution.streaming.StreamingQueryListenerBus.postToAll(StreamingQueryListenerBus.scala:88)
            at 
org.apache.spark.sql.execution.streaming.StreamingQueryListenerBus.post(StreamingQueryListenerBus.scala:77)
            at 
org.apache.spark.sql.streaming.StreamingQueryManager.postListenerEvent(StreamingQueryManager.scala:231)
            at 
org.apache.spark.sql.execution.streaming.StreamExecution.postEvent(StreamExecution.scala:408)
            at 
org.apache.spark.sql.execution.streaming.StreamExecution.org$apache$spark$sql$execution$streaming$StreamExecution$$runStream(StreamExecution.scala:283)
            at 
org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.$anonfun$run$1(StreamExecution.scala:211)
            at 
scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
            at 
org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
            at 
org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.run(StreamExecution.scala:211)
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #42986 from bogao007/error-handling.
    
    Authored-by: bogao007 <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../planner/StreamingForeachBatchHelper.scala      | 30 ++++++++++++++--
 .../planner/StreamingQueryListenerHelper.scala     | 42 ++++++++++++++++++++--
 .../streaming/worker/foreach_batch_worker.py       | 10 ++++--
 .../connect/streaming/worker/listener_worker.py    |  9 ++++-
 .../connect/streaming/test_parity_foreach_batch.py |  1 -
 python/pyspark/sql/worker/analyze_udtf.py          | 23 ++----------
 python/pyspark/util.py                             | 31 +++++++++++++++-
 python/pyspark/worker.py                           | 24 ++-----------
 8 files changed, 116 insertions(+), 54 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index c30e08bc39dd..5ef0aea6b61c 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -16,6 +16,8 @@
  */
 package org.apache.spark.sql.connect.planner
 
+import java.io.EOFException
+import java.nio.charset.StandardCharsets
 import java.util.UUID
 import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.ConcurrentMap
@@ -23,7 +25,8 @@ import java.util.concurrent.ConcurrentMap
 import scala.collection.JavaConverters._
 import scala.util.control.NonFatal
 
-import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction, 
StreamingPythonRunner}
+import org.apache.spark.SparkException
+import org.apache.spark.api.python.{PythonException, PythonRDD, 
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.connect.service.SessionHolder
@@ -125,8 +128,29 @@ object StreamingForeachBatchHelper extends Logging {
       dataOut.writeLong(args.batchId)
       dataOut.flush()
 
-      val ret = dataIn.readInt()
-      logInfo(s"Python foreach batch for dfId ${args.dfId} completed (ret: 
$ret)")
+      try {
+        dataIn.readInt() match {
+          case 0 =>
+            logInfo(s"Python foreach batch for dfId ${args.dfId} completed 
(ret: 0)")
+          case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+            val exLength = dataIn.readInt()
+            val obj = new Array[Byte](exLength)
+            dataIn.readFully(obj)
+            val msg = new String(obj, StandardCharsets.UTF_8)
+            throw new PythonException(
+              s"Found error inside foreachBatch Python process: $msg",
+              null)
+          case otherValue =>
+            throw new IllegalStateException(
+              s"Unexpected return value $otherValue from the " +
+                s"Python worker.")
+        }
+      } catch {
+        // TODO: Better handling (e.g. retries) on exceptions like 
EOFException to avoid
+        // transient errors, same for StreamingQueryListenerHelper.
+        case eof: EOFException =>
+          throw new SparkException("Python worker exited unexpectedly 
(crashed)", eof)
+      }
     }
 
     (dataFrameCachingWrapper(foreachBatchRunnerFn, sessionHolder), 
RunnerCleaner(runner))
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
index 01339a8a1b47..886aeab3befd 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
@@ -17,7 +17,12 @@
 
 package org.apache.spark.sql.connect.planner
 
-import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction, 
StreamingPythonRunner}
+import java.io.EOFException
+import java.nio.charset.StandardCharsets
+
+import org.apache.spark.SparkException
+import org.apache.spark.api.python.{PythonException, PythonRDD, 
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connect.service.{SessionHolder, 
SparkConnectService}
 import org.apache.spark.sql.streaming.StreamingQueryListener
 
@@ -27,7 +32,8 @@ import org.apache.spark.sql.streaming.StreamingQueryListener
  * When a new event is received, it is serialized to json, and passed to the 
python process.
  */
 class PythonStreamingQueryListener(listener: SimplePythonFunction, 
sessionHolder: SessionHolder)
-    extends StreamingQueryListener {
+    extends StreamingQueryListener
+    with Logging {
 
   private val port = SparkConnectService.localPort
   private val connectUrl = 
s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
@@ -38,33 +44,63 @@ class PythonStreamingQueryListener(listener: 
SimplePythonFunction, sessionHolder
     sessionHolder.sessionId,
     "pyspark.sql.connect.streaming.worker.listener_worker")
 
-  val (dataOut, _) = runner.init()
+  val (dataOut, dataIn) = runner.init()
 
   override def onQueryStarted(event: 
StreamingQueryListener.QueryStartedEvent): Unit = {
     PythonRDD.writeUTF(event.json, dataOut)
     dataOut.writeInt(0)
     dataOut.flush()
+    handlePythonWorkerError("onQueryStarted")
   }
 
   override def onQueryProgress(event: 
StreamingQueryListener.QueryProgressEvent): Unit = {
     PythonRDD.writeUTF(event.json, dataOut)
     dataOut.writeInt(1)
     dataOut.flush()
+    handlePythonWorkerError("onQueryProgress")
   }
 
   override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit 
= {
     PythonRDD.writeUTF(event.json, dataOut)
     dataOut.writeInt(2)
     dataOut.flush()
+    handlePythonWorkerError("onQueryIdle")
   }
 
   override def onQueryTerminated(event: 
StreamingQueryListener.QueryTerminatedEvent): Unit = {
     PythonRDD.writeUTF(event.json, dataOut)
     dataOut.writeInt(3)
     dataOut.flush()
+    handlePythonWorkerError("onQueryTerminated")
   }
 
   private[spark] def stopListenerProcess(): Unit = {
     runner.stop()
   }
+
+  // TODO: Reuse the same method in StreamingForeachBatchHelper to avoid 
duplication.
+  private def handlePythonWorkerError(functionName: String): Unit = {
+    try {
+      dataIn.readInt() match {
+        case 0 =>
+          logInfo(s"Streaming query listener function $functionName completed 
(ret: 0)")
+        case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+          val exLength = dataIn.readInt()
+          val obj = new Array[Byte](exLength)
+          dataIn.readFully(obj)
+          val msg = new String(obj, StandardCharsets.UTF_8)
+          throw new PythonException(
+            s"Found error inside Streaming query listener Python " +
+              s"process for function $functionName: $msg",
+            null)
+        case otherValue =>
+          throw new IllegalStateException(
+            s"Unexpected return value $otherValue from the " +
+              s"Python worker.")
+      }
+    } catch {
+      case eof: EOFException =>
+        throw new SparkException("Python worker exited unexpectedly 
(crashed)", eof)
+    }
+  }
 }
diff --git 
a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py 
b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
index 72037f1263db..06534e355de7 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
@@ -30,6 +30,7 @@ from pyspark.serializers import (
 )
 from pyspark import worker
 from pyspark.sql import SparkSession
+from pyspark.util import handle_worker_exception
 from typing import IO
 from pyspark.worker_util import check_python_version
 
@@ -69,8 +70,13 @@ def main(infile: IO, outfile: IO) -> None:
     while True:
         df_ref_id = utf8_deserializer.loads(infile)
         batch_id = read_long(infile)
-        process(df_ref_id, int(batch_id))  # TODO(SPARK-44463): Propagate 
error to the user.
-        write_int(0, outfile)
+        # Handle errors inside Python worker. Write 0 to outfile if no errors 
and write -2 with
+        # traceback string if error occurs.
+        try:
+            process(df_ref_id, int(batch_id))
+            write_int(0, outfile)
+        except BaseException as e:
+            handle_worker_exception(e, outfile)
         outfile.flush()
 
 
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py 
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index c026945767d9..ed38a7884359 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -31,6 +31,7 @@ from pyspark.serializers import (
 )
 from pyspark import worker
 from pyspark.sql import SparkSession
+from pyspark.util import handle_worker_exception
 from typing import IO
 
 from pyspark.sql.streaming.listener import (
@@ -83,7 +84,13 @@ def main(infile: IO, outfile: IO) -> None:
     while True:
         event = utf8_deserializer.loads(infile)
         event_type = read_int(infile)
-        process(event, int(event_type))  # TODO(SPARK-44463): Propagate error 
to the user.
+        # Handle errors inside Python worker. Write 0 to outfile if no errors 
and write -2 with
+        # traceback string if error occurs.
+        try:
+            process(event, int(event_type))
+            write_int(0, outfile)
+        except BaseException as e:
+            handle_worker_exception(e, outfile)
         outfile.flush()
 
 
diff --git 
a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py 
b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
index c174bd53f8e9..30f7bb8c2df9 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
@@ -23,7 +23,6 @@ from pyspark.errors import PySparkPicklingError
 
 
 class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, 
ReusedConnectTestCase):
-    @unittest.skip("SPARK-44463: Error handling needs improvement in connect 
foreachBatch")
     def test_streaming_foreach_batch_propagates_python_errors(self):
         super().test_streaming_foreach_batch_propagates_python_errors()
 
diff --git a/python/pyspark/sql/worker/analyze_udtf.py 
b/python/pyspark/sql/worker/analyze_udtf.py
index 29665b586a36..194cd3db7655 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -18,7 +18,6 @@
 import inspect
 import os
 import sys
-import traceback
 from typing import Dict, List, IO, Tuple
 
 from pyspark.accumulators import _accumulatorRegistry
@@ -33,7 +32,7 @@ from pyspark.serializers import (
 )
 from pyspark.sql.types import _parse_datatype_json_string
 from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
-from pyspark.util import try_simplify_traceback
+from pyspark.util import handle_worker_exception
 from pyspark.worker_util import (
     check_python_version,
     read_command,
@@ -146,25 +145,7 @@ def main(infile: IO, outfile: IO) -> None:
                 write_int(2, outfile)
 
     except BaseException as e:
-        try:
-            exc_info = None
-            if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
-                tb = try_simplify_traceback(sys.exc_info()[-1])  # type: 
ignore[arg-type]
-                if tb is not None:
-                    e.__cause__ = None
-                    exc_info = "".join(traceback.format_exception(type(e), e, 
tb))
-            if exc_info is None:
-                exc_info = traceback.format_exc()
-
-            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
-            write_with_length(exc_info.encode("utf-8"), outfile)
-        except IOError:
-            # JVM close the socket
-            pass
-        except BaseException:
-            # Write the error to stderr if it happened while serializing
-            print("PySpark worker failed with exception:", file=sys.stderr)
-            print(traceback.format_exc(), file=sys.stderr)
+        handle_worker_exception(e, outfile)
         sys.exit(-1)
 
     send_accumulator_updates(outfile)
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 87f808549d1a..47f5933079e2 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -26,7 +26,7 @@ import threading
 import traceback
 import typing
 from types import TracebackType
-from typing import Any, Callable, Iterator, List, Optional, TextIO, Tuple, 
Union
+from typing import Any, Callable, IO, Iterator, List, Optional, TextIO, Tuple, 
Union
 
 from pyspark.errors import PySparkRuntimeError
 
@@ -382,6 +382,35 @@ def inheritable_thread_target(f: Optional[Union[Callable, 
"SparkSession"]] = Non
         return f  # type: ignore[return-value]
 
 
+def handle_worker_exception(e: BaseException, outfile: IO) -> None:
+    """
+    Handles exception for Python worker which writes 
SpecialLengths.PYTHON_EXCEPTION_THROWN (-2)
+    and exception traceback info to outfile. JVM could then read from the 
outfile and perform
+    exception handling there.
+    """
+    from pyspark.serializers import write_int, write_with_length, 
SpecialLengths
+
+    try:
+        exc_info = None
+        if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
+            tb = try_simplify_traceback(sys.exc_info()[-1])  # type: 
ignore[arg-type]
+            if tb is not None:
+                e.__cause__ = None
+                exc_info = "".join(traceback.format_exception(type(e), e, tb))
+        if exc_info is None:
+            exc_info = traceback.format_exc()
+
+        write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
+        write_with_length(exc_info.encode("utf-8"), outfile)
+    except IOError:
+        # JVM close the socket
+        pass
+    except BaseException:
+        # Write the error to stderr if it happened while serializing
+        print("PySpark worker failed with exception:", file=sys.stderr)
+        print(traceback.format_exc(), file=sys.stderr)
+
+
 class InheritableThread(threading.Thread):
     """
     Thread that is recommended to be used in PySpark when the pinned thread 
mode is
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 026c07fb9988..a3c7bbb59ddf 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -25,7 +25,6 @@ from inspect import getfullargspec
 import json
 from typing import Any, Callable, Iterable, Iterator
 
-import traceback
 import faulthandler
 
 from pyspark.accumulators import _accumulatorRegistry
@@ -34,7 +33,6 @@ from pyspark.taskcontext import BarrierTaskContext, 
TaskContext
 from pyspark.resource import ResourceInformation
 from pyspark.rdd import PythonEvalType
 from pyspark.serializers import (
-    write_with_length,
     write_int,
     read_long,
     read_bool,
@@ -54,7 +52,7 @@ from pyspark.sql.pandas.serializers import (
 )
 from pyspark.sql.pandas.types import to_arrow_type
 from pyspark.sql.types import BinaryType, Row, StringType, StructType, 
_parse_datatype_json_string
-from pyspark.util import fail_on_stopiteration, try_simplify_traceback
+from pyspark.util import fail_on_stopiteration, handle_worker_exception
 from pyspark import shuffle
 from pyspark.errors import PySparkRuntimeError, PySparkTypeError
 from pyspark.worker_util import (
@@ -1324,25 +1322,7 @@ def main(infile, outfile):
         TaskContext._setTaskContext(None)
         BarrierTaskContext._setTaskContext(None)
     except BaseException as e:
-        try:
-            exc_info = None
-            if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
-                tb = try_simplify_traceback(sys.exc_info()[-1])
-                if tb is not None:
-                    e.__cause__ = None
-                    exc_info = "".join(traceback.format_exception(type(e), e, 
tb))
-            if exc_info is None:
-                exc_info = traceback.format_exc()
-
-            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
-            write_with_length(exc_info.encode("utf-8"), outfile)
-        except IOError:
-            # JVM close the socket
-            pass
-        except BaseException:
-            # Write the error to stderr if it happened while serializing
-            print("PySpark worker failed with exception:", file=sys.stderr)
-            print(traceback.format_exc(), file=sys.stderr)
+        handle_worker_exception(e, outfile)
         sys.exit(-1)
     finally:
         if faulthandler_log_path:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to