This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5c36c580477 [SPARK-44433][PYTHON][CONNECT][SS][FOLLOWUP] Terminate listener process with `removeListener` and improvements 5c36c580477 is described below commit 5c36c58047724885864cb781f17038a6b9c94513 Author: Wei Liu <wei....@databricks.com> AuthorDate: Fri Aug 4 09:14:05 2023 +0900 [SPARK-44433][PYTHON][CONNECT][SS][FOLLOWUP] Terminate listener process with `removeListener` and improvements ### What changes were proposed in this pull request? This is a followup to #42116. It addresses the following issues: 1. When `removeListener` is called upon one listener, before the python process is left running, now it also get stopped. 2. When multiple `removeListener` is called on the same listener, in non-connect mode, subsequent calls will be noop. But before this PR, in connect it actually throws an error, which doesn't align with existing behavior, this PR addresses it. 3. Set the socket timeout to be None (\infty) for `foreachBatch_worker` and `listener_worker`, because there could be a long time between each microbatch. If not setting this, the socket will timeout and won't be able to process new data. ``` scala> Streaming query listener worker is starting with url sc://localhost:15002/;user_id=wei.liu and sessionId 886191f0-2b64-4c44-b067-de511f04b42d. Traceback (most recent call last): File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/wei.liu/oss-spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/listener_worker.py", line 95, in <module> File "/home/wei.liu/oss-spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/listener_worker.py", line 82, in main File "/home/wei.liu/oss-spark/python/lib/pyspark.zip/pyspark/serializers.py", line 557, in loads File "/home/wei.liu/oss-spark/python/lib/pyspark.zip/pyspark/serializers.py", line 594, in read_int File "/usr/lib/python3.9/socket.py", line 704, in readinto return self._sock.recv_into(b) socket.timeout: timed out ``` ### Why are the changes needed? Necessary improvements ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manual test + unit test Closes #42283 from WweiL/SPARK-44433-listener-process-termination. Authored-by: Wei Liu <wei....@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../sql/streaming/StreamingQueryListener.scala | 28 ----------------- .../sql/connect/planner/SparkConnectPlanner.scala | 12 +++++--- .../planner/StreamingForeachBatchHelper.scala | 10 +++--- .../planner/StreamingQueryListenerHelper.scala | 21 +++++++------ .../spark/sql/connect/service/SessionHolder.scala | 19 +++++++----- .../spark/api/python/StreamingPythonRunner.scala | 36 ++++++++++++++++------ .../streaming/worker/foreachBatch_worker.py | 4 ++- .../connect/streaming/worker/listener_worker.py | 4 ++- .../connect/streaming/test_parity_listener.py | 7 +++++ 9 files changed, 77 insertions(+), 64 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index e2f3be02ad3..404bd1b078b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -75,34 +75,6 @@ abstract class StreamingQueryListener extends Serializable { def onQueryTerminated(event: QueryTerminatedEvent): Unit } -/** - * Py4J allows a pure interface so this proxy is required. - */ -private[spark] trait PythonStreamingQueryListener { - import StreamingQueryListener._ - - def onQueryStarted(event: QueryStartedEvent): Unit - - def onQueryProgress(event: QueryProgressEvent): Unit - - def onQueryIdle(event: QueryIdleEvent): Unit - - def onQueryTerminated(event: QueryTerminatedEvent): Unit -} - -private[spark] class PythonStreamingQueryListenerWrapper(listener: PythonStreamingQueryListener) - extends StreamingQueryListener { - import StreamingQueryListener._ - - def onQueryStarted(event: QueryStartedEvent): Unit = listener.onQueryStarted(event) - - def onQueryProgress(event: QueryProgressEvent): Unit = listener.onQueryProgress(event) - - override def onQueryIdle(event: QueryIdleEvent): Unit = listener.onQueryIdle(event) - - def onQueryTerminated(event: QueryTerminatedEvent): Unit = listener.onQueryTerminated(event) -} - /** * Companion object of [[StreamingQueryListener]] that defines the listener events. * @since 3.5.0 diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index f4b33ae961a..7136476b515 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -3097,10 +3097,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { case StreamingQueryManagerCommand.CommandCase.REMOVE_LISTENER => val listenerId = command.getRemoveListener.getId - val listener: StreamingQueryListener = sessionHolder.getListenerOrThrow(listenerId) - session.streams.removeListener(listener) - sessionHolder.removeCachedListener(listenerId) - respBuilder.setRemoveListener(true) + sessionHolder.getListener(listenerId) match { + case Some(listener) => + session.streams.removeListener(listener) + sessionHolder.removeCachedListener(listenerId) + respBuilder.setRemoveListener(true) + case None => + respBuilder.setRemoveListener(false) + } case StreamingQueryManagerCommand.CommandCase.LIST_LISTENERS => respBuilder.getListListenersBuilder diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index 998faf327d0..4f1037b86c9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -87,11 +87,13 @@ object StreamingForeachBatchHelper extends Logging { val port = SparkConnectService.localPort val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" - val runner = StreamingPythonRunner(pythonFn, connectUrl) + val runner = StreamingPythonRunner( + pythonFn, + connectUrl, + sessionHolder.sessionId, + "pyspark.sql.connect.streaming.worker.foreachBatch_worker") val (dataOut, dataIn) = - runner.init( - sessionHolder.sessionId, - "pyspark.sql.connect.streaming.worker.foreachBatch_worker") + runner.init() val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala index d915bc93496..9b2a931ec4a 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.streaming.StreamingQueryListener /** * A helper class for handling StreamingQueryListener related functionality in Spark Connect. Each * instance of this class starts a python process, inside which has the python handling logic. - * When new a event is received, it is serialized to json, and passed to the python process. + * When a new event is received, it is serialized to json, and passed to the python process. */ class PythonStreamingQueryListener( listener: SimplePythonFunction, @@ -32,12 +32,15 @@ class PythonStreamingQueryListener( pythonExec: String) extends StreamingQueryListener { - val port = SparkConnectService.localPort - val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" - val runner = StreamingPythonRunner(listener, connectUrl) + private val port = SparkConnectService.localPort + private val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" + private val runner = StreamingPythonRunner( + listener, + connectUrl, + sessionHolder.sessionId, + "pyspark.sql.connect.streaming.worker.listener_worker") - val (dataOut, _) = - runner.init(sessionHolder.sessionId, "pyspark.sql.connect.streaming.worker.listener_worker") + val (dataOut, _) = runner.init() override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = { PythonRDD.writeUTF(event.json, dataOut) @@ -63,7 +66,7 @@ class PythonStreamingQueryListener( dataOut.flush() } - // TODO(SPARK-44433)(SPARK-44516): Improve termination of Processes. - // Similar to foreachBatch when we need to exit the process when the query ends. - // In listener semantics, we need to exit the process when removeListener is called. + private[spark] def stopListenerProcess(): Unit = { + runner.stop() + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 310bb9208c2..29134f0dc0d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.InvalidPlanInput +import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.{SystemClock} import org.apache.spark.util.Utils @@ -220,20 +221,22 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } /** - * Returns [[StreamingQueryListener]] cached for Listener ID `id`. If it is not found, throw - * [[InvalidPlanInput]]. + * Returns [[StreamingQueryListener]] cached for Listener ID `id`. If it is not found, return + * None. */ - private[connect] def getListenerOrThrow(id: String): StreamingQueryListener = { + private[connect] def getListener(id: String): Option[StreamingQueryListener] = { Option(listenerCache.get(id)) - .getOrElse { - throw InvalidPlanInput(s"No listener with id $id is found in the session $sessionId") - } } /** - * Removes corresponding StreamingQueryListener by ID. + * Removes corresponding StreamingQueryListener by ID. Terminates the python process if it's a + * Spark Connect PythonStreamingQueryListener. */ - private[connect] def removeCachedListener(id: String): StreamingQueryListener = { + private[connect] def removeCachedListener(id: String): Unit = { + listenerCache.get(id) match { + case pyListener: PythonStreamingQueryListener => pyListener.stopListenerProcess() + case _ => // do nothing + } listenerCache.remove(id) } diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala index d4fd9485675..f14289f984a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -29,27 +29,36 @@ import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, PYTH private[spark] object StreamingPythonRunner { - def apply(func: PythonFunction, connectUrl: String): StreamingPythonRunner = { - new StreamingPythonRunner(func, connectUrl) + def apply( + func: PythonFunction, + connectUrl: String, + sessionId: String, + workerModule: String + ): StreamingPythonRunner = { + new StreamingPythonRunner(func, connectUrl, sessionId, workerModule) } } -private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: String) - extends Logging { +private[spark] class StreamingPythonRunner( + func: PythonFunction, + connectUrl: String, + sessionId: String, + workerModule: String) extends Logging { private val conf = SparkEnv.get.conf protected val bufferSize: Int = conf.get(BUFFER_SIZE) protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val envVars: java.util.Map[String, String] = func.envVars private val pythonExec: String = func.pythonExec + private var pythonWorker: Option[Socket] = None protected val pythonVer: String = func.pythonVer /** * Initializes the Python worker for streaming functions. Sets up Spark Connect session * to be used with the functions. */ - def init(sessionId: String, workerModule: String): (DataOutputStream, DataInputStream) = { - logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec") + def init(): (DataOutputStream, DataInputStream) = { + logInfo(s"Initializing Python runner (session: $sessionId, pythonExec: $pythonExec") val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") @@ -60,9 +69,9 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str conf.set(PYTHON_USE_DAEMON, false) envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl) - val pythonWorkerFactory = - new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap) - val (worker: Socket, _) = pythonWorkerFactory.createSimpleWorker() + val (worker, _) = env.createPythonWorker( + pythonExec, workerModule, envVars.asScala.toMap) + pythonWorker = Some(worker) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) @@ -85,4 +94,13 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str (dataOut, dataIn) } + + /** + * Stops the Python worker. + */ + def stop(): Unit = { + pythonWorker.foreach { worker => + SparkEnv.get.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) + } + } } diff --git a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py index 054788539f2..48a9848de40 100644 --- a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py @@ -76,7 +76,9 @@ if __name__ == "__main__": # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) + # There could be a long time between each micro batch. + sock.settimeout(None) write_int(os.getpid(), sock_file) sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py index 8eb310461b6..7aef911426d 100644 --- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py @@ -89,7 +89,9 @@ if __name__ == "__main__": # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) + # There could be a long time between each listener event. + sock.settimeout(None) write_int(os.getpid(), sock_file) sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 547462d4da6..4bf58bf7807 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -60,6 +60,10 @@ class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes try: self.spark.streams.addListener(test_listener) + # This ensures the read socket on the server won't crash (i.e. because of timeout) + # when there hasn't been a new event for a long time + time.sleep(30) + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() q = df.writeStream.format("noop").queryName("test").start() @@ -76,6 +80,9 @@ class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes finally: self.spark.streams.removeListener(test_listener) + # Remove again to verify this won't throw any error + self.spark.streams.removeListener(test_listener) + if __name__ == "__main__": import unittest --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org