This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new e8e330fbbca [SPARK-39218][SS][PYTHON] Make foreachBatch streaming 
query stop gracefully
e8e330fbbca is described below

commit e8e330fbbca5452e9af0a78e5f2cfae0cc6be134
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Fri May 20 13:02:17 2022 +0900

    [SPARK-39218][SS][PYTHON] Make foreachBatch streaming query stop gracefully
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to make the `foreachBatch` streaming query stop gracefully 
by handling the interrupted exceptions at 
`StreamExecution.isInterruptionException`.
    
    Because there is no straightforward way to access to the original JVM 
exception, here we rely on string pattern match for now (see also "Why are the 
changes needed?" below). There is only one place from Py4J 
https://github.com/py4j/py4j/blob/master/py4j-python/src/py4j/protocol.py#L326-L328
 so the approach would work at least.
    
    ### Why are the changes needed?
    
    In `foreachBatch`,  the Python user-defined function in the microbatch runs 
till the end even when `StreamingQuery.stop` is invoked. However, when any Py4J 
access is attempted within the user-defined function:
    
    - With the pinned thread mode disabled, the interrupt exception is not 
blocked, and the Python function is executed till the end in a different thread.
    - With the pinned thread mode enabled, the interrupt exception is raised in 
the same thread, and the Python thread raises a Py4J exception in the same 
thread.
    
    The latter case is a problem because the interrupt exception is first 
thrown from JVM side (`java.lang. InterruptedException`) -> Python callback 
server (`py4j.protocol.Py4JJavaError`) -> JVM (`py4j.Py4JException`), and 
`py4j.Py4JException` is not listed in `StreamExecution.isInterruptionException` 
which doesn't gracefully stop the query.
    
    Therefore, we should handle this exception at 
`StreamExecution.isInterruptionException`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it will make the query gracefully stop.
    
    ### How was this patch tested?
    
    Manually tested with:
    
    ```python
    import time
    
    def func(batch_df, batch_id):
        time.sleep(10)
        print(batch_df.count())
    
    q = 
spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
    time.sleep(5)
    q.stop()
    ```
    
    Closes #36589 from HyukjinKwon/SPARK-39218.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit 499de87b77944157828a6d905d9b9df37b7c9a67)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/tests/test_streaming.py                 | 10 ++++++++++
 .../spark/sql/execution/streaming/StreamExecution.scala    | 11 +++++++++++
 .../scala/org/apache/spark/sql/streaming/StreamSuite.scala | 14 ++++++++++++--
 3 files changed, 33 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/tests/test_streaming.py 
b/python/pyspark/sql/tests/test_streaming.py
index 4920423be22..809294d34c3 100644
--- a/python/pyspark/sql/tests/test_streaming.py
+++ b/python/pyspark/sql/tests/test_streaming.py
@@ -592,6 +592,16 @@ class StreamingTests(ReusedSQLTestCase):
             if q:
                 q.stop()
 
+    def test_streaming_foreachBatch_graceful_stop(self):
+        # SPARK-39218: Make foreachBatch streaming query stop gracefully
+        def func(batch_df, _):
+            batch_df.sparkSession._jvm.java.lang.Thread.sleep(10000)
+
+        q = 
self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
+        time.sleep(3)  # 'rowsPerSecond' defaults to 1. Waits 3 secs out for 
the input.
+        q.stop()
+        self.assertIsNone(q.exception(), "No exception has to be propagated.")
+
     def test_streaming_read_from_table(self):
         with self.table("input_table", "this_query"):
             self.spark.sql("CREATE TABLE input_table (value string) USING 
parquet")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index f9ae65cdc47..c7ce9f52e06 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -618,6 +618,13 @@ abstract class StreamExecution(
 object StreamExecution {
   val QUERY_ID_KEY = "sql.streaming.queryId"
   val IS_CONTINUOUS_PROCESSING = "__is_continuous_processing"
+  val IO_EXCEPTION_NAMES = Seq(
+    classOf[InterruptedException].getName,
+    classOf[InterruptedIOException].getName,
+    classOf[ClosedByInterruptException].getName)
+  val PROXY_ERROR = (
+    "py4j.protocol.Py4JJavaError: An error occurred while calling" +
+    s".+(\\r\\n|\\r|\\n): (${IO_EXCEPTION_NAMES.mkString("|")})").r
 
   @scala.annotation.tailrec
   def isInterruptionException(e: Throwable, sc: SparkContext): Boolean = e 
match {
@@ -647,6 +654,10 @@ object StreamExecution {
       } else {
         false
       }
+    // py4j.Py4JException - with pinned thread mode on, the exception can be 
interrupted by Py4J
+    //                      access, for example, in 
`DataFrameWriter.foreachBatch`. See also
+    //                      SPARK-39218.
+    case e: py4j.Py4JException => 
PROXY_ERROR.findFirstIn(e.getMessage).isDefined
     case _ =>
       false
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 71e8ae74fe2..f2031b94231 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -1175,8 +1175,18 @@ class StreamSuite extends StreamTest {
     new ClosedByInterruptException,
     new UncheckedIOException("test", new ClosedByInterruptException),
     new ExecutionException("test", new InterruptedException),
-    new UncheckedExecutionException("test", new InterruptedException))) {
-    test(s"view ${e.getClass.getSimpleName} as a normal query stop") {
+    new UncheckedExecutionException("test", new InterruptedException)) ++
+    Seq(
+      classOf[InterruptedException].getName,
+      classOf[InterruptedIOException].getName,
+      classOf[ClosedByInterruptException].getName).map { s =>
+    new py4j.Py4JException(
+      s"""
+        |py4j.protocol.Py4JJavaError: An error occurred while calling 
o44.count.
+        |: $s
+        |""".stripMargin)
+    }) {
+    test(s"view ${e.getClass.getSimpleName} [${e.getMessage}] as a normal 
query stop") {
       ThrowingExceptionInCreateSource.createSourceLatch = new CountDownLatch(1)
       ThrowingExceptionInCreateSource.exception = e
       val query = spark


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to