This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new c380da1357b [SPARK-44433][3.5X] Terminate foreach batch runner when streaming query terminates c380da1357b is described below commit c380da1357b20f55c8e80a515fc024e1b3b380cc Author: Raghu Angadi <raghu.ang...@databricks.com> AuthorDate: Fri Aug 18 10:39:42 2023 -0700 [SPARK-44433][3.5X] Terminate foreach batch runner when streaming query terminates [This is 3.5x port of #42460 in master. It resolves couple of conflicts. ] This terminates Python worker created for `foreachBatch` when the streaming query terminate. All of the tracking is done inside connect server (inside `StreamingForeachBatchHelper`). How this works: * (A) The helper class returns a cleaner (an `AutoCloseable`) to connect server when foreachBatch function is set up (happens before starting the query). * (B) If the query fails to start, server directly invokes the cleaner. * (C) If the query starts up, the server registers the cleaner with `streamingRunnerCleanerCache` in the `SessionHolder`. * (D) The cache keeps a mapping of query to cleaner. * It registers a streaming listener (only once per session), which invokes the cleaner when a query terminates. * There is also finally cleanup when SessionHolder expires. This ensures Python process created for a streaming query is properly terminated when a query terminates. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Unit tests are added for `CleanerCache` - Existing unit tests for foreachBatch. - Manual test to verify python process is terminated in different cases. - Unit tests don't really verify that the process is terminated. There will be a follow up PR to verify this. Closes #42555 from rangadi/pr-terminate-3.5x. Authored-by: Raghu Angadi <raghu.ang...@databricks.com> Signed-off-by: Gengliang Wang <gengli...@apache.org> --- .../sql/connect/planner/SparkConnectPlanner.scala | 37 +++++-- .../planner/StreamingForeachBatchHelper.scala | 109 ++++++++++++++++++--- .../spark/sql/connect/service/SessionHolder.scala | 9 +- .../service/SparkConnectStreamingQueryCache.scala | 21 ++-- .../planner/StreamingForeachBatchHelperSuite.scala | 80 +++++++++++++++ .../spark/api/python/StreamingPythonRunner.scala | 9 +- 6 files changed, 230 insertions(+), 35 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index f3e87b7067d..5120073e2f0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.planner import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Try +import scala.util.control.NonFatal import com.google.common.base.Throwables import com.google.common.collect.{Lists, Maps} @@ -2853,11 +2854,17 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } } + // This is filled when a foreach batch runner started for Python. + var foreachBatchRunnerCleaner: Option[StreamingForeachBatchHelper.RunnerCleaner] = None + if (writeOp.hasForeachBatch) { val foreachBatchFn = writeOp.getForeachBatch.getFunctionCase match { case StreamingForeachFunction.FunctionCase.PYTHON_FUNCTION => val pythonFn = transformPythonFunction(writeOp.getForeachBatch.getPythonFunction) - StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, sessionHolder) + val (fn, cleaner) = + StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, sessionHolder) + foreachBatchRunnerCleaner = Some(cleaner) + fn case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION => val scalaFn = Utils.deserialize[StreamingForeachBatchHelper.ForeachBatchFnType]( @@ -2872,16 +2879,26 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { writer.foreachBatch(foreachBatchFn) } - val query = writeOp.getPath match { - case "" if writeOp.hasTableName => writer.toTable(writeOp.getTableName) - case "" => writer.start() - case path => writer.start(path) - } + val query = + try { + writeOp.getPath match { + case "" if writeOp.hasTableName => writer.toTable(writeOp.getTableName) + case "" => writer.start() + case path => writer.start(path) + } + } catch { + case NonFatal(ex) => // Failed to start the query, clean up foreach runner if any. + logInfo(s"Removing foreachBatch worker, query failed to start for session $sessionId.") + foreachBatchRunnerCleaner.foreach(_.close()) + throw ex + } - // Register the new query so that the session and query references are cached. - SparkConnectService.streamingSessionManager.registerNewStreamingQuery( - sessionHolder = SessionHolder(userId = userId, sessionId = sessionId, session), - query = query) + // Register the new query so that its reference is cached and is stopped on session timeout. + SparkConnectService.streamingSessionManager.registerNewStreamingQuery(sessionHolder, query) + // Register the runner with the query if Python foreachBatch is enabled. + foreachBatchRunnerCleaner.foreach { cleaner => + sessionHolder.streamingRunnerCleanerCache.registerCleanerForQuery(query, cleaner) + } executeHolder.eventsManager.postFinished() val result = WriteStreamOperationStartResult diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index 4f1037b86c9..21e4adb9896 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -17,12 +17,19 @@ package org.apache.spark.sql.connect.planner import java.util.UUID +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ConcurrentMap + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction, StreamingPythonRunner} import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.connect.service.SparkConnectService +import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.streaming.StreamingQueryListener /** * A helper class for handling ForeachBatch related functionality in Spark Connect servers @@ -31,6 +38,16 @@ object StreamingForeachBatchHelper extends Logging { type ForeachBatchFnType = (DataFrame, Long) => Unit + case class RunnerCleaner(runner: StreamingPythonRunner) extends AutoCloseable { + override def close(): Unit = { + try runner.stop() + catch { + case NonFatal(ex) => // Exception is not propagated. + logWarning("Error while stopping streaming Python worker", ex) + } + } + } + private case class FnArgsWithId(dfId: String, df: DataFrame, batchId: Long) /** @@ -83,7 +100,7 @@ object StreamingForeachBatchHelper extends Logging { */ def pythonForeachBatchWrapper( pythonFn: SimplePythonFunction, - sessionHolder: SessionHolder): ForeachBatchFnType = { + sessionHolder: SessionHolder): (ForeachBatchFnType, RunnerCleaner) = { val port = SparkConnectService.localPort val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" @@ -92,8 +109,7 @@ object StreamingForeachBatchHelper extends Logging { connectUrl, sessionHolder.sessionId, "pyspark.sql.connect.streaming.worker.foreachBatch_worker") - val (dataOut, dataIn) = - runner.init() + val (dataOut, dataIn) = runner.init() val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => { @@ -102,6 +118,9 @@ object StreamingForeachBatchHelper extends Logging { // This is because MicroBatch execution clones the session during start. // The session attached to the foreachBatch dataframe is different from the one the one // the query was started with. `sessionHolder` here contains the latter. + // Another issue with not creating new session id: foreachBatch worker keeps + // the session alive. The session mapping at Connect server does not expire and query + // keeps running even if the original client disappears. This keeps the query running. PythonRDD.writeUTF(args.dfId, dataOut) dataOut.writeLong(args.batchId) @@ -111,17 +130,79 @@ object StreamingForeachBatchHelper extends Logging { logInfo(s"Python foreach batch for dfId ${args.dfId} completed (ret: $ret)") } - dataFrameCachingWrapper(foreachBatchRunnerFn, sessionHolder) + (dataFrameCachingWrapper(foreachBatchRunnerFn, sessionHolder), RunnerCleaner(runner)) } - // TODO(SPARK-44433): Improve termination of Processes - // The goal is that when a query is terminated, the python process associated with foreachBatch - // should be terminated. One way to do that is by registering streaming query listener: - // After pythonForeachBatchWrapper() is invoked by the SparkConnectPlanner. - // At that time, we don't have the streaming queries yet. - // Planner should call back into this helper with the query id when it starts it immediately - // after. Save the query id to StreamingPythonRunner mapping. This mapping should be - // part of the SessionHolder. - // When a query is terminated, check the mapping and terminate any associated runner. - // These runners should be terminated when a session is deleted (due to timeout, etc). + /** + * This manages cache from queries to cleaner for runners used for streaming queries. This is + * used in [[SessionHolder]]. + */ + class CleanerCache(sessionHolder: SessionHolder) { + + private case class CacheKey(queryId: String, runId: String) + + // Mapping from streaming (queryId, runId) to runner cleaner. Used for Python foreachBatch. + private val cleanerCache: ConcurrentMap[CacheKey, AutoCloseable] = new ConcurrentHashMap() + + private lazy val streamingListener = { // Initialized on first registered query + val listener = new StreamingRunnerCleanerListener + sessionHolder.session.streams.addListener(listener) + logInfo(s"Registered runner clean up listener for session ${sessionHolder.sessionId}") + listener + } + + private[connect] def registerCleanerForQuery( + query: StreamingQuery, + cleaner: AutoCloseable): Unit = { + + streamingListener // Access to initialize + val key = CacheKey(query.id.toString, query.runId.toString) + + Option(cleanerCache.putIfAbsent(key, cleaner)) match { + case Some(_) => + throw new IllegalStateException(s"Unexpected: a cleaner for query $key is already set") + case None => // Inserted. Normal. + } + } + + /** Cleans up all the registered runners. */ + private[connect] def cleanUpAll(): Unit = { + // Clean up all remaining registered runners. + cleanerCache.keySet().asScala.foreach(cleanupStreamingRunner(_)) + } + + private def cleanupStreamingRunner(key: CacheKey): Unit = { + Option(cleanerCache.remove(key)).foreach { cleaner => + logInfo(s"Cleaning up runner for queryId ${key.queryId} runId ${key.runId}.") + cleaner.close() + } + } + + /** + * An internal streaming query listener that cleans up Python runner (if there is any) when a + * query is terminated. + */ + private class StreamingRunnerCleanerListener extends StreamingQueryListener { + override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {} + + override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {} + + override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = { + val key = CacheKey(event.id.toString, event.runId.toString) + cleanupStreamingRunner(key) + } + } + + private[connect] def listEntriesForTesting(): Map[(String, String), AutoCloseable] = { + cleanerCache + .entrySet() + .asScala + .map { e => + (e.getKey.queryId, e.getKey.runId) -> e.getValue + } + .toMap + } + + private[connect] def listenerForTesting: StreamingQueryListener = streamingListener + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index b828d78710f..2034a97fce9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener +import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper import org.apache.spark.sql.streaming.StreamingQueryListener -import org.apache.spark.util.{SystemClock} +import org.apache.spark.util.SystemClock import org.apache.spark.util.Utils /** @@ -55,6 +56,10 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio private lazy val listenerCache: ConcurrentMap[String, StreamingQueryListener] = new ConcurrentHashMap() + // Handles Python process clean up for streaming queries. Initialized on first use in a query. + private[connect] lazy val streamingRunnerCleanerCache = + new StreamingForeachBatchHelper.CleanerCache(this) + /** Add ExecuteHolder to this session. Called only by SparkConnectExecutionManager. */ private[service] def addExecuteHolder(executeHolder: ExecuteHolder): Unit = { val oldExecute = executions.putIfAbsent(executeHolder.operationId, executeHolder) @@ -153,9 +158,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio logDebug(s"Expiring session with userId: $userId and sessionId: $sessionId") artifactManager.cleanUpResources() eventManager.postClosed() - // Clean up running queries SparkConnectService.streamingSessionManager.cleanupRunningQueries(this) + streamingRunnerCleanerCache.cleanUpAll() // Clean up any streaming workers. } /** diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala index 1b834648c51..04c38a1991a 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala @@ -113,7 +113,12 @@ private[connect] class SparkConnectStreamingQueryCache( if (v.userId.equals(sessionHolder.userId) && v.sessionId.equals(sessionHolder.sessionId)) { if (v.query.isActive && Option(v.session.streams.get(k.queryId)).nonEmpty) { logInfo(s"Stopping the query with id ${k.queryId} since the session has timed out") - v.query.stop() + try { + v.query.stop() + } catch { + case NonFatal(ex) => + logWarning(s"Failed to stop the query ${k.queryId}. Error is ignored.", ex) + } } } } @@ -180,13 +185,17 @@ private[connect] class SparkConnectStreamingQueryCache( case Some(_) => // Inactive query waiting for expiration. Do nothing. logInfo(s"Waiting for the expiration for $id in session ${v.sessionId}") - case None => // Active query, check if it is stopped. Keep the session alive. + case None => // Active query, check if it is stopped. Enable timeout if it is stopped. val isActive = v.query.isActive && Option(v.session.streams.get(id)).nonEmpty if (!isActive) { logInfo(s"Marking query $id in session ${v.sessionId} inactive.") val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs))) + // To consider: Clean up any runner registered for this query with the session holder + // for this session. Useful in case listener events are delayed (such delays are + // seen in practice, especially when users have heavy processing inside listeners). + // Currently such workers would be cleaned up when the connect session expires. } } } @@ -196,9 +205,6 @@ private[connect] class SparkConnectStreamingQueryCache( private[connect] object SparkConnectStreamingQueryCache { - case class SessionCacheKey(userId: String, sessionId: String) - case class SessionCacheValue(session: SparkSession) - case class QueryCacheKey(queryId: String, runId: String) case class QueryCacheValue( @@ -207,5 +213,8 @@ private[connect] object SparkConnectStreamingQueryCache { session: SparkSession, // Holds the reference to the session. query: StreamingQuery, // Holds the reference to the query. expiresAtMs: Option[Long] = None // Expiry time for a stopped query. - ) + ) { + override def toString(): String = + s"[session id: $sessionId, query id: ${query.id}, run id: ${query.runId}]" + } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelperSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelperSuite.scala new file mode 100644 index 00000000000..820a1b04795 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelperSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.planner + +import java.util.UUID + +import org.mockito.Mockito.times +import org.mockito.Mockito.verify +import org.mockito.Mockito.when +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.streaming.StreamingQueryListener +import org.apache.spark.sql.test.SharedSparkSession + +class StreamingForeachBatchHelperSuite extends SharedSparkSession with MockitoSugar { + + private def mockQuery(): StreamingQuery = { + val query = mock[StreamingQuery] + val (queryId, runId) = (UUID.randomUUID(), UUID.randomUUID()) + when(query.id).thenReturn(queryId) + when(query.runId).thenReturn(runId) + query + } + + test("CleanerCache functionality: register queries, terminate, full cleanup") { + + val cleaner1 = mock[AutoCloseable] + val cleaner2 = mock[AutoCloseable] + + val query1 = mockQuery() + val query2 = mockQuery() + + val cache = new StreamingForeachBatchHelper.CleanerCache(SessionHolder.forTesting(spark)) + + cache.registerCleanerForQuery(query1, cleaner1) + + // Verify listener is registered. + assert(spark.streams.listListeners().contains(cache.listenerForTesting)) + + cache.registerCleanerForQuery(query2, cleaner2) + + assert(cache.listEntriesForTesting().size == 2) + + // No calls to close yet. + verify(cleaner1, times(0)).close() + + // Terminate query1 + val terminatedEvent = + new StreamingQueryListener.QueryTerminatedEvent(id = query1.id, runId = query1.runId, None) + cache.listenerForTesting.onQueryTerminated(terminatedEvent) + + // This should close 'cleaner1' and remove it from the cache. + verify(cleaner1, times(1)).close() + assert(cache.listEntriesForTesting().size == 1) + + // Clean up remaining entries + verify(cleaner2, times(0)).close() // cleaner2 is not closed yet. + cache.cleanUpAll() // It should be closed now. + verify(cleaner2, times(1)).close() + + // No more entries left in it now. + assert(cache.listEntriesForTesting().isEmpty) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala index cddda6fb7a7..1e0feb74cfd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -59,7 +59,7 @@ private[spark] class StreamingPythonRunner( * to be used with the functions. */ def init(): (DataOutputStream, DataInputStream) = { - logInfo(s"Initializing Python runner (session: $sessionId, pythonExec: $pythonExec") + logInfo(s"Initializing Python runner (session: $sessionId, pythonExec: $pythonExec)") val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") @@ -98,7 +98,7 @@ private[spark] class StreamingPythonRunner( new BufferedInputStream(pythonWorker.get.getInputStream, bufferSize)) val resFromPython = dataIn.readInt() - logInfo(s"Runner initialization returned $resFromPython") + logInfo(s"Runner initialization succeeded (returned $resFromPython).") (dataOut, dataIn) } @@ -108,7 +108,10 @@ private[spark] class StreamingPythonRunner( */ def stop(): Unit = { pythonWorker.foreach { worker => - pythonWorkerFactory.foreach(_.stopWorker(worker)) + pythonWorkerFactory.foreach { factory => + factory.stopWorker(worker) + factory.stop() + } } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org