This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2bc9573e94f2 [SPARK-47819][CONNECT] Use asynchronous callback for execution cleanup 2bc9573e94f2 is described below commit 2bc9573e94f29cd5394429b623e30c4386a473ba Author: Xi Lyu <xi....@databricks.com> AuthorDate: Fri Apr 12 08:48:40 2024 -0400 [SPARK-47819][CONNECT] Use asynchronous callback for execution cleanup ### What changes were proposed in this pull request? Expired sessions are regularly checked and cleaned up by a maintenance thread. However, currently, this process is synchronous. Therefore, in rare cases, interrupting the execution thread of a query in a session can take hours, causing the entire maintenance process to stall, resulting in a large amount of memory not being cleared. We address this by introducing asynchronous callbacks for execution cleanup, avoiding synchronous joins of execution threads, and preventing the maintenance thread from stalling in the above scenarios. To be more specific, instead of calling `runner.join()` in `ExecutorHolder.close()`, we set a post-cleanup function as the callback through `runner.processOnCompletion`, which will be called asynchronously once the execution runner is completed or interrupted. In this way, the maintenan [...] ### Why are the changes needed? In the rare cases mentioned above, performance can be severely affected. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests and a new test `Async cleanup callback gets called after the execution is closed` in `SparkConnectServiceE2ESuite.scala`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46027 from xi-db/SPARK-47819-async-cleanup. Authored-by: Xi Lyu <xi....@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../connect/execution/ExecuteThreadRunner.scala | 33 ++++++++++++++++------ .../spark/sql/connect/service/ExecuteHolder.scala | 16 ++++++++--- .../connect/planner/SparkConnectServiceSuite.scala | 7 ++++- .../service/SparkConnectServiceE2ESuite.scala | 23 +++++++++++++++ 4 files changed, 65 insertions(+), 14 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 56776819dac9..37c3120a8ff4 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.connect.execution +import scala.concurrent.{ExecutionContext, Promise} import scala.jdk.CollectionConverters._ +import scala.util.Try import scala.util.control.NonFatal import com.google.protobuf.Message @@ -30,7 +32,7 @@ import org.apache.spark.sql.connect.common.ProtoUtils import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService} import org.apache.spark.sql.connect.utils.ErrorUtils -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * This class launches the actual execution in an execution thread. The execution pushes the @@ -38,10 +40,12 @@ import org.apache.spark.util.Utils */ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging { + private val promise: Promise[Unit] = Promise[Unit]() + // The newly created thread will inherit all InheritableThreadLocals used by Spark, // e.g. SparkContext.localProperties. If considering implementing a thread-pool, // forwarding of thread locals needs to be taken into account. - private val executionThread: Thread = new ExecutionThread() + private val executionThread: ExecutionThread = new ExecutionThread(promise) private var started: Boolean = false @@ -63,11 +67,11 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends } } - /** Joins the background execution thread after it is finished. */ - private[connect] def join(): Unit = { - // only called when the execution is completed or interrupted. - assert(completed || interrupted) - executionThread.join() + /** + * Register a callback that gets executed after completion/interruption of the execution thread. + */ + private[connect] def processOnCompletion(callback: Try[Unit] => Unit): Unit = { + promise.future.onComplete(callback)(ExecuteThreadRunner.namedExecutionContext) } /** @@ -276,10 +280,21 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends .build() } - private class ExecutionThread + private class ExecutionThread(onCompletionPromise: Promise[Unit]) extends Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") { override def run(): Unit = { - execute() + try { + execute() + onCompletionPromise.success(()) + } catch { + case NonFatal(e) => + onCompletionPromise.failure(e) + } } } } + +private[connect] object ExecuteThreadRunner { + private implicit val namedExecutionContext: ExecutionContext = ExecutionContext + .fromExecutor(ThreadUtils.newDaemonSingleThreadExecutor("SparkConnectExecuteThreadCallback")) +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index f03f81326064..3112d12bb0e6 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -117,6 +117,9 @@ private[connect] class ExecuteHolder( : mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]] = new mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]]() + /** For testing. Whether the async completion callback is called. */ + @volatile private[connect] var completionCallbackCalled: Boolean = false + /** * Start the execution. The execution is started in a background thread in ExecuteThreadRunner. * Responses are produced and cached in ExecuteResponseObserver. A GRPC thread consumes the @@ -238,8 +241,15 @@ private[connect] class ExecuteHolder( if (closedTimeMs.isEmpty) { // interrupt execution, if still running. runner.interrupt() - // wait for execution to finish, to make sure no more results get pushed to responseObserver - runner.join() + // Do not wait for the execution to finish, clean up resources immediately. + runner.processOnCompletion { _ => + completionCallbackCalled = true + // The execution may not immediately get interrupted, clean up any remaining resources when + // it does. + responseObserver.removeAll() + // post closed to UI + eventsManager.postClosed() + } // interrupt any attached grpcResponseSenders grpcResponseSenders.foreach(_.interrupt()) // if there were still any grpcResponseSenders, register detach time @@ -249,8 +259,6 @@ private[connect] class ExecuteHolder( } // remove all cached responses from observer responseObserver.removeAll() - // post closed to UI - eventsManager.postClosed() closedTimeMs = Some(System.currentTimeMillis()) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 63cebd452364..af18fca9dd21 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -31,6 +31,8 @@ import org.apache.arrow.vector.{BigIntVector, Float8Vector} import org.apache.arrow.vector.ipc.ArrowStreamReader import org.mockito.Mockito.when import org.scalatest.Tag +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkContext, SparkEnv} @@ -884,8 +886,11 @@ class SparkConnectServiceSuite assert(executeHolder.eventsManager.hasError.isDefined) } def onCompleted(producedRowCount: Option[Long] = None): Unit = { - assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount) + // The eventsManager is closed asynchronously + Eventually.eventually(timeout(1.seconds)) { + assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) + } } def onCanceled(): Unit = { assert(executeHolder.eventsManager.hasCanceled.contains(true)) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala index 33560cd53f6b..cb0bd8f771eb 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala @@ -91,6 +91,29 @@ class SparkConnectServiceE2ESuite extends SparkConnectServerTest { } } + test("Async cleanup callback gets called after the execution is closed") { + withClient(UUID.randomUUID().toString, defaultUserId) { client => + val query1 = client.execute(buildPlan(BIG_ENOUGH_QUERY)) + // just creating the iterator is lazy, trigger query1 and query2 to be sent. + query1.hasNext + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(SparkConnectService.executionManager.listExecuteHolders.length == 1) + } + val executeHolder1 = SparkConnectService.executionManager.listExecuteHolders.head + // Close session + client.releaseSession() + // Check that queries get cancelled + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(SparkConnectService.executionManager.listExecuteHolders.length == 0) + // SparkConnectService.sessionManager. + } + // Check the async execute cleanup get called + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(executeHolder1.completionCallbackCalled) + } + } + } + private def testReleaseSessionTwoSessions( sessionIdA: String, userIdA: String, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org