This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 074c866603c [SPARK-42940][SS][CONNECT] Improve session management for 
streaming queries
074c866603c is described below

commit 074c866603c8744eebe80c18d508332066c3fa80
Author: Raghu Angadi <[email protected]>
AuthorDate: Fri Apr 28 19:56:32 2023 -0700

    [SPARK-42940][SS][CONNECT] Improve session management for streaming queries
    
    ### What changes were proposed in this pull request?
    This fixes couple of important issues related to session management for 
streaming queries.
    
    1. Session mapping should be maintained at connect server as long as the 
streaming query is active, even if there are no accesses from the client side. 
Currently the session mapping is dropped after 1 hour of inactivity.
    2. When streaming query is stopped, the Spark session drops its reference 
to the streaming query object. That implies it can not accessed by remote 
spark-connect client. It is common usage pattern for users to access a 
streaming query after it is is stopped (e.g. to check its metrics, any 
exception if failed, etc).
       - This is not a problem in legacy mode since the user code in the REPL 
keeps the reference. This is no longer the case in Spark-Connect.
    
    *Solution*: This PR adds `SparkConnectStreamingQueryCache` that does the 
following:
      * Each new streaming query is registered with this cache.
      * It runs a periodic task that checks the status of these queries and 
polls session mapping in connect-server so that the session stays alive.
      * When query is stopped, it cached for 1 hour more so the it can be 
accessed from remote client.
      * The full semantics are codified in the scaladoc. See [this 
comment](https://github.com/apache/spark/pull/40937/files#r1176846545) for more 
details.
    
    ### Why are the changes needed?
      - Explained in the above description.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    - Unit tests
    - Manual testing
    
    Closes #40937 from rangadi/session-mgmt.
    
    Authored-by: Raghu Angadi <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../spark/sql/streaming/StreamingQuerySuite.scala  |   4 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |  20 +-
 .../sql/connect/service/SparkConnectService.scala  |   6 +
 .../service/SparkConnectStreamHandler.scala        |   6 +-
 .../service/SparkConnectStreamingQueryCache.scala  | 210 +++++++++++++++++++++
 .../connect/planner/SparkConnectPlannerSuite.scala |   2 +-
 .../plugin/SparkConnectPluginRegistrySuite.scala   |   2 +-
 .../SparkConnectStreamingQueryCacheSuite.scala     | 155 +++++++++++++++
 8 files changed, 399 insertions(+), 6 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 9061dcadd63..f0c12d212c2 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -75,6 +75,10 @@ class StreamingQuerySuite extends RemoteSparkSession with 
SQLHelper {
       } finally {
         // Don't wait for any processed data. Otherwise the test could take 
multiple seconds.
         query.stop()
+
+        // The query should still be accessible after stopped.
+        assert(!query.isActive)
+        assert(query.recentProgress.nonEmpty)
       }
     }
   }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 0234ef5b1cd..3b4dd13cc5d 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -47,7 +47,9 @@ import 
org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, 
UdfPacket}
 import 
org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
 import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
-import org.apache.spark.sql.connect.service.{SparkConnectService, 
SparkConnectStreamHandler}
+import org.apache.spark.sql.connect.service.SessionHolder
+import org.apache.spark.sql.connect.service.SparkConnectService
+import org.apache.spark.sql.connect.service.SparkConnectStreamHandler
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.arrow.ArrowConverters
@@ -1884,6 +1886,7 @@ class SparkConnectPlanner(val session: SparkSession) {
 
   def process(
       command: proto.Command,
+      userId: String,
       sessionId: String,
       responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
     command.getCommandTypeCase match {
@@ -1902,6 +1905,7 @@ class SparkConnectPlanner(val session: SparkSession) {
       case proto.Command.CommandTypeCase.WRITE_STREAM_OPERATION_START =>
         handleWriteStreamOperationStart(
           command.getWriteStreamOperationStart,
+          userId,
           sessionId,
           responseObserver)
       case proto.Command.CommandTypeCase.STREAMING_QUERY_COMMAND =>
@@ -2198,6 +2202,7 @@ class SparkConnectPlanner(val session: SparkSession) {
 
   def handleWriteStreamOperationStart(
       writeOp: WriteStreamOperationStart,
+      userId: String,
       sessionId: String,
       responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
     val plan = transformRelation(writeOp.getInput)
@@ -2241,6 +2246,11 @@ class SparkConnectPlanner(val session: SparkSession) {
       case path => writer.start(path)
     }
 
+    // Register the new query so that the session and query references are 
cached.
+    SparkConnectService.streamingSessionManager.registerNewStreamingQuery(
+      sessionHolder = SessionHolder(userId = userId, sessionId = sessionId, 
session),
+      query = query)
+
     val result = WriteStreamOperationStartResult
       .newBuilder()
       .setQueryId(
@@ -2272,7 +2282,12 @@ class SparkConnectPlanner(val session: SparkSession) {
       .newBuilder()
       .setQueryId(command.getQueryId)
 
-    val query = Option(session.streams.get(id)) match {
+    // Find the query in connect service level cache, otherwise check 
session's active streams.
+    val query = SparkConnectService.streamingSessionManager
+      .getCachedQuery(id, runId, session) // Common case: query is cached in 
the cache.
+      .orElse { // Else try to find it in active streams. Mostly will not be 
found here either.
+        Option(session.streams.get(id))
+      } match {
       case Some(query) if query.runId.toString == runId =>
         query
       case Some(query) =>
@@ -2281,7 +2296,6 @@ class SparkConnectPlanner(val session: SparkSession) {
             s"does not match one on the server ${query.runId}. The query might 
have restarted.")
       case None =>
         throw new IllegalArgumentException(s"Streaming query $id is not found")
-      // TODO(SPARK-42962): Handle this better. May be cache stopped queries 
for a few minutes.
     }
 
     command.getCommandCase match {
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 09a3ff39698..df3c1fd7b05 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -268,6 +268,12 @@ object SparkConnectService {
   private val userSessionMapping =
     cacheBuilder(CACHE_SIZE, CACHE_TIMEOUT_SECONDS).build[SessionCacheKey, 
SessionHolder]()
 
+  private[connect] val streamingSessionManager =
+    new SparkConnectStreamingQueryCache(sessionKeepAliveFn = { case (userId, 
sessionId) =>
+      // Use getIfPresent() rather than get() to prevent accidental loading.
+      userSessionMapping.getIfPresent((userId, sessionId))
+    })
+
   // Simple builder for creating the cache of Sessions.
   private def cacheBuilder(cacheSize: Int, timeoutSeconds: Int): 
CacheBuilder[Object, Object] = {
     var cacheBuilder = CacheBuilder.newBuilder().ticker(Ticker.systemTicker())
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index f08dfba5e28..c544f484381 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -98,7 +98,11 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[ExecutePlanResp
   private def handleCommand(session: SparkSession, request: 
ExecutePlanRequest): Unit = {
     val command = request.getPlan.getCommand
     val planner = new SparkConnectPlanner(session)
-    planner.process(command, request.getSessionId, responseObserver)
+    planner.process(
+      command = command,
+      userId = request.getUserContext.getUserId,
+      sessionId = request.getSessionId,
+      responseObserver = responseObserver)
     responseObserver.onCompleted()
   }
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
new file mode 100644
index 00000000000..133686df018
--- /dev/null
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.concurrent.Executors
+import java.util.concurrent.ScheduledExecutorService
+import java.util.concurrent.TimeUnit
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.mutable
+import scala.concurrent.duration.Duration
+import scala.concurrent.duration.DurationInt
+import scala.util.control.NonFatal
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.streaming.StreamingQuery
+import org.apache.spark.util.Clock
+import org.apache.spark.util.SystemClock
+
+/**
+ * Caches Spark-Connect streaming query references and the sessions. When a 
query is stopped (i.e.
+ * no longer active), it is cached for 1 hour so that it is accessible from 
the client side. It
+ * runs a background thread to run a periodic task that does the following:
+ *   - Check the status of the queries, and drops those that expired (1 hour 
after being stopped).
+ *   - Keep the associated session active by invoking supplied function 
`sessionKeepAliveFn`.
+ *
+ * This class helps with supporting following semantics for streaming query 
sessions:
+ *   - Keep the session and session mapping at connect server alive as long as 
a streaming query
+ *     is active. Even if the client side has disconnected.
+ *     - This matches how streaming queries behave in Spark. The queries 
continue to run if
+ *       notebook or job session is lost.
+ *   - Once a query is stopped, the reference and mappings are maintained for 
1 hour and will be
+ *     accessible from the client. This allows time for client to fetch 
status. If the client
+ *     continues to access the query, it stays in the cache until 1 hour of 
inactivity.
+ *
+ * Note that these semantics are evolving and might change before being 
finalized in Connect.
+ */
+private[connect] class SparkConnectStreamingQueryCache(
+    val sessionKeepAliveFn: (String, String) => Unit, // (userId, sessionId) 
=> Unit.
+    val clock: Clock = new SystemClock(),
+    private val stoppedQueryInactivityTimeout: Duration = 1.hour, // 
Configurable for testing.
+    private val sessionPollingPeriod: Duration = 1.minute // Configurable for 
testing.
+) extends Logging {
+
+  import SparkConnectStreamingQueryCache._
+
+  def registerNewStreamingQuery(sessionHolder: SessionHolder, query: 
StreamingQuery): Unit = {
+    queryCacheLock.synchronized {
+      val value = QueryCacheValue(
+        userId = sessionHolder.userId,
+        sessionId = sessionHolder.sessionId,
+        session = sessionHolder.session,
+        query = query,
+        expiresAtMs = None)
+
+      queryCache.put(QueryCacheKey(query.id.toString, query.runId.toString), 
value) match {
+        case Some(existing) => // Query is being replace. Not really expected.
+          log.warn(
+            s"Replacing existing query in the cache (unexpected). Query Id: 
${query.id}." +
+              s"Existing value $existing, new value $value.")
+        case None =>
+          log.info(s"Adding new query to the cache. Query Id ${query.id}, 
value $value.")
+      }
+
+      schedulePeriodicChecks() // Starts the scheduler thread if it hasn't 
started.
+    }
+  }
+
+  /**
+   * Returns [[StreamingQuery]] if it is cached and session matches the cached 
query. It ensures
+   * the the session associated with it matches the session passed into the 
call. If the query is
+   * inactive (i.e. it has a cache expiry time set), this access extends its 
expiry time. So if a
+   * client keeps accessing a query, it stays in the cache.
+   */
+  def getCachedQuery(
+      queryId: String,
+      runId: String,
+      session: SparkSession): Option[StreamingQuery] = {
+    val key = QueryCacheKey(queryId, runId)
+    queryCacheLock.synchronized {
+      queryCache.get(key).flatMap { v =>
+        if (v.session == session) {
+          v.expiresAtMs.foreach { _ =>
+            // Extend the expiry time as the client is accessing it.
+            val expiresAtMs = clock.getTimeMillis() + 
stoppedQueryInactivityTimeout.toMillis
+            queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs)))
+          }
+          Some(v.query)
+        } else None // Should be rare, may be client is trying access from a 
different session.
+      }
+    }
+  }
+
+  // Visible for testing
+  private[service] def getCachedValue(queryId: String, runId: String): 
Option[QueryCacheValue] =
+    queryCache.get(QueryCacheKey(queryId, runId))
+
+  // Visible for testing.
+  private[service] def shutdown(): Unit = queryCacheLock.synchronized {
+    scheduledExecutor.foreach { executor =>
+      executor.shutdown()
+      executor.awaitTermination(1, TimeUnit.MINUTES)
+    }
+    scheduledExecutor = None
+  }
+
+  @GuardedBy("queryCacheLock")
+  private val queryCache = new mutable.HashMap[QueryCacheKey, QueryCacheValue]
+  private val queryCacheLock = new Object
+
+  @GuardedBy("queryCacheLock")
+  private var scheduledExecutor: Option[ScheduledExecutorService] = None
+
+  /** Schedules periodic checks if it is not already scheduled */
+  private def schedulePeriodicChecks(): Unit = queryCacheLock.synchronized {
+    scheduledExecutor match {
+      case Some(_) => // Already running.
+      case None =>
+        log.info(s"Starting thread for polling streaming sessions every 
$sessionPollingPeriod")
+        scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
+        scheduledExecutor.get.scheduleAtFixedRate(
+          () => {
+            try periodicMaintenance()
+            catch {
+              case NonFatal(ex) => log.warn("Unexpected exception in periodic 
task", ex)
+            }
+          },
+          sessionPollingPeriod.toMillis,
+          sessionPollingPeriod.toMillis,
+          TimeUnit.MILLISECONDS)
+    }
+  }
+
+  /**
+   * Periodic maintenance task to do the following:
+   *   - Update status of query if it is inactive. Sets an expiery time for 
such queries
+   *   - Drop expired queries from the cache.
+   *   - Poll sessions associated with the cached queries in order keep them 
alive in connect
+   *     service' mapping (by invoking `sessionKeepAliveFn`).
+   */
+  private def periodicMaintenance(): Unit = {
+
+    // Gather sessions to keep alive and invoke supplied function outside the 
lock.
+    val sessionsToKeepAlive = mutable.HashSet[(String, String)]()
+
+    queryCacheLock.synchronized {
+      val nowMs = clock.getTimeMillis()
+
+      for ((k, v) <- queryCache) {
+        val id = k.queryId
+        v.expiresAtMs match {
+
+          case Some(ts) if nowMs >= ts => // Expired. Drop references.
+            log.info(s"Removing references for $id in session ${v.sessionId} 
after expiry period")
+            queryCache.remove(k)
+
+          case Some(_) => // Inactive query waiting for expiration. Keep the 
session alive.
+            sessionsToKeepAlive.add((v.userId, v.sessionId))
+
+          case None => // Active query, check if it is stopped. Keep the 
session alive.
+            sessionsToKeepAlive.add((v.userId, v.sessionId))
+
+            val isActive = v.query.isActive && 
Option(v.session.streams.get(id)).nonEmpty
+
+            if (!isActive) {
+              log.info(s"Marking query $id in session ${v.sessionId} 
inactive.")
+              val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis
+              queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs)))
+            }
+        }
+      }
+    }
+
+    for ((userId, sessionId) <- sessionsToKeepAlive) {
+      sessionKeepAliveFn(userId, sessionId)
+    }
+  }
+}
+
+private[connect] object SparkConnectStreamingQueryCache {
+
+  case class SessionCacheKey(userId: String, sessionId: String)
+  case class SessionCacheValue(session: SparkSession)
+
+  case class QueryCacheKey(queryId: String, runId: String)
+
+  case class QueryCacheValue(
+      userId: String,
+      sessionId: String,
+      session: SparkSession, // Holds the reference to the session.
+      query: StreamingQuery, // Holds the reference to the query.
+      expiresAtMs: Option[Long] = None // Expiry time for a stopped query.
+  )
+}
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 8dac0b166b6..88b4be16e5a 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -55,7 +55,7 @@ trait SparkConnectPlanTest extends SharedSparkSession {
   }
 
   def transform(cmd: proto.Command): Unit = {
-    new SparkConnectPlanner(spark).process(cmd, "clientId", new MockObserver())
+    new SparkConnectPlanner(spark).process(cmd, "clientId", "sessionId", new 
MockObserver())
   }
 
   def readRel: proto.Relation =
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
index 39fc90fd002..d61b54c67c2 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
@@ -195,7 +195,7 @@ class SparkConnectPluginRegistrySuite extends 
SharedSparkSession with SparkConne
               .build()))
         .build()
 
-      new SparkConnectPlanner(spark).process(plan, "clientId", new 
MockObserver())
+      new SparkConnectPlanner(spark).process(plan, "clientId", "sessionId", 
new MockObserver())
       
assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Martin"))
     }
   }
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala
new file mode 100644
index 00000000000..36f284ec3ca
--- /dev/null
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.concurrent.duration.DurationInt
+
+import org.mockito.Mockito.when
+import org.scalatest.concurrent.Eventually.eventually
+import org.scalatest.concurrent.Futures.timeout
+import org.scalatestplus.mockito.MockitoSugar
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.streaming.StreamingQuery
+import org.apache.spark.sql.streaming.StreamingQueryManager
+import org.apache.spark.util.ManualClock
+
+class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with 
MockitoSugar {
+
+  // Creates a manager with short durations for periodic check and expiry.
+  private def createSessionManager(keepAliveFn: (String, String) => Unit) = {
+    new SparkConnectStreamingQueryCache(
+      keepAliveFn,
+      clock = new ManualClock(),
+      stoppedQueryInactivityTimeout = 1.minute, // This is on manual clock.
+      sessionPollingPeriod = 20.milliseconds // This is real clock. Used for 
periodic task.
+    )
+  }
+
+  test("Session cache functionality with a streaming query") {
+    // Verifies common happy path for the query cache. Runs a query through 
its life cycle.
+
+    val numKeepAliveCalls = new AtomicInteger(0)
+
+    val queryId = UUID.randomUUID().toString
+    val runId = UUID.randomUUID().toString
+    val mockSession = mock[SparkSession]
+    val mockQuery = mock[StreamingQuery]
+    val mockStreamingQueryManager = mock[StreamingQueryManager]
+
+    val sessionHolder =
+      SessionHolder(userId = "test_user_1", sessionId = "test_session_1", 
session = mockSession)
+
+    val sessionMgr = createSessionManager(keepAliveFn = { case (userId, 
sessionId) =>
+      assert(userId == sessionHolder.userId)
+      assert(sessionId == sessionHolder.sessionId)
+      numKeepAliveCalls.incrementAndGet()
+    })
+
+    val clock = sessionMgr.clock.asInstanceOf[ManualClock]
+
+    when(mockQuery.id).thenReturn(UUID.fromString(queryId))
+    when(mockQuery.runId).thenReturn(UUID.fromString(runId))
+    when(mockQuery.isActive).thenReturn(true) // Query is active.
+    when(mockSession.streams).thenReturn(mockStreamingQueryManager)
+    when(mockStreamingQueryManager.get(queryId)).thenReturn(mockQuery)
+
+    // Register the query.
+
+    sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery)
+
+    eventually(timeout(1.minute)) {
+      // Verify keep alive function is called a few times.
+      assert(numKeepAliveCalls.get() >= 5)
+    }
+
+    sessionMgr.getCachedValue(queryId, runId) match {
+      case Some(v) =>
+        assert(v.sessionId == sessionHolder.sessionId)
+        assert(v.expiresAtMs.isEmpty, "No expiry time should be set for active 
query")
+
+      case None => assert(false, "Query should be found")
+    }
+
+    // Verify query is returned only with the correct session, not with a 
different session.
+    assert(sessionMgr.getCachedQuery(queryId, runId, 
mock[SparkSession]).isEmpty)
+    // Query is returned when correct session is used
+    assert(sessionMgr.getCachedQuery(queryId, runId, 
mockSession).contains(mockQuery))
+
+    // Stop the query.
+    when(mockQuery.isActive).thenReturn(false)
+
+    val expectedExpiryTimeMs = sessionMgr.clock.getTimeMillis() + 
1.minute.toMillis
+
+    // The query should have 'expiresAtMs' set now.
+    eventually(timeout(1.minute)) {
+      val expiresAtOpt = sessionMgr.getCachedValue(queryId, 
runId).flatMap(_.expiresAtMs)
+      assert(expiresAtOpt.contains(expectedExpiryTimeMs))
+    }
+
+    // Verify that expiry time gets extended when the query is accessed.
+    val prevExpiryTimeMs = sessionMgr.getCachedValue(queryId, 
runId).get.expiresAtMs.get
+
+    clock.advance(30.seconds.toMillis)
+
+    // Access the query. This should advance expiry time by 30 seconds.
+    assert(sessionMgr.getCachedQuery(queryId, runId, 
mockSession).contains(mockQuery))
+    val expiresAtMs = sessionMgr.getCachedValue(queryId, 
runId).get.expiresAtMs.get
+    assert(expiresAtMs == prevExpiryTimeMs + 30.seconds.toMillis)
+
+    // During this time ensure that query can be restarted with a new runId.
+
+    val restartedRunId = UUID.randomUUID().toString
+    val restartedQuery = mock[StreamingQuery]
+    when(restartedQuery.id).thenReturn(UUID.fromString(queryId))
+    when(restartedQuery.runId).thenReturn(UUID.fromString(restartedRunId))
+    when(restartedQuery.isActive).thenReturn(true)
+    when(mockStreamingQueryManager.get(queryId)).thenReturn(restartedQuery)
+
+    sessionMgr.registerNewStreamingQuery(sessionHolder, restartedQuery)
+
+    // Both queries should existing in the cache.
+    assert(sessionMgr.getCachedValue(queryId, 
runId).map(_.query).contains(mockQuery))
+    assert(
+      sessionMgr.getCachedValue(queryId, 
restartedRunId).map(_.query).contains(restartedQuery))
+
+    // Advance time by 1 minute and verify the first query is dropped from the 
cache.
+    clock.advance(1.minute.toMillis)
+    eventually(timeout(1.minute)) {
+      assert(sessionMgr.getCachedValue(queryId, runId).isEmpty)
+    }
+
+    // Stop the restarted query and verify gets dropped from the cache too.
+    when(restartedQuery.isActive).thenReturn(false)
+    eventually(timeout(1.minute)) {
+      assert(sessionMgr.getCachedValue(queryId, 
restartedRunId).flatMap(_.expiresAtMs).nonEmpty)
+    }
+
+    // Advance time by one more minute and restarted query should be dropped.
+    clock.advance(1.minute.toMillis)
+    eventually(timeout(1.minute)) {
+      assert(sessionMgr.getCachedValue(queryId, restartedRunId).isEmpty)
+    }
+
+    sessionMgr.shutdown()
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to