yaooqinn commented on code in PR #42779:
URL: https://github.com/apache/spark/pull/42779#discussion_r1325554786


##########
connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala:
##########
@@ -79,4 +92,196 @@ class SparkConnectSessionHolderSuite extends 
SharedSparkSession {
       sessionHolder.getDataFrameOrThrow(key1)
     }
   }
+
+  private def streamingForeachBatchFunction(pysparkPythonPath: String): 
Array[Byte] = {
+    var binaryFunc: Array[Byte] = null
+    withTempPath { path =>
+      Process(
+        Seq(
+          IntegratedUDFTestUtils.pythonExec,
+          "-c",
+          "from pyspark.serializers import CloudPickleSerializer; " +
+            s"f = open('$path', 'wb');" +
+            "f.write(CloudPickleSerializer().dumps((" +
+            "lambda df, batchId: batchId)))"),
+        None,
+        "PYTHONPATH" -> pysparkPythonPath).!!
+      binaryFunc = Files.readAllBytes(path.toPath)
+    }
+    assert(binaryFunc != null)
+    binaryFunc
+  }
+
+  private def streamingQueryListenerFunction(pysparkPythonPath: String): 
Array[Byte] = {
+    var binaryFunc: Array[Byte] = null
+    val pythonScript =
+      """
+        |from pyspark.sql.streaming.listener import StreamingQueryListener
+        |
+        |class MyListener(StreamingQueryListener):
+        |    def onQueryStarted(e):
+        |        pass
+        |
+        |    def onQueryIdle(e):
+        |        pass
+        |
+        |    def onQueryProgress(e):
+        |        pass
+        |
+        |    def onQueryTerminated(e):
+        |        pass
+        |
+        |listener = MyListener()
+      """.stripMargin
+    withTempPath { codePath =>
+      Files.write(codePath.toPath, 
pythonScript.getBytes(StandardCharsets.UTF_8))
+      withTempPath { path =>
+        Process(
+          Seq(
+            IntegratedUDFTestUtils.pythonExec,
+            "-c",
+            "from pyspark.serializers import CloudPickleSerializer; " +
+              s"f = open('$path', 'wb');" +
+              s"exec(open('$codePath', 'r').read());" +
+              "f.write(CloudPickleSerializer().dumps(listener))"),
+          None,
+          "PYTHONPATH" -> pysparkPythonPath).!!
+        binaryFunc = Files.readAllBytes(path.toPath)
+      }
+    }
+    assert(binaryFunc != null)
+    binaryFunc
+  }
+
+  private def dummyPythonFunction(sessionHolder: SessionHolder)(
+      fcn: String => Array[Byte]): SimplePythonFunction = {
+    val sparkPythonPath =
+      
s"${IntegratedUDFTestUtils.pysparkPythonPath}:${IntegratedUDFTestUtils.pythonPath}"
+
+    SimplePythonFunction(
+      command = fcn(sparkPythonPath),
+      envVars = mutable.Map("PYTHONPATH" -> sparkPythonPath).asJava,
+      pythonIncludes = 
sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
+      pythonExec = IntegratedUDFTestUtils.pythonExec,
+      pythonVer = IntegratedUDFTestUtils.pythonVer,
+      broadcastVars = Lists.newArrayList(),
+      accumulator = null)
+  }
+
+  test("python foreachBatch process: process terminates after query is 
stopped") {
+    // scalastyle:off assume
+    assume(IntegratedUDFTestUtils.shouldTestPythonUDFs)
+    // scalastyle:on assume
+
+    val sessionHolder = SessionHolder.forTesting(spark)
+    try {
+      SparkConnectService.start(spark.sparkContext)
+
+      val pythonFn = 
dummyPythonFunction(sessionHolder)(streamingForeachBatchFunction)
+      val (fn1, cleaner1) =
+        StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, 
sessionHolder)
+      val (fn2, cleaner2) =
+        StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, 
sessionHolder)
+
+      val query1 = spark.readStream
+        .format("rate")
+        .load()
+        .writeStream
+        .format("memory")
+        .queryName("foreachBatch_termination_test_q1")
+        .foreachBatch(fn1)
+        .start()
+
+      val query2 = spark.readStream
+        .format("rate")
+        .load()
+        .writeStream
+        .format("memory")
+        .queryName("foreachBatch_termination_test_q2")
+        .foreachBatch(fn2)
+        .start()
+
+      sessionHolder.streamingForeachBatchRunnerCleanerCache
+        .registerCleanerForQuery(query1, cleaner1)
+      sessionHolder.streamingForeachBatchRunnerCleanerCache
+        .registerCleanerForQuery(query2, cleaner2)
+
+      val (runner1, runner2) = (cleaner1.runner, cleaner2.runner)
+
+      // assert both python processes are running
+      assert(!runner1.isWorkerStopped().get)
+      assert(!runner2.isWorkerStopped().get)
+      // stop query1
+      query1.stop()
+      // assert query1's python process is not running
+      eventually(timeout(30.seconds)) {
+        assert(runner1.isWorkerStopped().get)
+        assert(!runner2.isWorkerStopped().get)
+      }
+
+      // stop query2
+      query2.stop()
+      eventually(timeout(30.seconds)) {
+        // assert query2's python process is not running
+        assert(runner2.isWorkerStopped().get)
+      }
+
+      assert(spark.streams.active.isEmpty) // no running query
+      assert(spark.streams.listListeners().length == 1) // only process 
termination listener
+    } finally {
+      SparkConnectService.stop()
+      // remove process termination listener
+      spark.streams.removeListener(spark.streams.listListeners()(0))
+    }
+  }
+
+  test("python listener process: process terminates after listener is 
removed") {
+    // scalastyle:off assume
+    assume(IntegratedUDFTestUtils.shouldTestPythonUDFs)
+    // scalastyle:on assume
+
+    val sessionHolder = SessionHolder.forTesting(spark)
+    try {
+      SparkConnectService.start(spark.sparkContext)
+
+      val pythonFn = 
dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction)
+
+      val id1 = "listener_removeListener_test_1"
+      val id2 = "listener_removeListener_test_2"
+      val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder)

Review Comment:
   this test also crushed at this line
   ```
   [info] - python listener process: process terminates after listener is 
removed *** FAILED *** (427 milliseconds)
   [info]   java.io.EOFException:
   [info]   at java.io.DataInputStream.readInt(DataInputStream.java:392)
   [info]   at 
org.apache.spark.api.python.StreamingPythonRunner.init(StreamingPythonRunner.scala:101)
   [info]   at 
org.apache.spark.sql.connect.planner.PythonStreamingQueryListener.<init>(StreamingQueryListenerHelper.scala:41)
   [info]   at 
org.apache.spark.sql.connect.service.SparkConnectSessionHolderSuite.$anonfun$new$11(SparkConnectSessionHodlerSuite.scala:251)
   [info]   at 
org.scalatest.enablers.Timed$$anon$1.timeoutAfter(Timed.scala:127)
   [info]   at 
org.scalatest.concurrent.TimeLimits$.failAfterImpl(TimeLimits.scala:282)
   [info]   at 
org.scalatest.concurrent.TimeLimits.failAfter(TimeLimits.scala:231)
   [info]   at 
org.scalatest.concurrent.TimeLimits.failAfter$(TimeLimits.scala:230)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to