This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f086c2327c36 [SPARK-47174][CONNECT][SS][1/2] Server side 
SparkConnectListenerBusListener for Client side streaming query listener
f086c2327c36 is described below

commit f086c2327c36c396ae5d886afd3ef613650c6b0d
Author: Wei Liu <wei....@databricks.com>
AuthorDate: Fri Apr 12 10:08:45 2024 +0900

    [SPARK-47174][CONNECT][SS][1/2] Server side SparkConnectListenerBusListener 
for Client side streaming query listener
    
    ### What changes were proposed in this pull request?
    
    Server side `SparkConnectListenerBusListener` implementation for the client 
side listener. There would only be one such listener for each `SessionHolder`.
    
    ### Why are the changes needed?
    
    Move streaming query listener to client side
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45988 from WweiL/SPARK-47174-client-side-listener-1.
    
    Authored-by: Wei Liu <wei....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  34 ++-
 ...SparkConnectStreamingQueryListenerHandler.scala | 121 +++++++++++
 .../spark/sql/connect/service/SessionHolder.scala  |   7 +
 .../service/SparkConnectListenerBusListener.scala  | 156 ++++++++++++++
 .../SparkConnectListenerBusListenerSuite.scala     | 240 +++++++++++++++++++++
 5 files changed, 555 insertions(+), 3 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 96db45c5c63e..5e7f3b74c299 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2551,6 +2551,11 @@ class SparkConnectPlanner(
         handleStreamingQueryManagerCommand(
           command.getStreamingQueryManagerCommand,
           responseObserver)
+      case proto.Command.CommandTypeCase.STREAMING_QUERY_LISTENER_BUS_COMMAND 
=>
+        val handler = new 
SparkConnectStreamingQueryListenerHandler(executeHolder)
+        handler.handleListenerCommand(
+          command.getStreamingQueryListenerBusCommand,
+          responseObserver)
       case proto.Command.CommandTypeCase.GET_RESOURCES_COMMAND =>
         handleGetResourcesCommand(responseObserver)
       case proto.Command.CommandTypeCase.CREATE_RESOURCE_PROFILE_COMMAND =>
@@ -3118,7 +3123,7 @@ class SparkConnectPlanner(
     }
     executeHolder.eventsManager.postFinished()
 
-    val result = WriteStreamOperationStartResult
+    val resultBuilder = WriteStreamOperationStartResult
       .newBuilder()
       .setQueryId(
         StreamingQueryInstanceId
@@ -3127,14 +3132,37 @@ class SparkConnectPlanner(
           .setRunId(query.runId.toString)
           .build())
       .setName(Option(query.name).getOrElse(""))
-      .build()
+
+    // The query started event for this query is sent to the client, and is 
handled by
+    // the client side listeners before client's DataStreamWriter.start() 
returns.
+    // This is to ensure that the onQueryStarted call back is called before 
the start() call, which
+    // is defined in the onQueryStarted API.
+    // So the flow is:
+    // 1. On the server side, the query is started above.
+    // 2. Per the contract of the onQueryStarted API, the queryStartedEvent is 
added to the
+    //    streamingServersideListenerHolder.streamingQueryStartedEventCache, 
by the onQueryStarted
+    //    call back of 
streamingServersideListenerHolder.streamingQueryServerSideListener.
+    // 3. The queryStartedEvent is sent to the client.
+    // 4. The client side listener handles the queryStartedEvent and calls the 
onQueryStarted API,
+    //    before the client side DataStreamWriter.start().
+    // This way we ensure that the onQueryStarted API is called before the 
start() call in Connect.
+    val queryStartedEvent = Option(
+      
sessionHolder.streamingServersideListenerHolder.streamingQueryStartedEventCache.remove(
+        query.runId.toString))
+    queryStartedEvent.foreach {
+      logDebug(
+        s"[SessionId: $sessionId][UserId: $userId][operationId: " +
+          s"${executeHolder.operationId}][query id: ${query.id}][query runId: 
${query.runId}] " +
+          s"Adding QueryStartedEvent to response")
+      e => resultBuilder.setQueryStartedEventJson(e.json)
+    }
 
     responseObserver.onNext(
       ExecutePlanResponse
         .newBuilder()
         .setSessionId(sessionId)
         .setServerSideSessionId(sessionHolder.serverSessionId)
-        .setWriteStreamOperationStartResult(result)
+        .setWriteStreamOperationStartResult(resultBuilder.build())
         .build())
   }
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectStreamingQueryListenerHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectStreamingQueryListenerHandler.scala
new file mode 100644
index 000000000000..94f01026b7a5
--- /dev/null
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectStreamingQueryListenerHandler.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.planner
+
+import scala.util.control.NonFatal
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto.ExecutePlanResponse
+import org.apache.spark.connect.proto.StreamingQueryListenerBusCommand
+import org.apache.spark.connect.proto.StreamingQueryListenerEventsResult
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connect.service.ExecuteHolder
+
+/**
+ * Handle long-running streaming query listener events.
+ */
+class SparkConnectStreamingQueryListenerHandler(executeHolder: ExecuteHolder) 
extends Logging {
+
+  val sessionHolder = executeHolder.sessionHolder
+
+  private[connect] def userId: String = sessionHolder.userId
+
+  private[connect] def sessionId: String = sessionHolder.sessionId
+
+  /**
+   * The handler logic. The handler of ADD_LISTENER_BUS_LISTENER uses the
+   * streamingQueryListenerLatch to block the handling thread, preventing it 
from sending back the
+   * final ResultComplete response.
+   *
+   * The handler of REMOVE_LISTENER_BUS_LISTENER cleans up the server side 
listener resources and
+   * count down the latch, allowing the handling thread of the original 
ADD_LISTENER_BUS_LISTENER
+   * to proceed to send back the final ResultComplete response.
+   */
+  def handleListenerCommand(
+      command: StreamingQueryListenerBusCommand,
+      responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
+
+    val listenerHolder = sessionHolder.streamingServersideListenerHolder
+
+    command.getCommandCase match {
+      case 
StreamingQueryListenerBusCommand.CommandCase.ADD_LISTENER_BUS_LISTENER =>
+        listenerHolder.isServerSideListenerRegistered match {
+          case true =>
+            logWarning(
+              s"[SessionId: $sessionId][UserId: $userId][operationId: " +
+                s"${executeHolder.operationId}] Redundant server side listener 
added. Exiting.")
+            return
+          case false =>
+            // This transfers sending back the response to the client until
+            // the long running command is terminated, either by
+            // errors in streamingQueryServerSideListener.send,
+            // or client issues a REMOVE_LISTENER_BUS_LISTENER call.
+            listenerHolder.init(responseObserver)
+            // Send back listener added response
+            val respBuilder = StreamingQueryListenerEventsResult.newBuilder()
+            val listenerAddedResult = respBuilder
+              .setListenerBusListenerAdded(true)
+              .build()
+            try {
+              responseObserver.onNext(
+                ExecutePlanResponse
+                  .newBuilder()
+                  .setSessionId(sessionHolder.sessionId)
+                  .setServerSideSessionId(sessionHolder.serverSessionId)
+                  .setStreamingQueryListenerEventsResult(listenerAddedResult)
+                  .build())
+            } catch {
+              case NonFatal(e) =>
+                logError(
+                  s"[SessionId: $sessionId][UserId: $userId][operationId: " +
+                    s"${executeHolder.operationId}] Error sending listener 
added response.",
+                  e)
+                listenerHolder.cleanUp()
+                return
+            }
+        }
+        logInfo(s"[SessionId: $sessionId][UserId: $userId][operationId: " +
+          s"${executeHolder.operationId}] Server side listener added. Now 
blocking until " +
+          "all client side listeners are removed or there is error 
transmitting the event back.")
+        // Block the handling thread, and have serverListener continuously 
send back new events
+        listenerHolder.streamingQueryListenerLatch.await()
+        logInfo(s"[SessionId: $sessionId][UserId: $userId][operationId: " +
+          s"${executeHolder.operationId}] Server side listener long-running 
handling thread ended.")
+      case 
StreamingQueryListenerBusCommand.CommandCase.REMOVE_LISTENER_BUS_LISTENER =>
+        listenerHolder.isServerSideListenerRegistered match {
+          case true =>
+            sessionHolder.streamingServersideListenerHolder.cleanUp()
+          case false =>
+            logWarning(
+              s"[SessionId: $sessionId][UserId: $userId][operationId: " +
+                s"${executeHolder.operationId}] No active server side listener 
bus listener " +
+                s"but received remove listener call. Exiting.")
+            return
+        }
+      case StreamingQueryListenerBusCommand.CommandCase.COMMAND_NOT_SET =>
+        throw new IllegalArgumentException("Missing command in 
StreamingQueryListenerBusCommand")
+    }
+    // If this thread is the handling thread of the original 
ADD_LISTENER_BUS_LISTENER command,
+    // this will be sent when the latch is counted down (either through
+    // a REMOVE_LISTENER_BUS_LISTENER command, or long-lived gRPC throws.
+    // If this thread is the handling thread of the 
REMOVE_LISTENER_BUS_LISTENER command,
+    // this is hit right away.
+    executeHolder.eventsManager.postFinished()
+  }
+}
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index ef79cdcce8ff..306b89148583 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -92,6 +92,8 @@ case class SessionHolder(userId: String, sessionId: String, 
session: SparkSessio
   private[connect] lazy val streamingForeachBatchRunnerCleanerCache =
     new StreamingForeachBatchHelper.CleanerCache(this)
 
+  private[connect] lazy val streamingServersideListenerHolder = new 
ServerSideListenerHolder(this)
+
   def key: SessionKey = SessionKey(userId, sessionId)
 
   // Returns the server side session ID and asserts that it must be different 
from the client-side
@@ -267,6 +269,11 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
     streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any 
streaming workers.
     removeAllListeners() // removes all listener and stop python listener 
processes if necessary.
 
+    // if there is a server side listener, clean up related resources
+    if (streamingServersideListenerHolder.isServerSideListenerRegistered) {
+      streamingServersideListenerHolder.cleanUp()
+    }
+
     // Clean up all executions.
     // After closedTimeMs is defined, SessionHolder.addExecuteHolder() will 
not allow new executions
     // to be added for this session anymore. Because both 
SessionHolder.addExecuteHolder() and
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
new file mode 100644
index 000000000000..1b6c5179871d
--- /dev/null
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, CountDownLatch}
+
+import scala.jdk.CollectionConverters._
+import scala.util.control.NonFatal
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto.ExecutePlanResponse
+import org.apache.spark.connect.proto.StreamingQueryEventType
+import org.apache.spark.connect.proto.StreamingQueryListenerEvent
+import org.apache.spark.connect.proto.StreamingQueryListenerEventsResult
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.streaming.StreamingQueryListener
+import org.apache.spark.util.ArrayImplicits._
+
+/**
+ * A holder for the server side listener and related resources. There should 
be only one such
+ * holder for each sessionHolder.
+ */
+private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
+  // The server side listener that is responsible to stream streaming query 
events back to client.
+  // There is only one listener per sessionHolder, but each listener is 
responsible for all events
+  // of all streaming queries in the SparkSession.
+  var streamingQueryServerSideListener: 
Option[SparkConnectListenerBusListener] = None
+  // The count down latch to hold the long-running listener thread before 
sending ResultComplete.
+  var streamingQueryListenerLatch = new CountDownLatch(1)
+  // The cache for QueryStartedEvent, key is query runId and value is the 
actual QueryStartedEvent.
+  // Events for corresponding query will be sent back to client with
+  // the WriteStreamOperationStart response, so that the client can handle the 
event before
+  // DataStreamWriter.start() returns. This special handling is to satisfy the 
contract of
+  // onQueryStarted in StreamingQueryListener.
+  val streamingQueryStartedEventCache
+      : ConcurrentMap[String, StreamingQueryListener.QueryStartedEvent] = new 
ConcurrentHashMap()
+
+  def isServerSideListenerRegistered: Boolean = 
streamingQueryServerSideListener.isDefined
+
+  /**
+   * The initialization of the server side listener and related resources. 
This method is called
+   * when the first ADD_LISTENER_BUS_LISTENER command is received. It is 
attached to a
+   * responseObserver, from the first executeThread (long running thread), so 
the lifecycle of the
+   * responseObserver is the same as the life cycle of the listener.
+   *
+   * @param responseObserver
+   *   the responseObserver created from the first long running executeThread.
+   */
+  def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
+    val serverListener = new SparkConnectListenerBusListener(this, 
responseObserver)
+    sessionHolder.session.streams.addListener(serverListener)
+    streamingQueryServerSideListener = Some(serverListener)
+    streamingQueryListenerLatch = new CountDownLatch(1)
+  }
+
+  /**
+   * The cleanup of the server side listener and related resources. This 
method is called when the
+   * REMOVE_LISTENER_BUS_LISTENER command is received or when 
responseObserver.onNext throws an
+   * exception. It removes the listener from the session, clears the cache. 
Also it counts down
+   * the latch, so the long-running thread can proceed to send back the final 
ResultComplete
+   * response.
+   */
+  def cleanUp(): Unit = {
+    streamingQueryServerSideListener.foreach { listener =>
+      sessionHolder.session.streams.removeListener(listener)
+    }
+    streamingQueryStartedEventCache.clear()
+    streamingQueryServerSideListener = None
+    streamingQueryListenerLatch.countDown()
+  }
+}
+
+/**
+ * A customized StreamingQueryListener used in Spark Connect for the 
client-side listeners. Upon
+ * the invocation of each callback function, it serializes the event to json 
and sent it to the
+ * client.
+ */
+private[sql] class SparkConnectListenerBusListener(
+    serverSideListenerHolder: ServerSideListenerHolder,
+    responseObserver: StreamObserver[ExecutePlanResponse])
+    extends StreamingQueryListener
+    with Logging {
+
+  val sessionHolder = serverSideListenerHolder.sessionHolder
+  // The method used to stream back the events to the client.
+  // The event is serialized to json and sent to the client.
+  // The responseObserver is what of the first executeThread (long running 
thread),
+  // which is held still by the streamingQueryListenerLatch.
+  // If any exception is thrown while transmitting back the event, the 
listener is removed,
+  // all related sources are cleaned up, and the long-running thread will 
proceed to send
+  // the final ResultComplete response.
+  private def send(eventJson: String, eventType: StreamingQueryEventType): 
Unit = {
+    val event = StreamingQueryListenerEvent
+      .newBuilder()
+      .setEventJson(eventJson)
+      .setEventType(eventType)
+      .build()
+
+    val respBuilder = StreamingQueryListenerEventsResult.newBuilder()
+    val eventResult = respBuilder
+      
.addAllEvents(Array[StreamingQueryListenerEvent](event).toImmutableArraySeq.asJava)
+      .build()
+
+    try {
+      responseObserver.onNext(
+        ExecutePlanResponse
+          .newBuilder()
+          .setSessionId(sessionHolder.sessionId)
+          .setServerSideSessionId(sessionHolder.serverSessionId)
+          .setStreamingQueryListenerEventsResult(eventResult)
+          .build())
+    } catch {
+      case NonFatal(e) =>
+        logError(
+          s"[SessionId: ${sessionHolder.sessionId}][UserId: 
${sessionHolder.userId}] " +
+            s"Removing SparkConnectListenerBusListener and terminating the 
long-running thread " +
+            s"because of exception: $e")
+        // This likely means that the client is not responsive even with 
retry, we should
+        // remove this listener and cleanup resources.
+        serverSideListenerHolder.cleanUp()
+    }
+  }
+
+  // QueryStartedEvent is sent to client along with 
WriteStreamOperationStartResult
+  override def onQueryStarted(event: 
StreamingQueryListener.QueryStartedEvent): Unit = {
+    
serverSideListenerHolder.streamingQueryStartedEventCache.put(event.runId.toString,
 event)
+  }
+
+  override def onQueryProgress(event: 
StreamingQueryListener.QueryProgressEvent): Unit = {
+    send(event.json, StreamingQueryEventType.QUERY_PROGRESS_EVENT)
+  }
+
+  override def onQueryTerminated(event: 
StreamingQueryListener.QueryTerminatedEvent): Unit = {
+    send(event.json, StreamingQueryEventType.QUERY_TERMINATED_EVENT)
+  }
+
+  override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit 
= {
+    send(event.json, StreamingQueryEventType.QUERY_IDLE_EVENT)
+  }
+}
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala
new file mode 100644
index 000000000000..4c2962fda507
--- /dev/null
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration.DurationInt
+import scala.jdk.CollectionConverters._
+
+import io.grpc.stub.StreamObserver
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.scalatestplus.mockito.MockitoSugar
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.connect.proto.{Command, ExecutePlanResponse}
+import 
org.apache.spark.sql.connect.planner.SparkConnectStreamingQueryListenerHandler
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryListener}
+import org.apache.spark.sql.streaming.Trigger.ProcessingTime
+import org.apache.spark.sql.test.SharedSparkSession
+
+class SparkConnectListenerBusListenerSuite
+    extends SparkFunSuite
+    with SharedSparkSession
+    with MockitoSugar {
+
+  override def afterEach(): Unit = {
+    try {
+      spark.streams.active.foreach(_.stop())
+      spark.streams.listListeners().foreach(spark.streams.removeListener)
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  // A test listener that caches all events
+  private class CacheEventsStreamingQueryListener(
+      startEvents: ArrayBuffer[StreamingQueryListener.QueryStartedEvent],
+      otherEvents: ArrayBuffer[StreamingQueryListener.Event])
+      extends StreamingQueryListener {
+
+    override def onQueryStarted(event: 
StreamingQueryListener.QueryStartedEvent): Unit = {
+      startEvents += event
+    }
+
+    override def onQueryProgress(event: 
StreamingQueryListener.QueryProgressEvent): Unit = {
+      otherEvents += event
+    }
+
+    override def onQueryTerminated(event: 
StreamingQueryListener.QueryTerminatedEvent): Unit = {
+      otherEvents += event
+    }
+
+    override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): 
Unit = {
+      otherEvents += event
+    }
+  }
+
+  private def verifyEventsSent(
+      fromCachedEventsListener: ArrayBuffer[StreamingQueryListener.Event],
+      fromListenerBusListener: ArrayBuffer[String]): Unit = {
+    assert(fromListenerBusListener.toSet === fromCachedEventsListener.map {
+      case e: StreamingQueryListener.QueryStartedEvent => e.json
+      case e: StreamingQueryListener.QueryProgressEvent => e.json
+      case e: StreamingQueryListener.QueryTerminatedEvent => e.json
+      case e: StreamingQueryListener.QueryIdleEvent => e.json
+    }.toSet)
+  }
+
+  private def startQuery(slow: Boolean = false): StreamingQuery = {
+    val dsw = spark.readStream.format("rate").load().writeStream.format("noop")
+    if (slow) {
+      dsw.trigger(ProcessingTime("20 seconds"))
+    }
+    dsw.start()
+  }
+
+  Seq(1, 5, 20).foreach { queryNum =>
+    test(
+      "Basic functionalities - onQueryStart, onQueryProgress, 
onQueryTerminated" +
+        s" - $queryNum queries") {
+      val sessionHolder = SessionHolder.forTesting(spark)
+      val responseObserver = mock[StreamObserver[ExecutePlanResponse]]
+      val eventJsonBuffer = ArrayBuffer.empty[String]
+      val startEventsBuffer = 
ArrayBuffer.empty[StreamingQueryListener.QueryStartedEvent]
+      val otherEventsBuffer = ArrayBuffer.empty[StreamingQueryListener.Event]
+
+      doAnswer((invocation: InvocationOnMock) => {
+        val argument = invocation.getArgument[ExecutePlanResponse](0)
+        val eventJson = 
argument.getStreamingQueryListenerEventsResult().getEvents(0).getEventJson
+        eventJsonBuffer += eventJson
+      }).when(responseObserver).onNext(any[ExecutePlanResponse]())
+
+      val listenerHolder = sessionHolder.streamingServersideListenerHolder
+      listenerHolder.init(responseObserver)
+      val cachedEventsListener =
+        new CacheEventsStreamingQueryListener(startEventsBuffer, 
otherEventsBuffer)
+
+      spark.streams.addListener(cachedEventsListener)
+
+      for (_ <- 1 to queryNum) startQuery()
+
+      // after all queries made some progresses
+      eventually(timeout(60.seconds), interval(2.seconds)) {
+        spark.streams.active.foreach { q =>
+          assert(q.lastProgress.batchId > 5)
+        }
+      }
+
+      // stops all queries
+      spark.streams.active.foreach(_.stop())
+
+      eventually(timeout(60.seconds), interval(500.milliseconds)) {
+        assert(eventJsonBuffer.nonEmpty)
+        assert(!listenerHolder.streamingQueryStartedEventCache.isEmpty)
+        verifyEventsSent(otherEventsBuffer, eventJsonBuffer)
+        assert(
+          startEventsBuffer.map(_.json).toSet ===
+            
listenerHolder.streamingQueryStartedEventCache.asScala.map(_._2.json).toSet)
+      }
+    }
+  }
+
+  test("Basic functionalities - Slow query") {
+    val sessionHolder = SessionHolder.forTesting(spark)
+    val responseObserver = mock[StreamObserver[ExecutePlanResponse]]
+    val eventJsonBuffer = ArrayBuffer.empty[String]
+    val startEventsBuffer = 
ArrayBuffer.empty[StreamingQueryListener.QueryStartedEvent]
+    val otherEventsBuffer = ArrayBuffer.empty[StreamingQueryListener.Event]
+
+    doAnswer((invocation: InvocationOnMock) => {
+      val argument = invocation.getArgument[ExecutePlanResponse](0)
+      val eventJson = 
argument.getStreamingQueryListenerEventsResult().getEvents(0).getEventJson
+      eventJsonBuffer += eventJson
+    }).when(responseObserver).onNext(any[ExecutePlanResponse]())
+
+    val listenerHolder = sessionHolder.streamingServersideListenerHolder
+    listenerHolder.init(responseObserver)
+
+    val cachedEventsListener =
+      new CacheEventsStreamingQueryListener(startEventsBuffer, 
otherEventsBuffer)
+    spark.streams.addListener(cachedEventsListener)
+
+    // Slow query
+    val q = startQuery(true)
+
+    // after the slow query made some progresses
+    eventually(timeout(100.seconds), interval(7.seconds)) {
+      assert(q.lastProgress.batchId > 2)
+    }
+
+    q.stop()
+
+    eventually(timeout(60.seconds), interval(1.second)) {
+      assert(eventJsonBuffer.nonEmpty)
+      assert(!listenerHolder.streamingQueryStartedEventCache.isEmpty)
+      verifyEventsSent(otherEventsBuffer, eventJsonBuffer)
+      assert(
+        startEventsBuffer.map(_.json).toSet ===
+          
listenerHolder.streamingQueryStartedEventCache.asScala.map(_._2.json).toSet)
+    }
+  }
+
+  test("Proper handling on onNext throw - initial response") {
+    val sessionHolder = SessionHolder.forTesting(spark)
+
+    val executeHolder = mock[ExecuteHolder]
+    when(executeHolder.sessionHolder).thenReturn(sessionHolder)
+    when(executeHolder.operationId).thenReturn("operationId")
+
+    val responseObserver = mock[StreamObserver[ExecutePlanResponse]]
+    doThrow(new RuntimeException("I'm dead"))
+      .when(responseObserver)
+      .onNext(any[ExecutePlanResponse]())
+
+    val listenerCntBeforeThrow = spark.streams.listListeners().size
+
+    val handler = new SparkConnectStreamingQueryListenerHandler(executeHolder)
+    val listenerBusCmdBuilder = 
Command.newBuilder().getStreamingQueryListenerBusCommandBuilder
+    val addListenerCommand = 
listenerBusCmdBuilder.setAddListenerBusListener(true).build()
+    handler.handleListenerCommand(addListenerCommand, responseObserver)
+
+    val listenerHolder = sessionHolder.streamingServersideListenerHolder
+    eventually(timeout(5.seconds), interval(500.milliseconds)) {
+      assert(
+        
sessionHolder.streamingServersideListenerHolder.streamingQueryServerSideListener.isEmpty)
+      assert(spark.streams.listListeners().size === listenerCntBeforeThrow)
+      assert(listenerHolder.streamingQueryStartedEventCache.isEmpty)
+      assert(listenerHolder.streamingQueryListenerLatch.getCount === 0)
+    }
+
+  }
+
+  test("Proper handling on onNext throw - query progress") {
+    val sessionHolder = SessionHolder.forTesting(spark)
+    val responseObserver = mock[StreamObserver[ExecutePlanResponse]]
+    doThrow(new RuntimeException("I'm dead"))
+      .when(responseObserver)
+      .onNext(any[ExecutePlanResponse]())
+
+    val listenerHolder = sessionHolder.streamingServersideListenerHolder
+    listenerHolder.init(responseObserver)
+    val listenerBusListener = 
listenerHolder.streamingQueryServerSideListener.get
+
+    // mock a QueryStartedEvent for cleanup test
+    val queryStartedEvent = new StreamingQueryListener.QueryStartedEvent(
+      UUID.randomUUID,
+      UUID.randomUUID,
+      "name",
+      "timestamp")
+    listenerHolder.streamingQueryStartedEventCache.put(
+      queryStartedEvent.runId.toString,
+      queryStartedEvent)
+
+    startQuery()
+
+    eventually(timeout(5.seconds), interval(500.milliseconds)) {
+      assert(!spark.streams.listListeners().contains(listenerBusListener))
+      assert(listenerHolder.streamingQueryStartedEventCache.isEmpty)
+      assert(listenerHolder.streamingQueryListenerLatch.getCount === 0)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to