This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 02dc372314f [SPARK-44626][SS][CONNECT] Followup on streaming query 
termination when client session is timed out for Spark Connect
02dc372314f is described below

commit 02dc372314f4018be3c432f0334231c36da8a948
Author: bogao007 <bo....@databricks.com>
AuthorDate: Thu Aug 3 06:53:52 2023 +0900

    [SPARK-44626][SS][CONNECT] Followup on streaming query termination when 
client session is timed out for Spark Connect
    
    ### What changes were proposed in this pull request?
    
    Removed keep alive feature in `SparkConnectStreamingQueryCache` so that the 
session mapping of running queries can expire after a client session times out. 
This is needed for a change to terminate the query when client session is timed 
out.
    
    ### Why are the changes needed?
    
    This is a followup of https://issues.apache.org/jira/browse/SPARK-44432. 
Without removing the keep alive feature, the session mapping for running 
queries cannot expire correctly since it would be accessed periodically
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Did manual testing and verified the change is working as expected. Adding 
unit test is kind of hard since the session mapping logic cannot be easily 
mocked. We want to get this checked in with Spark 3.5 asap and this is already 
verified with a manual testing.
    
    Create a streaming query on client side and then disconnect:
    ```
     val q = 
spark.readStream.format("rate").load().writeStream.format("console").start()
    
     exit
    ```
    
    Validate on server side that the running query is stopped after a while:
    ```
    23/08/01 13:26:34 INFO SparkConnectStreamingQueryCache: Stopping the query 
with id 787cdaed-0e6a-4496-a4a3-78b77789193b since the session has timed out
    23/08/01 13:26:34 INFO DAGScheduler: Asked to cancel job group 
c615e5a7-78ff-462b-b24c-bba55259f8f5
    23/08/01 13:26:34 INFO MicroBatchExecution: Async log purge executor pool 
for query [id = 787cdaed-0e6a-4496-a4a3-78b77789193b, runId = 
c615e5a7-78ff-462b-b24c-bba55259f8f5] has been shutdown
    23/08/01 13:26:34 INFO MicroBatchExecution: Deleting checkpoint 
file:/private/var/folders/b0/f9jmmrrx5js7xsswxyf58nwr0000gp/T/temporary-05b9e890-ec53-45df-a3ca-65c4c01a143d.
    23/08/01 13:26:34 INFO DAGScheduler: Asked to cancel job group 
c615e5a7-78ff-462b-b24c-bba55259f8f5
    23/08/01 13:26:34 INFO MicroBatchExecution: Query [id = 
787cdaed-0e6a-4496-a4a3-78b77789193b, runId = 
c615e5a7-78ff-462b-b24c-bba55259f8f5] was stopped
    ```
    
    Closes #42280 from bogao007/terminate-followup.
    
    Authored-by: bogao007 <bo....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 79938ee2a717d5384eface254f376169e9e5bf64)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../sql/connect/service/SparkConnectService.scala  |  5 +--
 .../service/SparkConnectStreamingQueryCache.scala  | 45 ++++++++--------------
 .../SparkConnectStreamingQueryCacheSuite.scala     | 21 ++--------
 3 files changed, 21 insertions(+), 50 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 5f60f4b1f37..6e607037e6c 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -278,10 +278,7 @@ object SparkConnectService extends Logging {
     cacheBuilder(CACHE_SIZE, CACHE_TIMEOUT_SECONDS).build[SessionCacheKey, 
SessionHolder]()
 
   private[connect] val streamingSessionManager =
-    new SparkConnectStreamingQueryCache(sessionKeepAliveFn = { case (userId, 
sessionId) =>
-      // Use getIfPresent() rather than get() to prevent accidental loading.
-      userSessionMapping.getIfPresent((userId, sessionId))
-    })
+    new SparkConnectStreamingQueryCache()
 
   private class RemoveSessionListener extends RemovalListener[SessionCacheKey, 
SessionHolder] {
     override def onRemoval(
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
index 21a65d4ed99..1b834648c51 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
@@ -38,13 +38,10 @@ import org.apache.spark.util.SystemClock
  * no longer active), it is cached for 1 hour so that it is accessible from 
the client side. It
  * runs a background thread to run a periodic task that does the following:
  *   - Check the status of the queries, and drops those that expired (1 hour 
after being stopped).
- *   - Keep the associated session active by invoking supplied function 
`sessionKeepAliveFn`.
  *
  * This class helps with supporting following semantics for streaming query 
sessions:
- *   - Keep the session and session mapping at connect server alive as long as 
a streaming query
- *     is active. Even if the client side has disconnected.
- *     - This matches how streaming queries behave in Spark. The queries 
continue to run if
- *       notebook or job session is lost.
+ *   - If the session mapping on connect server side is expired, stop all the 
running queries that
+ *     are associated with that session.
  *   - Once a query is stopped, the reference and mappings are maintained for 
1 hour and will be
  *     accessible from the client. This allows time for client to fetch 
status. If the client
  *     continues to access the query, it stays in the cache until 1 hour of 
inactivity.
@@ -52,7 +49,6 @@ import org.apache.spark.util.SystemClock
  * Note that these semantics are evolving and might change before being 
finalized in Connect.
  */
 private[connect] class SparkConnectStreamingQueryCache(
-    val sessionKeepAliveFn: (String, String) => Unit, // (userId, sessionId) 
=> Unit.
     val clock: Clock = new SystemClock(),
     private val stoppedQueryInactivityTimeout: Duration = 1.hour, // 
Configurable for testing.
     private val sessionPollingPeriod: Duration = 1.minute // Configurable for 
testing.
@@ -71,11 +67,11 @@ private[connect] class SparkConnectStreamingQueryCache(
 
       queryCache.put(QueryCacheKey(query.id.toString, query.runId.toString), 
value) match {
         case Some(existing) => // Query is being replace. Not really expected.
-          log.warn(
+          logWarning(
             s"Replacing existing query in the cache (unexpected). Query Id: 
${query.id}." +
               s"Existing value $existing, new value $value.")
         case None =>
-          log.info(s"Adding new query to the cache. Query Id ${query.id}, 
value $value.")
+          logInfo(s"Adding new query to the cache. Query Id ${query.id}, value 
$value.")
       }
 
       schedulePeriodicChecks() // Starts the scheduler thread if it hasn't 
started.
@@ -108,13 +104,17 @@ private[connect] class SparkConnectStreamingQueryCache(
   }
 
   /**
-   * Terminate all the running queries attached to the given sessionHolder. 
This is used when
-   * session is expired and we need to cleanup resources of that session.
+   * Terminate all the running queries attached to the given sessionHolder and 
remove them from
+   * the queryCache. This is used when session is expired and we need to 
cleanup resources of that
+   * session.
    */
   def cleanupRunningQueries(sessionHolder: SessionHolder): Unit = {
     for ((k, v) <- queryCache) {
       if (v.userId.equals(sessionHolder.userId) && 
v.sessionId.equals(sessionHolder.sessionId)) {
-        v.query.stop()
+        if (v.query.isActive && 
Option(v.session.streams.get(k.queryId)).nonEmpty) {
+          logInfo(s"Stopping the query with id ${k.queryId} since the session 
has timed out")
+          v.query.stop()
+        }
       }
     }
   }
@@ -144,13 +144,13 @@ private[connect] class SparkConnectStreamingQueryCache(
     scheduledExecutor match {
       case Some(_) => // Already running.
       case None =>
-        log.info(s"Starting thread for polling streaming sessions every 
$sessionPollingPeriod")
+        logInfo(s"Starting thread for polling streaming sessions every 
$sessionPollingPeriod")
         scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
         scheduledExecutor.get.scheduleAtFixedRate(
           () => {
             try periodicMaintenance()
             catch {
-              case NonFatal(ex) => log.warn("Unexpected exception in periodic 
task", ex)
+              case NonFatal(ex) => logWarning("Unexpected exception in 
periodic task", ex)
             }
           },
           sessionPollingPeriod.toMillis,
@@ -163,14 +163,9 @@ private[connect] class SparkConnectStreamingQueryCache(
    * Periodic maintenance task to do the following:
    *   - Update status of query if it is inactive. Sets an expiry time for 
such queries
    *   - Drop expired queries from the cache.
-   *   - Poll sessions associated with the cached queries in order keep them 
alive in connect
-   *     service' mapping (by invoking `sessionKeepAliveFn`).
    */
   private def periodicMaintenance(): Unit = {
 
-    // Gather sessions to keep alive and invoke supplied function outside the 
lock.
-    val sessionsToKeepAlive = mutable.HashSet[(String, String)]()
-
     queryCacheLock.synchronized {
       val nowMs = clock.getTimeMillis()
 
@@ -179,29 +174,23 @@ private[connect] class SparkConnectStreamingQueryCache(
         v.expiresAtMs match {
 
           case Some(ts) if nowMs >= ts => // Expired. Drop references.
-            log.info(s"Removing references for $id in session ${v.sessionId} 
after expiry period")
+            logInfo(s"Removing references for $id in session ${v.sessionId} 
after expiry period")
             queryCache.remove(k)
 
-          case Some(_) => // Inactive query waiting for expiration. Keep the 
session alive.
-            sessionsToKeepAlive.add((v.userId, v.sessionId))
+          case Some(_) => // Inactive query waiting for expiration. Do nothing.
+            logInfo(s"Waiting for the expiration for $id in session 
${v.sessionId}")
 
           case None => // Active query, check if it is stopped. Keep the 
session alive.
-            sessionsToKeepAlive.add((v.userId, v.sessionId))
-
             val isActive = v.query.isActive && 
Option(v.session.streams.get(id)).nonEmpty
 
             if (!isActive) {
-              log.info(s"Marking query $id in session ${v.sessionId} 
inactive.")
+              logInfo(s"Marking query $id in session ${v.sessionId} inactive.")
               val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis
               queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs)))
             }
         }
       }
     }
-
-    for ((userId, sessionId) <- sessionsToKeepAlive) {
-      sessionKeepAliveFn(userId, sessionId)
-    }
   }
 }
 
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala
index 4251d2badb9..ed3da2c0f71 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala
@@ -18,11 +18,10 @@
 package org.apache.spark.sql.connect.service
 
 import java.util.UUID
-import java.util.concurrent.atomic.AtomicInteger
 
 import scala.concurrent.duration.DurationInt
 
-import org.mockito.Mockito.{verify, when}
+import org.mockito.Mockito.when
 import org.scalatest.concurrent.Eventually.eventually
 import org.scalatest.concurrent.Futures.timeout
 import org.scalatestplus.mockito.MockitoSugar
@@ -36,9 +35,8 @@ import org.apache.spark.util.ManualClock
 class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with 
MockitoSugar {
 
   // Creates a manager with short durations for periodic check and expiry.
-  private def createSessionManager(keepAliveFn: (String, String) => Unit) = {
+  private def createSessionManager() = {
     new SparkConnectStreamingQueryCache(
-      keepAliveFn,
       clock = new ManualClock(),
       stoppedQueryInactivityTimeout = 1.minute, // This is on manual clock.
       sessionPollingPeriod = 20.milliseconds // This is real clock. Used for 
periodic task.
@@ -48,8 +46,6 @@ class SparkConnectStreamingQueryCacheSuite extends 
SparkFunSuite with MockitoSug
   test("Session cache functionality with a streaming query") {
     // Verifies common happy path for the query cache. Runs a query through 
its life cycle.
 
-    val numKeepAliveCalls = new AtomicInteger(0)
-
     val queryId = UUID.randomUUID().toString
     val runId = UUID.randomUUID().toString
     val mockSession = mock[SparkSession]
@@ -59,11 +55,7 @@ class SparkConnectStreamingQueryCacheSuite extends 
SparkFunSuite with MockitoSug
     val sessionHolder =
       SessionHolder(userId = "test_user_1", sessionId = "test_session_1", 
session = mockSession)
 
-    val sessionMgr = createSessionManager(keepAliveFn = { case (userId, 
sessionId) =>
-      assert(userId == sessionHolder.userId)
-      assert(sessionId == sessionHolder.sessionId)
-      numKeepAliveCalls.incrementAndGet()
-    })
+    val sessionMgr = createSessionManager()
 
     val clock = sessionMgr.clock.asInstanceOf[ManualClock]
 
@@ -77,11 +69,6 @@ class SparkConnectStreamingQueryCacheSuite extends 
SparkFunSuite with MockitoSug
 
     sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery)
 
-    eventually(timeout(1.minute)) {
-      // Verify keep alive function is called a few times.
-      assert(numKeepAliveCalls.get() >= 5)
-    }
-
     sessionMgr.getCachedValue(queryId, runId) match {
       case Some(v) =>
         assert(v.sessionId == sessionHolder.sessionId)
@@ -96,8 +83,6 @@ class SparkConnectStreamingQueryCacheSuite extends 
SparkFunSuite with MockitoSug
     assert(sessionMgr.getCachedQuery(queryId, runId, 
mockSession).contains(mockQuery))
 
     // Cleanup the query and verify if stop() method has been called.
-    sessionMgr.cleanupRunningQueries(sessionHolder)
-    verify(mockQuery).stop()
     when(mockQuery.isActive).thenReturn(false)
 
     val expectedExpiryTimeMs = sessionMgr.clock.getTimeMillis() + 
1.minute.toMillis


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to